feat(checker): type check predicate body

This commit is contained in:
2026-06-19 13:55:32 +02:00
parent 1b100b6ceb
commit 6aacdb98b7
2 changed files with 115 additions and 24 deletions

View File

@@ -31,7 +31,7 @@ class TypedParamSpec:
kw: list[Function.Argument]
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry"""
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
@@ -41,6 +41,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
self.types: TypesRegistry = types
self._local_variables: dict[str, TypeVar] = {}
self._predicate_params: dict[str, Type] = {}
self._current_name: Optional[str] = None
define_builtins(self.types)
@@ -59,6 +61,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
self.reporter.error(error.token.get_location(), error.message)
self.resolve(stmts)
def type_of(self, expr: m.Expr) -> Type:
type: Type = expr.accept(self)
return type
def get_type(self, name: str) -> Type:
"""Get a type from its name
@@ -75,6 +81,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
return self._local_variables[name]
return self.types.get_type(name)
def get_variable(self, name: str) -> Type:
if name in self._predicate_params:
return self._predicate_params[name]
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
if predicate is None:
raise NameError(f"Unknown variable '{name}'")
return predicate.type
def resolve(self, stmts: list[m.Stmt]):
"""Process a sequence of statements
@@ -84,6 +98,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
for stmt in stmts:
stmt.accept(self)
def assert_bool(self, expr: m.Expr):
type: Type = self.type_of(expr)
if not self.types.is_subtype(type, self._bool):
self.reporter.error(expr.location, f"Must be a boolean but is {type}")
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
name: str = stmt.name.lexeme
self._current_name = name
@@ -118,17 +137,31 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
for spec in stmt.params:
for param in spec.mixed:
assert param.name is not None
self._predicate_params[param.name.lexeme] = param.type.accept(self)
type: Type = self.type_of(stmt.body)
params: list[TypedParamSpec] = [
self._visit_param_spec(spec) for spec in stmt.params
]
type: Type = self._bool
for spec in reversed(params):
type = Function(
pos_args=spec.pos,
args=spec.mixed,
kw_args=spec.kw,
returns=type,
if not self._is_valid_predicate(type):
self.reporter.error(
stmt.body.location,
f"Predicate function body must evaluate to a boolean, got {type}",
)
if len(params) != 0:
type = self._bool
for spec in reversed(params):
type = Function(
pos_args=spec.pos,
args=spec.mixed,
kw_args=spec.kw,
returns=type,
)
self._predicate_params = {}
self.types.define_predicate(
stmt.name.lexeme,
Predicate(
@@ -137,32 +170,90 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
),
)
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
def _is_valid_predicate(self, body: Type) -> bool:
match body:
case Function(returns=returns):
return self._is_valid_predicate(returns)
case _ if self.types.is_subtype(body, self._bool):
return True
case _:
return False
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type:
self.assert_bool(expr.left)
self.assert_bool(expr.right)
return self._bool
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
# TODO
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
return UnknownType()
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
# TODO
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
return UnknownType()
def visit_call_expr(self, expr: m.CallExpr) -> None:
self.reporter.warning(expr.location, "CallExpr not yet supported")
def visit_call_expr(self, expr: m.CallExpr) -> Type:
callee: Type = expr.callee.accept(self)
if not isinstance(callee, Function):
self.reporter.error(expr.location, f"Cannot call {callee}")
return UnknownType()
args: list[Type] = [arg.accept(self) for arg in expr.arguments]
def visit_get_expr(self, expr: m.GetExpr) -> None:
self.reporter.warning(expr.location, "GetExpr not yet supported")
n_args: int = len(args)
n_params: int = len(callee.args)
if n_args != n_params:
self.reporter.error(
expr.location,
f"Wrong number of argument, expected {n_params}, got {n_args}",
)
return UnknownType()
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
self.reporter.warning(expr.location, "VariableExpr not yet supported")
valid: bool = True
for arg, param in zip(args, callee.args):
if not self.types.is_subtype(arg, param.type):
self.reporter.error(
expr.location,
f"Invalid argument type at pos {param.pos}, expected {param.type}, got {arg}",
)
valid = False
if not valid:
return UnknownType()
return callee.returns
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
def visit_get_expr(self, expr: m.GetExpr) -> Type:
object: Type = expr.expr.accept(self)
member: Optional[Type] = self.types.lookup_member(object, expr.name.lexeme)
if member is None:
self.reporter.error(
expr.location, f"Unknown member '{expr.name}' of {object}"
)
return UnknownType()
return member
def visit_variable_expr(self, expr: m.VariableExpr) -> Type:
return self.get_variable(expr.name.lexeme)
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
return expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
self.reporter.warning(expr.location, "LiteralExpr not yet supported")
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type:
match expr.value:
case bool(): # Must be before int
return self.types.get_type("bool")
case int():
return self.types.get_type("int")
case float():
return self.types.get_type("float")
case str():
return self.types.get_type("str")
case _:
self.reporter.warning(expr.location, f"Unknown literal {expr}")
return UnknownType()
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
self.reporter.warning(expr.location, "WildcardExpr not yet supported")
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type:
return self.get_variable("_")
def visit_named_type(self, type: m.NamedType) -> Type:
name: str = type.name.lexeme

View File

@@ -652,7 +652,7 @@ class PythonTyper(
If the function has overloads, the function will try to resolve the
appropriate signature.
Argument types are matched to the defined parameters.
The function doesn't take the raw expression as a parameter to accomodate
The function doesn't take the raw expression as a parameter to accommodate
for desugared calls such as for operators.
Args: