feat(checker): handle for loops

This commit is contained in:
2026-06-16 00:36:03 +02:00
parent faa98ce0ef
commit 48e13d3348
2 changed files with 103 additions and 39 deletions

View File

@@ -334,6 +334,18 @@ class PythonTyper(
def visit_pass(self, stmt: p.Pass) -> None: def visit_pass(self, stmt: p.Pass) -> None:
pass pass
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
item_type: Optional[Type] = self._get_iterator_type(stmt.iterator)
if item_type is None:
self.reporter.error(stmt.iterator.location, "Iterator is not an iterator")
item_type = UnknownType()
self._assign(stmt.location, stmt.target, item_type)
env: Environment = Environment(self.env)
body_returned: bool = self.process_block(stmt.body, env)
if body_returned:
raise ReturnException()
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
if method is None: if method is None:
@@ -370,7 +382,13 @@ class PythonTyper(
) )
return UnknownType() return UnknownType()
return self._get_call_result(location, operation, [(right_expr, right)], {}) result: Optional[Type] = self._get_call_result(
location,
operation,
[(right_expr, right)],
{},
)
return result or UnknownType()
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__) method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
@@ -390,9 +408,13 @@ class PythonTyper(
) )
return UnknownType() return UnknownType()
return self._get_call_result( result: Optional[Type] = self._get_call_result(
expr.location, operation, [(expr.right, operand)], {} expr.location,
operation,
[],
{},
) )
return result or UnknownType()
def visit_call_expr(self, expr: p.CallExpr) -> Type: def visit_call_expr(self, expr: p.CallExpr) -> Type:
callee: Type = self.type_of(expr.callee) callee: Type = self.type_of(expr.callee)
@@ -402,12 +424,15 @@ class PythonTyper(
keywords: dict[str, TypedExpr] = { keywords: dict[str, TypedExpr] = {
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items() name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
} }
return self._get_call_result( return (
self._get_call_result(
location=expr.location, location=expr.location,
callee=callee, callee=callee,
positional=positional, positional=positional,
keywords=keywords, keywords=keywords,
) )
or UnknownType()
)
def visit_get_expr(self, expr: p.GetExpr) -> Type: def visit_get_expr(self, expr: p.GetExpr) -> Type:
object: Type = self.type_of(expr.object) object: Type = self.type_of(expr.object)
@@ -509,8 +534,9 @@ class PythonTyper(
return UnknownType() return UnknownType()
index: Type = self.type_of(expr.index) index: Type = self.type_of(expr.index)
return self._get_call_result( return (
expr.location, operation, [(expr.index, index)], {} self._get_call_result(expr.location, operation, [(expr.index, index)], {})
or UnknownType()
) )
def visit_slice_expr(self, expr: p.SliceExpr) -> Type: def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
@@ -547,7 +573,8 @@ class PythonTyper(
callee: Type, callee: Type,
positional: list[TypedExpr], positional: list[TypedExpr],
keywords: dict[str, TypedExpr], keywords: dict[str, TypedExpr],
) -> Type: report_errors: bool = True,
) -> Optional[Type]:
"""Get the result type of a function call """Get the result type of a function call
If the function has overloads, the function will try to resolve the If the function has overloads, the function will try to resolve the
@@ -561,9 +588,10 @@ class PythonTyper(
callee (Type): the called function callee (Type): the called function
positional (list[TypedExpr]): the list positional arguments positional (list[TypedExpr]): the list positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns: Returns:
Type: the return type of the call, or `UnknownType` if either Type: the return type of the call, or `None` if either
the call is invalid or no overload matched the arguments uniquely the call is invalid or no overload matched the arguments uniquely
""" """
match callee: match callee:
@@ -573,21 +601,22 @@ class PythonTyper(
valid, mapped = self.map_call_arguments( valid, mapped = self.map_call_arguments(
function, location, positional, keywords function, location, positional, keywords
) )
valid = valid and self._are_arguments_valid(mapped) valid = valid and self._are_arguments_valid(mapped, report_errors)
if not valid: if not valid:
return UnknownType() return None
return function.returns return function.returns
case OverloadedFunction(overloads=overloads): case OverloadedFunction(overloads=overloads):
function = self._match_overload( function = self._match_overload(
overloads, location, positional, keywords overloads, location, positional, keywords, report_errors
) )
if function is None: if function is None:
return UnknownType() return None
return function.returns return function.returns
case _: case _:
if report_errors:
self.reporter.error(location, f"{callee} is not callable") self.reporter.error(location, f"{callee} is not callable")
return UnknownType() return None
def _are_arguments_valid( def _are_arguments_valid(
self, self,
@@ -620,6 +649,7 @@ class PythonTyper(
location: Location, location: Location,
positional: list[TypedExpr], positional: list[TypedExpr],
keywords: dict[str, TypedExpr], keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> Optional[Function]: ) -> Optional[Function]:
"""Try and resolve the appropriate overload for the given arguments """Try and resolve the appropriate overload for the given arguments
@@ -628,6 +658,7 @@ class PythonTyper(
location (Location): the call location location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keywords arguments keywords (dict[str, TypedExpr]): the map of keywords arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns: Returns:
Optional[Function]: the resolved function signature if it can be Optional[Function]: the resolved function signature if it can be
@@ -637,6 +668,7 @@ class PythonTyper(
for overload in overloads: for overload in overloads:
function: Type = unfold_type(overload) function: Type = unfold_type(overload)
if not isinstance(function, Function): if not isinstance(function, Function):
if report_errors:
self.logger.error( self.logger.error(
f"Overload is not a function: {overload} is {function}" f"Overload is not a function: {overload} is {function}"
) )
@@ -671,6 +703,7 @@ class PythonTyper(
# No match -> invalid call # No match -> invalid call
if n_candidates == 0: if n_candidates == 0:
overloads_str: str = ", ".join(map(str, overloads)) overloads_str: str = ", ".join(map(str, overloads))
if report_errors:
self.reporter.error( self.reporter.error(
location, location,
f"No matching overload in [{overloads_str}] {for_args}", f"No matching overload in [{overloads_str}] {for_args}",
@@ -695,6 +728,7 @@ class PythonTyper(
candidates_str: str = ", ".join( candidates_str: str = ", ".join(
str(candidate.function) for candidate in candidates str(candidate.function) for candidate in candidates
) )
if report_errors:
self.reporter.error( self.reporter.error(
location, location,
f"Multiple matching overloads {for_args}: {candidates_str}", f"Multiple matching overloads {for_args}: {candidates_str}",
@@ -863,3 +897,23 @@ class PythonTyper(
if not self.is_subtype(type1, type2): if not self.is_subtype(type1, type2):
return False return False
return True return True
def _get_iterator_type(self, expr: p.Expr) -> Optional[Type]:
# TODO: lookup __iter__
type: Type = self.type_of(expr)
getitem: Optional[Type] = self.types.lookup_member(type, "__getitem__")
if getitem is None:
return None
index: p.Expr = p.LiteralExpr(location=expr.location, value=0)
index_type: Type = index.accept(
self
) # skip type_of() to avoid recording judgement
result: Optional[Type] = self._get_call_result(
location=expr.location,
callee=getitem,
positional=[(index, index_type)],
keywords={},
report_errors=False,
)
return result

View File

@@ -116,6 +116,9 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
self.resolve(stmt.value) self.resolve(stmt.value)
for target in stmt.targets: for target in stmt.targets:
self._visit_assign(target)
def _visit_assign(self, target: p.Expr):
match target: match target:
case p.VariableExpr(name=name): case p.VariableExpr(name=name):
if not self.is_defined(name): if not self.is_defined(name):
@@ -153,6 +156,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
def visit_pass(self, stmt: p.Pass) -> None: def visit_pass(self, stmt: p.Pass) -> None:
pass pass
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
self.resolve(stmt.iterator)
self._visit_assign(stmt.target)
self.begin_scope()
self.resolve(*stmt.body)
self.end_scope()
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self.resolve(expr.left) self.resolve(expr.left)
self.resolve(expr.right) self.resolve(expr.right)