feat(checker): type check if statements

This commit is contained in:
2026-06-01 14:13:17 +02:00
parent 9d45163d9c
commit 65164abadb
2 changed files with 49 additions and 2 deletions

View File

@@ -85,21 +85,29 @@ class Checker(
"""
return expr.accept(self)
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> None:
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> bool:
"""Evaluate a sequence of statements
Args:
block (list[p.Stmt]): the statements to evaluate
env (Environment): the environment in which to evaluate
Returns:
bool: whether a return statement is present in the block
"""
previous_env: Environment = self.env
self.env = env
for stmt in block:
returned: bool = False
for i, stmt in enumerate(block):
try:
stmt.accept(self)
except ReturnException:
returned = True
if i < len(block) - 1:
self.warning(block[i + 1].location, "Unreachable statement")
break
self.env = previous_env
return returned
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
"""Type check a sequence of statements and returns diagnostics
@@ -276,6 +284,27 @@ class Checker(
self.env.return_types.append(type)
raise ReturnException()
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
# Not evaluated in sub-environment because assignments in the test leak out of the if
# For example:
# if (m := 1 + 1) < 2:
# ...
# print(m) # <- m is still defined
test_type: Type = stmt.test.accept(self)
# TODO Allow subtypes or any type
if test_type != self.ctx.get_type("bool"):
self.error(
stmt.test.location, f"If test must be a boolean, got {test_type}"
)
env: Environment = Environment(self.env)
body_returned: bool = self.evaluate_block(stmt.body, env)
else_returned: bool = self.evaluate_block(stmt.orelse, env)
self.env.return_types.extend(env.return_types)
if body_returned and else_returned:
raise ReturnException()
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
if method is None:

View File

@@ -121,6 +121,24 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
if stmt.value is not None:
self.resolve(stmt.value)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
# Not resolved in sub-environment because assignments in the test leak out of the if
# For example:
# if (m := 1 + 1) < 2:
# ...
# print(m) # <- m is still defined
self.resolve(stmt.test)
# Body
self.begin_scope()
self.resolve(*stmt.body)
self.end_scope()
# Else
self.begin_scope()
self.resolve(*stmt.orelse)
self.end_scope()
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self.resolve(expr.left)
self.resolve(expr.right)