feat(checker): type check predicate body
This commit is contained in:
@@ -31,7 +31,7 @@ class TypedParamSpec:
|
|||||||
kw: list[Function.Argument]
|
kw: list[Function.Argument]
|
||||||
|
|
||||||
|
|
||||||
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
|
||||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||||
|
|
||||||
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||||
@@ -41,6 +41,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
self.types: TypesRegistry = types
|
self.types: TypesRegistry = types
|
||||||
self._local_variables: dict[str, TypeVar] = {}
|
self._local_variables: dict[str, TypeVar] = {}
|
||||||
|
|
||||||
|
self._predicate_params: dict[str, Type] = {}
|
||||||
|
|
||||||
self._current_name: Optional[str] = None
|
self._current_name: Optional[str] = None
|
||||||
|
|
||||||
define_builtins(self.types)
|
define_builtins(self.types)
|
||||||
@@ -59,6 +61,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
self.reporter.error(error.token.get_location(), error.message)
|
self.reporter.error(error.token.get_location(), error.message)
|
||||||
self.resolve(stmts)
|
self.resolve(stmts)
|
||||||
|
|
||||||
|
def type_of(self, expr: m.Expr) -> Type:
|
||||||
|
type: Type = expr.accept(self)
|
||||||
|
return type
|
||||||
|
|
||||||
def get_type(self, name: str) -> Type:
|
def get_type(self, name: str) -> Type:
|
||||||
"""Get a type from its name
|
"""Get a type from its name
|
||||||
|
|
||||||
@@ -75,6 +81,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
return self._local_variables[name]
|
return self._local_variables[name]
|
||||||
return self.types.get_type(name)
|
return self.types.get_type(name)
|
||||||
|
|
||||||
|
def get_variable(self, name: str) -> Type:
|
||||||
|
if name in self._predicate_params:
|
||||||
|
return self._predicate_params[name]
|
||||||
|
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
|
||||||
|
if predicate is None:
|
||||||
|
raise NameError(f"Unknown variable '{name}'")
|
||||||
|
return predicate.type
|
||||||
|
|
||||||
def resolve(self, stmts: list[m.Stmt]):
|
def resolve(self, stmts: list[m.Stmt]):
|
||||||
"""Process a sequence of statements
|
"""Process a sequence of statements
|
||||||
|
|
||||||
@@ -84,6 +98,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
for stmt in stmts:
|
for stmt in stmts:
|
||||||
stmt.accept(self)
|
stmt.accept(self)
|
||||||
|
|
||||||
|
def assert_bool(self, expr: m.Expr):
|
||||||
|
type: Type = self.type_of(expr)
|
||||||
|
if not self.types.is_subtype(type, self._bool):
|
||||||
|
self.reporter.error(expr.location, f"Must be a boolean but is {type}")
|
||||||
|
|
||||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||||
name: str = stmt.name.lexeme
|
name: str = stmt.name.lexeme
|
||||||
self._current_name = name
|
self._current_name = name
|
||||||
@@ -118,17 +137,31 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
)
|
)
|
||||||
|
|
||||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||||
|
for spec in stmt.params:
|
||||||
|
for param in spec.mixed:
|
||||||
|
assert param.name is not None
|
||||||
|
self._predicate_params[param.name.lexeme] = param.type.accept(self)
|
||||||
|
|
||||||
|
type: Type = self.type_of(stmt.body)
|
||||||
params: list[TypedParamSpec] = [
|
params: list[TypedParamSpec] = [
|
||||||
self._visit_param_spec(spec) for spec in stmt.params
|
self._visit_param_spec(spec) for spec in stmt.params
|
||||||
]
|
]
|
||||||
type: Type = self._bool
|
|
||||||
for spec in reversed(params):
|
if not self._is_valid_predicate(type):
|
||||||
type = Function(
|
self.reporter.error(
|
||||||
pos_args=spec.pos,
|
stmt.body.location,
|
||||||
args=spec.mixed,
|
f"Predicate function body must evaluate to a boolean, got {type}",
|
||||||
kw_args=spec.kw,
|
|
||||||
returns=type,
|
|
||||||
)
|
)
|
||||||
|
if len(params) != 0:
|
||||||
|
type = self._bool
|
||||||
|
for spec in reversed(params):
|
||||||
|
type = Function(
|
||||||
|
pos_args=spec.pos,
|
||||||
|
args=spec.mixed,
|
||||||
|
kw_args=spec.kw,
|
||||||
|
returns=type,
|
||||||
|
)
|
||||||
|
self._predicate_params = {}
|
||||||
self.types.define_predicate(
|
self.types.define_predicate(
|
||||||
stmt.name.lexeme,
|
stmt.name.lexeme,
|
||||||
Predicate(
|
Predicate(
|
||||||
@@ -137,32 +170,90 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
def _is_valid_predicate(self, body: Type) -> bool:
|
||||||
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
|
match body:
|
||||||
|
case Function(returns=returns):
|
||||||
|
return self._is_valid_predicate(returns)
|
||||||
|
case _ if self.types.is_subtype(body, self._bool):
|
||||||
|
return True
|
||||||
|
case _:
|
||||||
|
return False
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
|
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type:
|
||||||
|
self.assert_bool(expr.left)
|
||||||
|
self.assert_bool(expr.right)
|
||||||
|
return self._bool
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
|
||||||
|
# TODO
|
||||||
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
|
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
|
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
|
||||||
|
# TODO
|
||||||
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
|
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
def visit_call_expr(self, expr: m.CallExpr) -> None:
|
def visit_call_expr(self, expr: m.CallExpr) -> Type:
|
||||||
self.reporter.warning(expr.location, "CallExpr not yet supported")
|
callee: Type = expr.callee.accept(self)
|
||||||
|
if not isinstance(callee, Function):
|
||||||
|
self.reporter.error(expr.location, f"Cannot call {callee}")
|
||||||
|
return UnknownType()
|
||||||
|
args: list[Type] = [arg.accept(self) for arg in expr.arguments]
|
||||||
|
|
||||||
def visit_get_expr(self, expr: m.GetExpr) -> None:
|
n_args: int = len(args)
|
||||||
self.reporter.warning(expr.location, "GetExpr not yet supported")
|
n_params: int = len(callee.args)
|
||||||
|
if n_args != n_params:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Wrong number of argument, expected {n_params}, got {n_args}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
|
valid: bool = True
|
||||||
self.reporter.warning(expr.location, "VariableExpr not yet supported")
|
for arg, param in zip(args, callee.args):
|
||||||
|
if not self.types.is_subtype(arg, param.type):
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Invalid argument type at pos {param.pos}, expected {param.type}, got {arg}",
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
if not valid:
|
||||||
|
return UnknownType()
|
||||||
|
return callee.returns
|
||||||
|
|
||||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
def visit_get_expr(self, expr: m.GetExpr) -> Type:
|
||||||
|
object: Type = expr.expr.accept(self)
|
||||||
|
member: Optional[Type] = self.types.lookup_member(object, expr.name.lexeme)
|
||||||
|
if member is None:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Unknown member '{expr.name}' of {object}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return member
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: m.VariableExpr) -> Type:
|
||||||
|
return self.get_variable(expr.name.lexeme)
|
||||||
|
|
||||||
|
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
|
||||||
return expr.expr.accept(self)
|
return expr.expr.accept(self)
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type:
|
||||||
self.reporter.warning(expr.location, "LiteralExpr not yet supported")
|
match expr.value:
|
||||||
|
case bool(): # Must be before int
|
||||||
|
return self.types.get_type("bool")
|
||||||
|
case int():
|
||||||
|
return self.types.get_type("int")
|
||||||
|
case float():
|
||||||
|
return self.types.get_type("float")
|
||||||
|
case str():
|
||||||
|
return self.types.get_type("str")
|
||||||
|
case _:
|
||||||
|
self.reporter.warning(expr.location, f"Unknown literal {expr}")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type:
|
||||||
self.reporter.warning(expr.location, "WildcardExpr not yet supported")
|
return self.get_variable("_")
|
||||||
|
|
||||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||||
name: str = type.name.lexeme
|
name: str = type.name.lexeme
|
||||||
|
|||||||
@@ -652,7 +652,7 @@ class PythonTyper(
|
|||||||
If the function has overloads, the function will try to resolve the
|
If the function has overloads, the function will try to resolve the
|
||||||
appropriate signature.
|
appropriate signature.
|
||||||
Argument types are matched to the defined parameters.
|
Argument types are matched to the defined parameters.
|
||||||
The function doesn't take the raw expression as a parameter to accomodate
|
The function doesn't take the raw expression as a parameter to accommodate
|
||||||
for desugared calls such as for operators.
|
for desugared calls such as for operators.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
Reference in New Issue
Block a user