feat(checker): evaluate function definitions

This commit is contained in:
2026-05-29 12:10:09 +02:00
parent 8906ac3db8
commit fd5399f50a
3 changed files with 103 additions and 3 deletions

View File

@@ -8,13 +8,17 @@ from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic, DiagnosticType from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.operators import OPERATOR_METHODS from midas.checker.operators import OPERATOR_METHODS
from midas.checker.types import Type, UnknownType from midas.checker.types import Function, Type, UnitType, UnknownType
from midas.lexer.midas import MidasLexer from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token from midas.lexer.token import Token
from midas.parser.midas import MidasParser from midas.parser.midas import MidasParser
from midas.resolver.midas import MidasResolver from midas.resolver.midas import MidasResolver
class ReturnException(Exception):
pass
class Checker( class Checker(
p.Stmt.Visitor[None], p.Stmt.Visitor[None],
p.Expr.Visitor[Type], p.Expr.Visitor[Type],
@@ -63,6 +67,16 @@ class Checker(
def evaluate(self, expr: p.Expr) -> Type: def evaluate(self, expr: p.Expr) -> Type:
return expr.accept(self) return expr.accept(self)
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> None:
previous_env: Environment = self.env
self.env = env
for stmt in block:
try:
stmt.accept(self)
except ReturnException:
break
self.env = previous_env
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]: def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
self.diagnostics = [] self.diagnostics = []
for stmt in statements: for stmt in statements:
@@ -105,7 +119,69 @@ class Checker(
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
self.evaluate(stmt.expr) self.evaluate(stmt.expr)
def visit_function(self, stmt: p.Function) -> None: ... def visit_function(self, stmt: p.Function) -> None:
env: Environment = Environment(self.env)
pos_args: list[Function.Argument] = []
args: list[Function.Argument] = []
kw_args: list[Function.Argument] = []
def eval_arg_type(arg: p.Function.Argument) -> Type:
if arg.type is None:
return UnknownType()
return arg.type.accept(self)
for arg in stmt.posonlyargs:
pos_args.append(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
)
)
for arg in stmt.args:
args.append(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
)
)
for arg in stmt.kwonlyargs:
kw_args.append(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
)
)
for arg in pos_args + args + kw_args:
env.define(arg.name, arg.type)
self.evaluate_block(stmt.body, env)
inferred_return: Type = UnknownType()
if len(env.return_types) == 1:
inferred_return = list(env.return_types)[0]
elif len(env.return_types) > 1:
self.error(
stmt.location,
f"Mixed return types: {env.return_types}",
)
returns: Type = UnknownType()
if stmt.returns is not None:
returns = stmt.returns.accept(self)
if returns != inferred_return:
self.error(
stmt.returns.location,
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
)
else:
returns = inferred_return
function: Function = Function(
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns,
)
self.env.define(stmt.name, function)
def visit_type_assign(self, stmt: p.TypeAssign) -> None: def visit_type_assign(self, stmt: p.TypeAssign) -> None:
# TODO check not yet defined locally # TODO check not yet defined locally
@@ -132,6 +208,11 @@ class Checker(
f"Cannot assign {value} to {name} of type {var_type}", f"Cannot assign {value} to {name} of type {var_type}",
) )
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
self.env.return_types.add(type)
raise ReturnException()
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
if method is None: if method is None:

View File

@@ -9,6 +9,7 @@ class Environment:
def __init__(self, enclosing: Optional[Environment] = None) -> None: def __init__(self, enclosing: Optional[Environment] = None) -> None:
self.enclosing: Optional[Environment] = enclosing self.enclosing: Optional[Environment] = enclosing
self.values: dict[str, Type] = {} self.values: dict[str, Type] = {}
self.return_types: set[Type] = set()
def define(self, name: str, value: Type): def define(self, name: str, value: Type):
self.values[name] = value self.values[name] = value

View File

@@ -19,4 +19,22 @@ class UnknownType:
pass pass
Type = BaseType | SimpleType | UnknownType @dataclass(frozen=True, kw_only=True)
class UnitType:
pass
@dataclass(frozen=True, kw_only=True)
class Function:
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
returns: Type
@dataclass(frozen=True, kw_only=True)
class Argument:
name: str
type: Type
Type = BaseType | SimpleType | UnknownType | UnitType | Function