feat(checker): type check predicate body

This commit is contained in:
2026-06-19 13:55:32 +02:00
parent 1b100b6ceb
commit 6aacdb98b7
2 changed files with 115 additions and 24 deletions

View File

@@ -31,7 +31,7 @@ class TypedParamSpec:
kw: list[Function.Argument] kw: list[Function.Argument]
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]): class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry""" """A resolver which evaluates Midas type definitions and build a registry"""
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None: def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
@@ -41,6 +41,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
self.types: TypesRegistry = types self.types: TypesRegistry = types
self._local_variables: dict[str, TypeVar] = {} self._local_variables: dict[str, TypeVar] = {}
self._predicate_params: dict[str, Type] = {}
self._current_name: Optional[str] = None self._current_name: Optional[str] = None
define_builtins(self.types) define_builtins(self.types)
@@ -59,6 +61,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
self.reporter.error(error.token.get_location(), error.message) self.reporter.error(error.token.get_location(), error.message)
self.resolve(stmts) self.resolve(stmts)
def type_of(self, expr: m.Expr) -> Type:
type: Type = expr.accept(self)
return type
def get_type(self, name: str) -> Type: def get_type(self, name: str) -> Type:
"""Get a type from its name """Get a type from its name
@@ -75,6 +81,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
return self._local_variables[name] return self._local_variables[name]
return self.types.get_type(name) return self.types.get_type(name)
def get_variable(self, name: str) -> Type:
if name in self._predicate_params:
return self._predicate_params[name]
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
if predicate is None:
raise NameError(f"Unknown variable '{name}'")
return predicate.type
def resolve(self, stmts: list[m.Stmt]): def resolve(self, stmts: list[m.Stmt]):
"""Process a sequence of statements """Process a sequence of statements
@@ -84,6 +98,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
for stmt in stmts: for stmt in stmts:
stmt.accept(self) stmt.accept(self)
def assert_bool(self, expr: m.Expr):
type: Type = self.type_of(expr)
if not self.types.is_subtype(type, self._bool):
self.reporter.error(expr.location, f"Must be a boolean but is {type}")
def visit_type_stmt(self, stmt: m.TypeStmt) -> None: def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
name: str = stmt.name.lexeme name: str = stmt.name.lexeme
self._current_name = name self._current_name = name
@@ -118,17 +137,31 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
) )
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
for spec in stmt.params:
for param in spec.mixed:
assert param.name is not None
self._predicate_params[param.name.lexeme] = param.type.accept(self)
type: Type = self.type_of(stmt.body)
params: list[TypedParamSpec] = [ params: list[TypedParamSpec] = [
self._visit_param_spec(spec) for spec in stmt.params self._visit_param_spec(spec) for spec in stmt.params
] ]
type: Type = self._bool
for spec in reversed(params): if not self._is_valid_predicate(type):
type = Function( self.reporter.error(
pos_args=spec.pos, stmt.body.location,
args=spec.mixed, f"Predicate function body must evaluate to a boolean, got {type}",
kw_args=spec.kw,
returns=type,
) )
if len(params) != 0:
type = self._bool
for spec in reversed(params):
type = Function(
pos_args=spec.pos,
args=spec.mixed,
kw_args=spec.kw,
returns=type,
)
self._predicate_params = {}
self.types.define_predicate( self.types.define_predicate(
stmt.name.lexeme, stmt.name.lexeme,
Predicate( Predicate(
@@ -137,32 +170,90 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
), ),
) )
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: def _is_valid_predicate(self, body: Type) -> bool:
self.reporter.warning(expr.location, "LogicalExpr not yet supported") match body:
case Function(returns=returns):
return self._is_valid_predicate(returns)
case _ if self.types.is_subtype(body, self._bool):
return True
case _:
return False
def visit_binary_expr(self, expr: m.BinaryExpr) -> None: def visit_logical_expr(self, expr: m.LogicalExpr) -> Type:
self.assert_bool(expr.left)
self.assert_bool(expr.right)
return self._bool
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
# TODO
self.reporter.warning(expr.location, "BinaryExpr not yet supported") self.reporter.warning(expr.location, "BinaryExpr not yet supported")
return UnknownType()
def visit_unary_expr(self, expr: m.UnaryExpr) -> None: def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
# TODO
self.reporter.warning(expr.location, "UnaryExpr not yet supported") self.reporter.warning(expr.location, "UnaryExpr not yet supported")
return UnknownType()
def visit_call_expr(self, expr: m.CallExpr) -> None: def visit_call_expr(self, expr: m.CallExpr) -> Type:
self.reporter.warning(expr.location, "CallExpr not yet supported") callee: Type = expr.callee.accept(self)
if not isinstance(callee, Function):
self.reporter.error(expr.location, f"Cannot call {callee}")
return UnknownType()
args: list[Type] = [arg.accept(self) for arg in expr.arguments]
def visit_get_expr(self, expr: m.GetExpr) -> None: n_args: int = len(args)
self.reporter.warning(expr.location, "GetExpr not yet supported") n_params: int = len(callee.args)
if n_args != n_params:
self.reporter.error(
expr.location,
f"Wrong number of argument, expected {n_params}, got {n_args}",
)
return UnknownType()
def visit_variable_expr(self, expr: m.VariableExpr) -> None: valid: bool = True
self.reporter.warning(expr.location, "VariableExpr not yet supported") for arg, param in zip(args, callee.args):
if not self.types.is_subtype(arg, param.type):
self.reporter.error(
expr.location,
f"Invalid argument type at pos {param.pos}, expected {param.type}, got {arg}",
)
valid = False
if not valid:
return UnknownType()
return callee.returns
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None: def visit_get_expr(self, expr: m.GetExpr) -> Type:
object: Type = expr.expr.accept(self)
member: Optional[Type] = self.types.lookup_member(object, expr.name.lexeme)
if member is None:
self.reporter.error(
expr.location, f"Unknown member '{expr.name}' of {object}"
)
return UnknownType()
return member
def visit_variable_expr(self, expr: m.VariableExpr) -> Type:
return self.get_variable(expr.name.lexeme)
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
return expr.expr.accept(self) return expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: def visit_literal_expr(self, expr: m.LiteralExpr) -> Type:
self.reporter.warning(expr.location, "LiteralExpr not yet supported") match expr.value:
case bool(): # Must be before int
return self.types.get_type("bool")
case int():
return self.types.get_type("int")
case float():
return self.types.get_type("float")
case str():
return self.types.get_type("str")
case _:
self.reporter.warning(expr.location, f"Unknown literal {expr}")
return UnknownType()
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type:
self.reporter.warning(expr.location, "WildcardExpr not yet supported") return self.get_variable("_")
def visit_named_type(self, type: m.NamedType) -> Type: def visit_named_type(self, type: m.NamedType) -> Type:
name: str = type.name.lexeme name: str = type.name.lexeme

View File

@@ -652,7 +652,7 @@ class PythonTyper(
If the function has overloads, the function will try to resolve the If the function has overloads, the function will try to resolve the
appropriate signature. appropriate signature.
Argument types are matched to the defined parameters. Argument types are matched to the defined parameters.
The function doesn't take the raw expression as a parameter to accomodate The function doesn't take the raw expression as a parameter to accommodate
for desugared calls such as for operators. for desugared calls such as for operators.
Args: Args: