feat(checker): type check predicate body
This commit is contained in:
@@ -31,7 +31,7 @@ class TypedParamSpec:
|
||||
kw: list[Function.Argument]
|
||||
|
||||
|
||||
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
||||
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
|
||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||
|
||||
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||
@@ -41,6 +41,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
self.types: TypesRegistry = types
|
||||
self._local_variables: dict[str, TypeVar] = {}
|
||||
|
||||
self._predicate_params: dict[str, Type] = {}
|
||||
|
||||
self._current_name: Optional[str] = None
|
||||
|
||||
define_builtins(self.types)
|
||||
@@ -59,6 +61,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
self.reporter.error(error.token.get_location(), error.message)
|
||||
self.resolve(stmts)
|
||||
|
||||
def type_of(self, expr: m.Expr) -> Type:
|
||||
type: Type = expr.accept(self)
|
||||
return type
|
||||
|
||||
def get_type(self, name: str) -> Type:
|
||||
"""Get a type from its name
|
||||
|
||||
@@ -75,6 +81,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
return self._local_variables[name]
|
||||
return self.types.get_type(name)
|
||||
|
||||
def get_variable(self, name: str) -> Type:
|
||||
if name in self._predicate_params:
|
||||
return self._predicate_params[name]
|
||||
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
|
||||
if predicate is None:
|
||||
raise NameError(f"Unknown variable '{name}'")
|
||||
return predicate.type
|
||||
|
||||
def resolve(self, stmts: list[m.Stmt]):
|
||||
"""Process a sequence of statements
|
||||
|
||||
@@ -84,6 +98,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
for stmt in stmts:
|
||||
stmt.accept(self)
|
||||
|
||||
def assert_bool(self, expr: m.Expr):
|
||||
type: Type = self.type_of(expr)
|
||||
if not self.types.is_subtype(type, self._bool):
|
||||
self.reporter.error(expr.location, f"Must be a boolean but is {type}")
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
name: str = stmt.name.lexeme
|
||||
self._current_name = name
|
||||
@@ -118,17 +137,31 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||
for spec in stmt.params:
|
||||
for param in spec.mixed:
|
||||
assert param.name is not None
|
||||
self._predicate_params[param.name.lexeme] = param.type.accept(self)
|
||||
|
||||
type: Type = self.type_of(stmt.body)
|
||||
params: list[TypedParamSpec] = [
|
||||
self._visit_param_spec(spec) for spec in stmt.params
|
||||
]
|
||||
type: Type = self._bool
|
||||
for spec in reversed(params):
|
||||
type = Function(
|
||||
pos_args=spec.pos,
|
||||
args=spec.mixed,
|
||||
kw_args=spec.kw,
|
||||
returns=type,
|
||||
|
||||
if not self._is_valid_predicate(type):
|
||||
self.reporter.error(
|
||||
stmt.body.location,
|
||||
f"Predicate function body must evaluate to a boolean, got {type}",
|
||||
)
|
||||
if len(params) != 0:
|
||||
type = self._bool
|
||||
for spec in reversed(params):
|
||||
type = Function(
|
||||
pos_args=spec.pos,
|
||||
args=spec.mixed,
|
||||
kw_args=spec.kw,
|
||||
returns=type,
|
||||
)
|
||||
self._predicate_params = {}
|
||||
self.types.define_predicate(
|
||||
stmt.name.lexeme,
|
||||
Predicate(
|
||||
@@ -137,32 +170,90 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
),
|
||||
)
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
||||
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
|
||||
def _is_valid_predicate(self, body: Type) -> bool:
|
||||
match body:
|
||||
case Function(returns=returns):
|
||||
return self._is_valid_predicate(returns)
|
||||
case _ if self.types.is_subtype(body, self._bool):
|
||||
return True
|
||||
case _:
|
||||
return False
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type:
|
||||
self.assert_bool(expr.left)
|
||||
self.assert_bool(expr.right)
|
||||
return self._bool
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
|
||||
# TODO
|
||||
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
|
||||
return UnknownType()
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
|
||||
# TODO
|
||||
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
|
||||
return UnknownType()
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> None:
|
||||
self.reporter.warning(expr.location, "CallExpr not yet supported")
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> Type:
|
||||
callee: Type = expr.callee.accept(self)
|
||||
if not isinstance(callee, Function):
|
||||
self.reporter.error(expr.location, f"Cannot call {callee}")
|
||||
return UnknownType()
|
||||
args: list[Type] = [arg.accept(self) for arg in expr.arguments]
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> None:
|
||||
self.reporter.warning(expr.location, "GetExpr not yet supported")
|
||||
n_args: int = len(args)
|
||||
n_params: int = len(callee.args)
|
||||
if n_args != n_params:
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Wrong number of argument, expected {n_params}, got {n_args}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
|
||||
self.reporter.warning(expr.location, "VariableExpr not yet supported")
|
||||
valid: bool = True
|
||||
for arg, param in zip(args, callee.args):
|
||||
if not self.types.is_subtype(arg, param.type):
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Invalid argument type at pos {param.pos}, expected {param.type}, got {arg}",
|
||||
)
|
||||
valid = False
|
||||
if not valid:
|
||||
return UnknownType()
|
||||
return callee.returns
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> Type:
|
||||
object: Type = expr.expr.accept(self)
|
||||
member: Optional[Type] = self.types.lookup_member(object, expr.name.lexeme)
|
||||
if member is None:
|
||||
self.reporter.error(
|
||||
expr.location, f"Unknown member '{expr.name}' of {object}"
|
||||
)
|
||||
return UnknownType()
|
||||
return member
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> Type:
|
||||
return self.get_variable(expr.name.lexeme)
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
|
||||
return expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
||||
self.reporter.warning(expr.location, "LiteralExpr not yet supported")
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type:
|
||||
match expr.value:
|
||||
case bool(): # Must be before int
|
||||
return self.types.get_type("bool")
|
||||
case int():
|
||||
return self.types.get_type("int")
|
||||
case float():
|
||||
return self.types.get_type("float")
|
||||
case str():
|
||||
return self.types.get_type("str")
|
||||
case _:
|
||||
self.reporter.warning(expr.location, f"Unknown literal {expr}")
|
||||
return UnknownType()
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||
self.reporter.warning(expr.location, "WildcardExpr not yet supported")
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type:
|
||||
return self.get_variable("_")
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||
name: str = type.name.lexeme
|
||||
|
||||
@@ -652,7 +652,7 @@ class PythonTyper(
|
||||
If the function has overloads, the function will try to resolve the
|
||||
appropriate signature.
|
||||
Argument types are matched to the defined parameters.
|
||||
The function doesn't take the raw expression as a parameter to accomodate
|
||||
The function doesn't take the raw expression as a parameter to accommodate
|
||||
for desugared calls such as for operators.
|
||||
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user