From 6aacdb98b7a3d8a272e73a887dd98d558807fa52 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Fri, 19 Jun 2026 13:55:32 +0200 Subject: [PATCH] feat(checker): type check predicate body --- midas/checker/midas.py | 137 +++++++++++++++++++++++++++++++++------- midas/checker/python.py | 2 +- 2 files changed, 115 insertions(+), 24 deletions(-) diff --git a/midas/checker/midas.py b/midas/checker/midas.py index 5e117c2..5e0e847 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -31,7 +31,7 @@ class TypedParamSpec: kw: list[Function.Argument] -class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]): +class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]): """A resolver which evaluates Midas type definitions and build a registry""" def __init__(self, types: TypesRegistry, reporter: Reporter) -> None: @@ -41,6 +41,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type self.types: TypesRegistry = types self._local_variables: dict[str, TypeVar] = {} + self._predicate_params: dict[str, Type] = {} + self._current_name: Optional[str] = None define_builtins(self.types) @@ -59,6 +61,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type self.reporter.error(error.token.get_location(), error.message) self.resolve(stmts) + def type_of(self, expr: m.Expr) -> Type: + type: Type = expr.accept(self) + return type + def get_type(self, name: str) -> Type: """Get a type from its name @@ -75,6 +81,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type return self._local_variables[name] return self.types.get_type(name) + def get_variable(self, name: str) -> Type: + if name in self._predicate_params: + return self._predicate_params[name] + predicate: Optional[Predicate] = self.types.lookup_predicate(name) + if predicate is None: + raise NameError(f"Unknown variable '{name}'") + return predicate.type + def resolve(self, stmts: list[m.Stmt]): """Process a sequence of statements @@ -84,6 +98,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type for stmt in stmts: stmt.accept(self) + def assert_bool(self, expr: m.Expr): + type: Type = self.type_of(expr) + if not self.types.is_subtype(type, self._bool): + self.reporter.error(expr.location, f"Must be a boolean but is {type}") + def visit_type_stmt(self, stmt: m.TypeStmt) -> None: name: str = stmt.name.lexeme self._current_name = name @@ -118,17 +137,31 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type ) def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: + for spec in stmt.params: + for param in spec.mixed: + assert param.name is not None + self._predicate_params[param.name.lexeme] = param.type.accept(self) + + type: Type = self.type_of(stmt.body) params: list[TypedParamSpec] = [ self._visit_param_spec(spec) for spec in stmt.params ] - type: Type = self._bool - for spec in reversed(params): - type = Function( - pos_args=spec.pos, - args=spec.mixed, - kw_args=spec.kw, - returns=type, + + if not self._is_valid_predicate(type): + self.reporter.error( + stmt.body.location, + f"Predicate function body must evaluate to a boolean, got {type}", ) + if len(params) != 0: + type = self._bool + for spec in reversed(params): + type = Function( + pos_args=spec.pos, + args=spec.mixed, + kw_args=spec.kw, + returns=type, + ) + self._predicate_params = {} self.types.define_predicate( stmt.name.lexeme, Predicate( @@ -137,32 +170,90 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type ), ) - def visit_logical_expr(self, expr: m.LogicalExpr) -> None: - self.reporter.warning(expr.location, "LogicalExpr not yet supported") + def _is_valid_predicate(self, body: Type) -> bool: + match body: + case Function(returns=returns): + return self._is_valid_predicate(returns) + case _ if self.types.is_subtype(body, self._bool): + return True + case _: + return False - def visit_binary_expr(self, expr: m.BinaryExpr) -> None: + def visit_logical_expr(self, expr: m.LogicalExpr) -> Type: + self.assert_bool(expr.left) + self.assert_bool(expr.right) + return self._bool + + def visit_binary_expr(self, expr: m.BinaryExpr) -> Type: + # TODO self.reporter.warning(expr.location, "BinaryExpr not yet supported") + return UnknownType() - def visit_unary_expr(self, expr: m.UnaryExpr) -> None: + def visit_unary_expr(self, expr: m.UnaryExpr) -> Type: + # TODO self.reporter.warning(expr.location, "UnaryExpr not yet supported") + return UnknownType() - def visit_call_expr(self, expr: m.CallExpr) -> None: - self.reporter.warning(expr.location, "CallExpr not yet supported") + def visit_call_expr(self, expr: m.CallExpr) -> Type: + callee: Type = expr.callee.accept(self) + if not isinstance(callee, Function): + self.reporter.error(expr.location, f"Cannot call {callee}") + return UnknownType() + args: list[Type] = [arg.accept(self) for arg in expr.arguments] - def visit_get_expr(self, expr: m.GetExpr) -> None: - self.reporter.warning(expr.location, "GetExpr not yet supported") + n_args: int = len(args) + n_params: int = len(callee.args) + if n_args != n_params: + self.reporter.error( + expr.location, + f"Wrong number of argument, expected {n_params}, got {n_args}", + ) + return UnknownType() - def visit_variable_expr(self, expr: m.VariableExpr) -> None: - self.reporter.warning(expr.location, "VariableExpr not yet supported") + valid: bool = True + for arg, param in zip(args, callee.args): + if not self.types.is_subtype(arg, param.type): + self.reporter.error( + expr.location, + f"Invalid argument type at pos {param.pos}, expected {param.type}, got {arg}", + ) + valid = False + if not valid: + return UnknownType() + return callee.returns - def visit_grouping_expr(self, expr: m.GroupingExpr) -> None: + def visit_get_expr(self, expr: m.GetExpr) -> Type: + object: Type = expr.expr.accept(self) + member: Optional[Type] = self.types.lookup_member(object, expr.name.lexeme) + if member is None: + self.reporter.error( + expr.location, f"Unknown member '{expr.name}' of {object}" + ) + return UnknownType() + return member + + def visit_variable_expr(self, expr: m.VariableExpr) -> Type: + return self.get_variable(expr.name.lexeme) + + def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type: return expr.expr.accept(self) - def visit_literal_expr(self, expr: m.LiteralExpr) -> None: - self.reporter.warning(expr.location, "LiteralExpr not yet supported") + def visit_literal_expr(self, expr: m.LiteralExpr) -> Type: + match expr.value: + case bool(): # Must be before int + return self.types.get_type("bool") + case int(): + return self.types.get_type("int") + case float(): + return self.types.get_type("float") + case str(): + return self.types.get_type("str") + case _: + self.reporter.warning(expr.location, f"Unknown literal {expr}") + return UnknownType() - def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: - self.reporter.warning(expr.location, "WildcardExpr not yet supported") + def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type: + return self.get_variable("_") def visit_named_type(self, type: m.NamedType) -> Type: name: str = type.name.lexeme diff --git a/midas/checker/python.py b/midas/checker/python.py index c4bffff..65b48eb 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -652,7 +652,7 @@ class PythonTyper( If the function has overloads, the function will try to resolve the appropriate signature. Argument types are matched to the defined parameters. - The function doesn't take the raw expression as a parameter to accomodate + The function doesn't take the raw expression as a parameter to accommodate for desugared calls such as for operators. Args: