diff --git a/midas/checker/python.py b/midas/checker/python.py index d667164..4481ba7 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -334,6 +334,18 @@ class PythonTyper( def visit_pass(self, stmt: p.Pass) -> None: pass + def visit_for_stmt(self, stmt: p.ForStmt) -> None: + item_type: Optional[Type] = self._get_iterator_type(stmt.iterator) + if item_type is None: + self.reporter.error(stmt.iterator.location, "Iterator is not an iterator") + item_type = UnknownType() + + self._assign(stmt.location, stmt.target, item_type) + env: Environment = Environment(self.env) + body_returned: bool = self.process_block(stmt.body, env) + if body_returned: + raise ReturnException() + def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) if method is None: @@ -370,7 +382,13 @@ class PythonTyper( ) return UnknownType() - return self._get_call_result(location, operation, [(right_expr, right)], {}) + result: Optional[Type] = self._get_call_result( + location, + operation, + [(right_expr, right)], + {}, + ) + return result or UnknownType() def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__) @@ -390,9 +408,13 @@ class PythonTyper( ) return UnknownType() - return self._get_call_result( - expr.location, operation, [(expr.right, operand)], {} + result: Optional[Type] = self._get_call_result( + expr.location, + operation, + [], + {}, ) + return result or UnknownType() def visit_call_expr(self, expr: p.CallExpr) -> Type: callee: Type = self.type_of(expr.callee) @@ -402,11 +424,14 @@ class PythonTyper( keywords: dict[str, TypedExpr] = { name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items() } - return self._get_call_result( - location=expr.location, - callee=callee, - positional=positional, - keywords=keywords, + return ( + self._get_call_result( + location=expr.location, + callee=callee, + positional=positional, + keywords=keywords, + ) + or UnknownType() ) def visit_get_expr(self, expr: p.GetExpr) -> Type: @@ -509,8 +534,9 @@ class PythonTyper( return UnknownType() index: Type = self.type_of(expr.index) - return self._get_call_result( - expr.location, operation, [(expr.index, index)], {} + return ( + self._get_call_result(expr.location, operation, [(expr.index, index)], {}) + or UnknownType() ) def visit_slice_expr(self, expr: p.SliceExpr) -> Type: @@ -547,7 +573,8 @@ class PythonTyper( callee: Type, positional: list[TypedExpr], keywords: dict[str, TypedExpr], - ) -> Type: + report_errors: bool = True, + ) -> Optional[Type]: """Get the result type of a function call If the function has overloads, the function will try to resolve the @@ -561,9 +588,10 @@ class PythonTyper( callee (Type): the called function positional (list[TypedExpr]): the list positional arguments keywords (dict[str, TypedExpr]): the map of keyword arguments + report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. Returns: - Type: the return type of the call, or `UnknownType` if either + Type: the return type of the call, or `None` if either the call is invalid or no overload matched the arguments uniquely """ match callee: @@ -573,21 +601,22 @@ class PythonTyper( valid, mapped = self.map_call_arguments( function, location, positional, keywords ) - valid = valid and self._are_arguments_valid(mapped) + valid = valid and self._are_arguments_valid(mapped, report_errors) if not valid: - return UnknownType() + return None return function.returns case OverloadedFunction(overloads=overloads): function = self._match_overload( - overloads, location, positional, keywords + overloads, location, positional, keywords, report_errors ) if function is None: - return UnknownType() + return None return function.returns case _: - self.reporter.error(location, f"{callee} is not callable") - return UnknownType() + if report_errors: + self.reporter.error(location, f"{callee} is not callable") + return None def _are_arguments_valid( self, @@ -620,6 +649,7 @@ class PythonTyper( location: Location, positional: list[TypedExpr], keywords: dict[str, TypedExpr], + report_errors: bool = True, ) -> Optional[Function]: """Try and resolve the appropriate overload for the given arguments @@ -628,6 +658,7 @@ class PythonTyper( location (Location): the call location positional (list[TypedExpr]): the list of positional arguments keywords (dict[str, TypedExpr]): the map of keywords arguments + report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. Returns: Optional[Function]: the resolved function signature if it can be @@ -637,9 +668,10 @@ class PythonTyper( for overload in overloads: function: Type = unfold_type(overload) if not isinstance(function, Function): - self.logger.error( - f"Overload is not a function: {overload} is {function}" - ) + if report_errors: + self.logger.error( + f"Overload is not a function: {overload} is {function}" + ) continue valid, mapped = self.map_call_arguments( function=function, @@ -671,10 +703,11 @@ class PythonTyper( # No match -> invalid call if n_candidates == 0: overloads_str: str = ", ".join(map(str, overloads)) - self.reporter.error( - location, - f"No matching overload in [{overloads_str}] {for_args}", - ) + if report_errors: + self.reporter.error( + location, + f"No matching overload in [{overloads_str}] {for_args}", + ) return None # Multiple matches -> see if one <: all others (more specific) @@ -695,10 +728,11 @@ class PythonTyper( candidates_str: str = ", ".join( str(candidate.function) for candidate in candidates ) - self.reporter.error( - location, - f"Multiple matching overloads {for_args}: {candidates_str}", - ) + if report_errors: + self.reporter.error( + location, + f"Multiple matching overloads {for_args}: {candidates_str}", + ) return None def map_call_arguments( @@ -863,3 +897,23 @@ class PythonTyper( if not self.is_subtype(type1, type2): return False return True + + def _get_iterator_type(self, expr: p.Expr) -> Optional[Type]: + # TODO: lookup __iter__ + type: Type = self.type_of(expr) + getitem: Optional[Type] = self.types.lookup_member(type, "__getitem__") + if getitem is None: + return None + + index: p.Expr = p.LiteralExpr(location=expr.location, value=0) + index_type: Type = index.accept( + self + ) # skip type_of() to avoid recording judgement + result: Optional[Type] = self._get_call_result( + location=expr.location, + callee=getitem, + positional=[(index, index_type)], + keywords={}, + report_errors=False, + ) + return result diff --git a/midas/checker/resolver.py b/midas/checker/resolver.py index eb0a6e8..c99a18d 100644 --- a/midas/checker/resolver.py +++ b/midas/checker/resolver.py @@ -116,17 +116,20 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: self.resolve(stmt.value) for target in stmt.targets: - match target: - case p.VariableExpr(name=name): - if not self.is_defined(name): - self.declare(name) - self.define(name) - target.accept(self) + self._visit_assign(target) - case p.GetExpr(): - target.accept(self) - case _: - raise Exception(f"Unsupported assignment to {target}") + def _visit_assign(self, target: p.Expr): + match target: + case p.VariableExpr(name=name): + if not self.is_defined(name): + self.declare(name) + self.define(name) + target.accept(self) + + case p.GetExpr(): + target.accept(self) + case _: + raise Exception(f"Unsupported assignment to {target}") def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: if stmt.value is not None: @@ -153,6 +156,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): def visit_pass(self, stmt: p.Pass) -> None: pass + def visit_for_stmt(self, stmt: p.ForStmt) -> None: + self.resolve(stmt.iterator) + self._visit_assign(stmt.target) + self.begin_scope() + self.resolve(*stmt.body) + self.end_scope() + def visit_binary_expr(self, expr: p.BinaryExpr) -> None: self.resolve(expr.left) self.resolve(expr.right)