diff --git a/midas/checker/checker.py b/midas/checker/checker.py index bc0d323..dcc7188 100644 --- a/midas/checker/checker.py +++ b/midas/checker/checker.py @@ -8,13 +8,17 @@ from midas.ast.location import Location from midas.checker.diagnostic import Diagnostic, DiagnosticType from midas.checker.environment import Environment from midas.checker.operators import OPERATOR_METHODS -from midas.checker.types import Type, UnknownType +from midas.checker.types import Function, Type, UnitType, UnknownType from midas.lexer.midas import MidasLexer from midas.lexer.token import Token from midas.parser.midas import MidasParser from midas.resolver.midas import MidasResolver +class ReturnException(Exception): + pass + + class Checker( p.Stmt.Visitor[None], p.Expr.Visitor[Type], @@ -63,6 +67,16 @@ class Checker( def evaluate(self, expr: p.Expr) -> Type: return expr.accept(self) + def evaluate_block(self, block: list[p.Stmt], env: Environment) -> None: + previous_env: Environment = self.env + self.env = env + for stmt in block: + try: + stmt.accept(self) + except ReturnException: + break + self.env = previous_env + def check(self, statements: list[p.Stmt]) -> list[Diagnostic]: self.diagnostics = [] for stmt in statements: @@ -105,7 +119,69 @@ class Checker( def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: self.evaluate(stmt.expr) - def visit_function(self, stmt: p.Function) -> None: ... + def visit_function(self, stmt: p.Function) -> None: + env: Environment = Environment(self.env) + pos_args: list[Function.Argument] = [] + args: list[Function.Argument] = [] + kw_args: list[Function.Argument] = [] + + def eval_arg_type(arg: p.Function.Argument) -> Type: + if arg.type is None: + return UnknownType() + return arg.type.accept(self) + + for arg in stmt.posonlyargs: + pos_args.append( + Function.Argument( + name=arg.name, + type=eval_arg_type(arg), + ) + ) + for arg in stmt.args: + args.append( + Function.Argument( + name=arg.name, + type=eval_arg_type(arg), + ) + ) + for arg in stmt.kwonlyargs: + kw_args.append( + Function.Argument( + name=arg.name, + type=eval_arg_type(arg), + ) + ) + + for arg in pos_args + args + kw_args: + env.define(arg.name, arg.type) + + self.evaluate_block(stmt.body, env) + inferred_return: Type = UnknownType() + if len(env.return_types) == 1: + inferred_return = list(env.return_types)[0] + elif len(env.return_types) > 1: + self.error( + stmt.location, + f"Mixed return types: {env.return_types}", + ) + returns: Type = UnknownType() + if stmt.returns is not None: + returns = stmt.returns.accept(self) + if returns != inferred_return: + self.error( + stmt.returns.location, + f"Return type mismatch, annotated {returns} but returns {inferred_return}", + ) + else: + returns = inferred_return + + function: Function = Function( + pos_args=pos_args, + args=args, + kw_args=kw_args, + returns=returns, + ) + self.env.define(stmt.name, function) def visit_type_assign(self, stmt: p.TypeAssign) -> None: # TODO check not yet defined locally @@ -132,6 +208,11 @@ class Checker( f"Cannot assign {value} to {name} of type {var_type}", ) + def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: + type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType() + self.env.return_types.add(type) + raise ReturnException() + def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) if method is None: diff --git a/midas/checker/environment.py b/midas/checker/environment.py index dc5cac8..fa6dfe2 100644 --- a/midas/checker/environment.py +++ b/midas/checker/environment.py @@ -9,6 +9,7 @@ class Environment: def __init__(self, enclosing: Optional[Environment] = None) -> None: self.enclosing: Optional[Environment] = enclosing self.values: dict[str, Type] = {} + self.return_types: set[Type] = set() def define(self, name: str, value: Type): self.values[name] = value diff --git a/midas/checker/types.py b/midas/checker/types.py index ea94b45..2afd0ae 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -19,4 +19,22 @@ class UnknownType: pass -Type = BaseType | SimpleType | UnknownType +@dataclass(frozen=True, kw_only=True) +class UnitType: + pass + + +@dataclass(frozen=True, kw_only=True) +class Function: + pos_args: list[Argument] + args: list[Argument] + kw_args: list[Argument] + returns: Type + + @dataclass(frozen=True, kw_only=True) + class Argument: + name: str + type: Type + + +Type = BaseType | SimpleType | UnknownType | UnitType | Function