feat(checker): evaluate function definitions
This commit is contained in:
@@ -8,13 +8,17 @@ from midas.ast.location import Location
|
||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.operators import OPERATOR_METHODS
|
||||
from midas.checker.types import Type, UnknownType
|
||||
from midas.checker.types import Function, Type, UnitType, UnknownType
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
from midas.resolver.midas import MidasResolver
|
||||
|
||||
|
||||
class ReturnException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Checker(
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[Type],
|
||||
@@ -63,6 +67,16 @@ class Checker(
|
||||
def evaluate(self, expr: p.Expr) -> Type:
|
||||
return expr.accept(self)
|
||||
|
||||
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> None:
|
||||
previous_env: Environment = self.env
|
||||
self.env = env
|
||||
for stmt in block:
|
||||
try:
|
||||
stmt.accept(self)
|
||||
except ReturnException:
|
||||
break
|
||||
self.env = previous_env
|
||||
|
||||
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
|
||||
self.diagnostics = []
|
||||
for stmt in statements:
|
||||
@@ -105,7 +119,69 @@ class Checker(
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
self.evaluate(stmt.expr)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None: ...
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
env: Environment = Environment(self.env)
|
||||
pos_args: list[Function.Argument] = []
|
||||
args: list[Function.Argument] = []
|
||||
kw_args: list[Function.Argument] = []
|
||||
|
||||
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
||||
if arg.type is None:
|
||||
return UnknownType()
|
||||
return arg.type.accept(self)
|
||||
|
||||
for arg in stmt.posonlyargs:
|
||||
pos_args.append(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
)
|
||||
)
|
||||
for arg in stmt.args:
|
||||
args.append(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
)
|
||||
)
|
||||
for arg in stmt.kwonlyargs:
|
||||
kw_args.append(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
)
|
||||
)
|
||||
|
||||
for arg in pos_args + args + kw_args:
|
||||
env.define(arg.name, arg.type)
|
||||
|
||||
self.evaluate_block(stmt.body, env)
|
||||
inferred_return: Type = UnknownType()
|
||||
if len(env.return_types) == 1:
|
||||
inferred_return = list(env.return_types)[0]
|
||||
elif len(env.return_types) > 1:
|
||||
self.error(
|
||||
stmt.location,
|
||||
f"Mixed return types: {env.return_types}",
|
||||
)
|
||||
returns: Type = UnknownType()
|
||||
if stmt.returns is not None:
|
||||
returns = stmt.returns.accept(self)
|
||||
if returns != inferred_return:
|
||||
self.error(
|
||||
stmt.returns.location,
|
||||
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
||||
)
|
||||
else:
|
||||
returns = inferred_return
|
||||
|
||||
function: Function = Function(
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns,
|
||||
)
|
||||
self.env.define(stmt.name, function)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
# TODO check not yet defined locally
|
||||
@@ -132,6 +208,11 @@ class Checker(
|
||||
f"Cannot assign {value} to {name} of type {var_type}",
|
||||
)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
|
||||
self.env.return_types.add(type)
|
||||
raise ReturnException()
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
|
||||
@@ -9,6 +9,7 @@ class Environment:
|
||||
def __init__(self, enclosing: Optional[Environment] = None) -> None:
|
||||
self.enclosing: Optional[Environment] = enclosing
|
||||
self.values: dict[str, Type] = {}
|
||||
self.return_types: set[Type] = set()
|
||||
|
||||
def define(self, name: str, value: Type):
|
||||
self.values[name] = value
|
||||
|
||||
@@ -19,4 +19,22 @@ class UnknownType:
|
||||
pass
|
||||
|
||||
|
||||
Type = BaseType | SimpleType | UnknownType
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class UnitType:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Function:
|
||||
pos_args: list[Argument]
|
||||
args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
name: str
|
||||
type: Type
|
||||
|
||||
|
||||
Type = BaseType | SimpleType | UnknownType | UnitType | Function
|
||||
|
||||
Reference in New Issue
Block a user