feat(checker): evaluate function definitions

This commit is contained in:
2026-05-29 12:10:09 +02:00
parent 8906ac3db8
commit fd5399f50a
3 changed files with 103 additions and 3 deletions

View File

@@ -8,13 +8,17 @@ from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.environment import Environment
from midas.checker.operators import OPERATOR_METHODS
from midas.checker.types import Type, UnknownType
from midas.checker.types import Function, Type, UnitType, UnknownType
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
from midas.resolver.midas import MidasResolver
class ReturnException(Exception):
pass
class Checker(
p.Stmt.Visitor[None],
p.Expr.Visitor[Type],
@@ -63,6 +67,16 @@ class Checker(
def evaluate(self, expr: p.Expr) -> Type:
return expr.accept(self)
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> None:
previous_env: Environment = self.env
self.env = env
for stmt in block:
try:
stmt.accept(self)
except ReturnException:
break
self.env = previous_env
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
self.diagnostics = []
for stmt in statements:
@@ -105,7 +119,69 @@ class Checker(
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
self.evaluate(stmt.expr)
def visit_function(self, stmt: p.Function) -> None: ...
def visit_function(self, stmt: p.Function) -> None:
env: Environment = Environment(self.env)
pos_args: list[Function.Argument] = []
args: list[Function.Argument] = []
kw_args: list[Function.Argument] = []
def eval_arg_type(arg: p.Function.Argument) -> Type:
if arg.type is None:
return UnknownType()
return arg.type.accept(self)
for arg in stmt.posonlyargs:
pos_args.append(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
)
)
for arg in stmt.args:
args.append(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
)
)
for arg in stmt.kwonlyargs:
kw_args.append(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
)
)
for arg in pos_args + args + kw_args:
env.define(arg.name, arg.type)
self.evaluate_block(stmt.body, env)
inferred_return: Type = UnknownType()
if len(env.return_types) == 1:
inferred_return = list(env.return_types)[0]
elif len(env.return_types) > 1:
self.error(
stmt.location,
f"Mixed return types: {env.return_types}",
)
returns: Type = UnknownType()
if stmt.returns is not None:
returns = stmt.returns.accept(self)
if returns != inferred_return:
self.error(
stmt.returns.location,
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
)
else:
returns = inferred_return
function: Function = Function(
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns,
)
self.env.define(stmt.name, function)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
# TODO check not yet defined locally
@@ -132,6 +208,11 @@ class Checker(
f"Cannot assign {value} to {name} of type {var_type}",
)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
self.env.return_types.add(type)
raise ReturnException()
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
if method is None:

View File

@@ -9,6 +9,7 @@ class Environment:
def __init__(self, enclosing: Optional[Environment] = None) -> None:
self.enclosing: Optional[Environment] = enclosing
self.values: dict[str, Type] = {}
self.return_types: set[Type] = set()
def define(self, name: str, value: Type):
self.values[name] = value

View File

@@ -19,4 +19,22 @@ class UnknownType:
pass
Type = BaseType | SimpleType | UnknownType
@dataclass(frozen=True, kw_only=True)
class UnitType:
pass
@dataclass(frozen=True, kw_only=True)
class Function:
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
returns: Type
@dataclass(frozen=True, kw_only=True)
class Argument:
name: str
type: Type
Type = BaseType | SimpleType | UnknownType | UnitType | Function