feat(checker): evaluate function definitions
This commit is contained in:
@@ -8,13 +8,17 @@ from midas.ast.location import Location
|
|||||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
from midas.checker.operators import OPERATOR_METHODS
|
from midas.checker.operators import OPERATOR_METHODS
|
||||||
from midas.checker.types import Type, UnknownType
|
from midas.checker.types import Function, Type, UnitType, UnknownType
|
||||||
from midas.lexer.midas import MidasLexer
|
from midas.lexer.midas import MidasLexer
|
||||||
from midas.lexer.token import Token
|
from midas.lexer.token import Token
|
||||||
from midas.parser.midas import MidasParser
|
from midas.parser.midas import MidasParser
|
||||||
from midas.resolver.midas import MidasResolver
|
from midas.resolver.midas import MidasResolver
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Checker(
|
class Checker(
|
||||||
p.Stmt.Visitor[None],
|
p.Stmt.Visitor[None],
|
||||||
p.Expr.Visitor[Type],
|
p.Expr.Visitor[Type],
|
||||||
@@ -63,6 +67,16 @@ class Checker(
|
|||||||
def evaluate(self, expr: p.Expr) -> Type:
|
def evaluate(self, expr: p.Expr) -> Type:
|
||||||
return expr.accept(self)
|
return expr.accept(self)
|
||||||
|
|
||||||
|
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> None:
|
||||||
|
previous_env: Environment = self.env
|
||||||
|
self.env = env
|
||||||
|
for stmt in block:
|
||||||
|
try:
|
||||||
|
stmt.accept(self)
|
||||||
|
except ReturnException:
|
||||||
|
break
|
||||||
|
self.env = previous_env
|
||||||
|
|
||||||
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
|
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
|
||||||
self.diagnostics = []
|
self.diagnostics = []
|
||||||
for stmt in statements:
|
for stmt in statements:
|
||||||
@@ -105,7 +119,69 @@ class Checker(
|
|||||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||||
self.evaluate(stmt.expr)
|
self.evaluate(stmt.expr)
|
||||||
|
|
||||||
def visit_function(self, stmt: p.Function) -> None: ...
|
def visit_function(self, stmt: p.Function) -> None:
|
||||||
|
env: Environment = Environment(self.env)
|
||||||
|
pos_args: list[Function.Argument] = []
|
||||||
|
args: list[Function.Argument] = []
|
||||||
|
kw_args: list[Function.Argument] = []
|
||||||
|
|
||||||
|
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
||||||
|
if arg.type is None:
|
||||||
|
return UnknownType()
|
||||||
|
return arg.type.accept(self)
|
||||||
|
|
||||||
|
for arg in stmt.posonlyargs:
|
||||||
|
pos_args.append(
|
||||||
|
Function.Argument(
|
||||||
|
name=arg.name,
|
||||||
|
type=eval_arg_type(arg),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for arg in stmt.args:
|
||||||
|
args.append(
|
||||||
|
Function.Argument(
|
||||||
|
name=arg.name,
|
||||||
|
type=eval_arg_type(arg),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for arg in stmt.kwonlyargs:
|
||||||
|
kw_args.append(
|
||||||
|
Function.Argument(
|
||||||
|
name=arg.name,
|
||||||
|
type=eval_arg_type(arg),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for arg in pos_args + args + kw_args:
|
||||||
|
env.define(arg.name, arg.type)
|
||||||
|
|
||||||
|
self.evaluate_block(stmt.body, env)
|
||||||
|
inferred_return: Type = UnknownType()
|
||||||
|
if len(env.return_types) == 1:
|
||||||
|
inferred_return = list(env.return_types)[0]
|
||||||
|
elif len(env.return_types) > 1:
|
||||||
|
self.error(
|
||||||
|
stmt.location,
|
||||||
|
f"Mixed return types: {env.return_types}",
|
||||||
|
)
|
||||||
|
returns: Type = UnknownType()
|
||||||
|
if stmt.returns is not None:
|
||||||
|
returns = stmt.returns.accept(self)
|
||||||
|
if returns != inferred_return:
|
||||||
|
self.error(
|
||||||
|
stmt.returns.location,
|
||||||
|
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
returns = inferred_return
|
||||||
|
|
||||||
|
function: Function = Function(
|
||||||
|
pos_args=pos_args,
|
||||||
|
args=args,
|
||||||
|
kw_args=kw_args,
|
||||||
|
returns=returns,
|
||||||
|
)
|
||||||
|
self.env.define(stmt.name, function)
|
||||||
|
|
||||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||||
# TODO check not yet defined locally
|
# TODO check not yet defined locally
|
||||||
@@ -132,6 +208,11 @@ class Checker(
|
|||||||
f"Cannot assign {value} to {name} of type {var_type}",
|
f"Cannot assign {value} to {name} of type {var_type}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||||
|
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
|
||||||
|
self.env.return_types.add(type)
|
||||||
|
raise ReturnException()
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
||||||
if method is None:
|
if method is None:
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ class Environment:
|
|||||||
def __init__(self, enclosing: Optional[Environment] = None) -> None:
|
def __init__(self, enclosing: Optional[Environment] = None) -> None:
|
||||||
self.enclosing: Optional[Environment] = enclosing
|
self.enclosing: Optional[Environment] = enclosing
|
||||||
self.values: dict[str, Type] = {}
|
self.values: dict[str, Type] = {}
|
||||||
|
self.return_types: set[Type] = set()
|
||||||
|
|
||||||
def define(self, name: str, value: Type):
|
def define(self, name: str, value: Type):
|
||||||
self.values[name] = value
|
self.values[name] = value
|
||||||
|
|||||||
@@ -19,4 +19,22 @@ class UnknownType:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
Type = BaseType | SimpleType | UnknownType
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class UnitType:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Function:
|
||||||
|
pos_args: list[Argument]
|
||||||
|
args: list[Argument]
|
||||||
|
kw_args: list[Argument]
|
||||||
|
returns: Type
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Argument:
|
||||||
|
name: str
|
||||||
|
type: Type
|
||||||
|
|
||||||
|
|
||||||
|
Type = BaseType | SimpleType | UnknownType | UnitType | Function
|
||||||
|
|||||||
Reference in New Issue
Block a user