diff --git a/examples/01_simple_type_checking/04_complex_types.midas b/examples/01_simple_type_checking/04_complex_types.midas index b561cef..adc76b3 100644 --- a/examples/01_simple_type_checking/04_complex_types.midas +++ b/examples/01_simple_type_checking/04_complex_types.midas @@ -1,8 +1,8 @@ type Meter = float extend Meter { - def __add__: fn(Meter) -> Meter - def __sub__: fn(Meter) -> Meter + def __add__: fn(Meter, /) -> Meter + def __sub__: fn(Meter, /) -> Meter } type Coordinate = object diff --git a/midas/checker/python.py b/midas/checker/python.py index cf25593..88ecde0 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -78,6 +78,12 @@ class PythonTyper( self.judgements.append((expr, type)) return type + def resolve_type_expr(self, expr: p.MidasType) -> Type: + return expr.accept(self) + + def process_stmt(self, stmt: p.Stmt) -> None: + stmt.accept(self) + def process_block(self, block: list[p.Stmt], env: Environment) -> bool: """Evaluate a sequence of statements @@ -93,7 +99,7 @@ class PythonTyper( returned: bool = False for i, stmt in enumerate(block): try: - stmt.accept(self) + self.process_stmt(stmt) except ReturnException: returned = True if i < len(block) - 1: @@ -111,7 +117,7 @@ class PythonTyper( statements (list[p.Stmt]): the statements to evaluate and check """ for stmt in statements: - stmt.accept(self) + self.process_stmt(stmt) self.logger.debug(f"Final environment: {self.env.flat_dict()}") @@ -144,9 +150,9 @@ class PythonTyper( def eval_arg_type(arg: p.Function.Argument) -> Type: if arg.type is not None: - return arg.type.accept(self) + return self.resolve_type_expr(arg.type) if arg.default is not None: - return arg.default.accept(self) + return self.type_of(arg.default) return UnknownType() pos: int = 0 @@ -186,7 +192,7 @@ class PythonTyper( returns_hint: Optional[Type] = None if stmt.returns is not None: - returns_hint = stmt.returns.accept(self) + returns_hint = self.resolve_type_expr(stmt.returns) # Early define to handle simple fully-typed recursion inside_function: Function = Function( pos_args=pos_args, @@ -232,7 +238,7 @@ class PythonTyper( def visit_type_assign(self, stmt: p.TypeAssign) -> None: # TODO check not yet defined locally - type: Type = stmt.type.accept(self) + type: Type = self.resolve_type_expr(stmt.type) self.env.define(stmt.name, type) def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: @@ -287,7 +293,7 @@ class PythonTyper( ) def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: - type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType() + type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType() self.env.return_types.append(type) raise ReturnException() @@ -297,7 +303,7 @@ class PythonTyper( # if (m := 1 + 1) < 2: # ... # print(m) # <- m is still defined - test_type: Type = stmt.test.accept(self) + test_type: Type = self.type_of(stmt.test) # TODO Allow subtypes or any type if test_type != self.types.get_type("bool"): @@ -419,8 +425,8 @@ class PythonTyper( return type or UnknownType() def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: - left: Type = expr.left.accept(self) - right: Type = expr.right.accept(self) + left: Type = self.type_of(expr.left) + right: Type = self.type_of(expr.right) if self.is_subtype(left, right): return right @@ -434,10 +440,10 @@ class PythonTyper( return UnknownType() def visit_cast_expr(self, expr: p.CastExpr) -> Type: - return expr.type.accept(self) + return self.resolve_type_expr(expr.type) def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type: - test_type: Type = expr.test.accept(self) + test_type: Type = self.type_of(expr.test) # TODO Allow subtypes or any type if test_type != self.types.get_type("bool"): @@ -445,8 +451,8 @@ class PythonTyper( expr.test.location, f"If test must be a boolean, got {test_type}" ) - true_type: Type = expr.if_true.accept(self) - false_type: Type = expr.if_false.accept(self) + true_type: Type = self.type_of(expr.if_true) + false_type: Type = self.type_of(expr.if_false) if self.is_subtype(true_type, false_type): return false_type if self.is_subtype(false_type, true_type): @@ -484,7 +490,7 @@ class PythonTyper( return UnknownType() if node.param is not None: - param: Type = node.param.accept(self) + param: Type = self.resolve_type_expr(node.param) return self.types.apply_generic(base, [param]) return base