diff --git a/midas/checker/checker.py b/midas/checker/checker.py index cec5c80..2b7f41d 100644 --- a/midas/checker/checker.py +++ b/midas/checker/checker.py @@ -85,21 +85,29 @@ class Checker( """ return expr.accept(self) - def evaluate_block(self, block: list[p.Stmt], env: Environment) -> None: + def evaluate_block(self, block: list[p.Stmt], env: Environment) -> bool: """Evaluate a sequence of statements Args: block (list[p.Stmt]): the statements to evaluate env (Environment): the environment in which to evaluate + + Returns: + bool: whether a return statement is present in the block """ previous_env: Environment = self.env self.env = env - for stmt in block: + returned: bool = False + for i, stmt in enumerate(block): try: stmt.accept(self) except ReturnException: + returned = True + if i < len(block) - 1: + self.warning(block[i + 1].location, "Unreachable statement") break self.env = previous_env + return returned def check(self, statements: list[p.Stmt]) -> list[Diagnostic]: """Type check a sequence of statements and returns diagnostics @@ -276,6 +284,27 @@ class Checker( self.env.return_types.append(type) raise ReturnException() + def visit_if_stmt(self, stmt: p.IfStmt) -> None: + # Not evaluated in sub-environment because assignments in the test leak out of the if + # For example: + # if (m := 1 + 1) < 2: + # ... + # print(m) # <- m is still defined + test_type: Type = stmt.test.accept(self) + + # TODO Allow subtypes or any type + if test_type != self.ctx.get_type("bool"): + self.error( + stmt.test.location, f"If test must be a boolean, got {test_type}" + ) + + env: Environment = Environment(self.env) + body_returned: bool = self.evaluate_block(stmt.body, env) + else_returned: bool = self.evaluate_block(stmt.orelse, env) + self.env.return_types.extend(env.return_types) + if body_returned and else_returned: + raise ReturnException() + def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) if method is None: diff --git a/midas/resolver/resolver.py b/midas/resolver/resolver.py index 9d7581f..221c221 100644 --- a/midas/resolver/resolver.py +++ b/midas/resolver/resolver.py @@ -121,6 +121,24 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): if stmt.value is not None: self.resolve(stmt.value) + def visit_if_stmt(self, stmt: p.IfStmt) -> None: + # Not resolved in sub-environment because assignments in the test leak out of the if + # For example: + # if (m := 1 + 1) < 2: + # ... + # print(m) # <- m is still defined + self.resolve(stmt.test) + + # Body + self.begin_scope() + self.resolve(*stmt.body) + self.end_scope() + + # Else + self.begin_scope() + self.resolve(*stmt.orelse) + self.end_scope() + def visit_binary_expr(self, expr: p.BinaryExpr) -> None: self.resolve(expr.left) self.resolve(expr.right)