From 48fcb499a195c7e6b3e7ff0a106939e2be4eefe3 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Thu, 18 Jun 2026 22:48:10 +0200 Subject: [PATCH] feat(gen): generate predicate functions --- midas/checker/registry.py | 3 + midas/cli/commands/compile.py | 2 +- midas/generator/constraints.py | 100 ++++++++++++++++++++++++++++++++- midas/generator/generator.py | 29 ++++++++-- midas/parser/midas.py | 2 +- 5 files changed, 126 insertions(+), 10 deletions(-) diff --git a/midas/checker/registry.py b/midas/checker/registry.py index f97d04f..3e3d509 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -352,3 +352,6 @@ class TypesRegistry: case _: self.logger.debug(f"Can't get member on {type}") return None + + def lookup_predicate(self, name: str) -> Optional[Predicate]: + return self._predicates.get(name) diff --git a/midas/cli/commands/compile.py b/midas/cli/commands/compile.py index 5a410f7..5a623ec 100644 --- a/midas/cli/commands/compile.py +++ b/midas/cli/commands/compile.py @@ -38,5 +38,5 @@ def compile( if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)): sys.exit(1) - generator = Generator(workdir=source_path.parent) + generator = Generator(workdir=source_path.parent, types=checker.types) generator.generate(typed_ast, source_path) diff --git a/midas/generator/constraints.py b/midas/generator/constraints.py index 329b79b..d739516 100644 --- a/midas/generator/constraints.py +++ b/midas/generator/constraints.py @@ -1,6 +1,9 @@ import ast +from typing import Optional import midas.ast.midas as m +from midas.checker.registry import TypesRegistry +from midas.checker.types import Function, Predicate, Type from midas.lexer.token import TokenType LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = { @@ -31,6 +34,97 @@ COMPARISON_OPERATORS: dict[TokenType, type[ast.cmpop]] = { class ConstraintGenerator(m.Expr.Visitor[ast.expr]): + def __init__(self, types: TypesRegistry): + self.types: TypesRegistry = types + self._id: int = 0 + self._definitions: list[ast.stmt] = [] + self._aliases: dict[str, str] = {} + + def get_definitions(self) -> list[ast.stmt]: + return self._definitions + + def generate(self, expr: m.Expr) -> ast.expr: + match expr: + case m.VariableExpr(): + return expr.accept(self) + case _: + func = Function( + pos_args=[], + args=[ + Function.Argument( + pos=0, + name="_", + type=self.types.get_type("Any"), + required=True, + ) + ], + kw_args=[], + returns=self.types.get_type("bool"), + ) + alias: str = self.make_alias(None) + definition: ast.stmt = self.make_definition( + alias, Predicate(type=func, body=expr) + ) + self._definitions.append(definition) + return ast.Name(id=alias) + + def make_alias(self, name: Optional[str]) -> str: + suffix: str = f"_{name}" if name is not None else "" + alias: str = f"__midas_p{self._id}{suffix}__" + self._id += 1 + return alias + + def make_definition(self, name: str, predicate: Predicate) -> ast.stmt: + body: list[ast.stmt] = [ast.Return(value=predicate.body.accept(self))] + return self.make_func(name, body, predicate.type) + + def make_args(self, func: Function) -> ast.arguments: + return ast.arguments( + posonlyargs=[ast.arg(arg=arg.name) for arg in func.pos_args], + args=[ast.arg(arg=arg.name) for arg in func.args], + kwonlyargs=[ast.arg(arg=arg.name) for arg in func.kw_args], + defaults=[], + kw_defaults=[], + ) + + def make_func( + self, name: str, inner_body: list[ast.stmt], type: Type, level: int = 0 + ) -> ast.stmt: + match type: + case Function(returns=Function()): + inner_name: str = f"inner{level}" + return ast.FunctionDef( + name=name, + args=self.make_args(type), + body=[ + self.make_func(inner_name, inner_body, type.returns, level + 1), + ast.Return(value=ast.Name(id=inner_name)), + ], + decorator_list=[], + ) + + case Function(): + return ast.FunctionDef( + name=name, + args=self.make_args(type), + body=inner_body, + decorator_list=[], + ) + + case _: + raise ValueError(f"Expected function, got {type}") + + def get_predicate(self, name: str) -> Optional[ast.expr]: + if name not in self._aliases: + predicate: Optional[Predicate] = self.types.lookup_predicate(name) + if predicate is None: + return None + alias: str = self.make_alias(name) + self._aliases[name] = alias + self._definitions.append(self.make_definition(alias, predicate)) + + return ast.Name(id=self._aliases[name]) + def visit_logical_expr(self, expr: m.LogicalExpr) -> ast.expr: return ast.BoolOp( op=LOGICAL_OPERATORS[expr.operator.type](), @@ -79,8 +173,10 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]): ) def visit_variable_expr(self, expr: m.VariableExpr) -> ast.expr: - # TODO: lookup predicate - return ast.Name(id=expr.name.lexeme) + name: str = expr.name.lexeme + if (p := self.get_predicate(name)) is not None: + return p + return ast.Name(id=name) def visit_grouping_expr(self, expr: m.GroupingExpr) -> ast.expr: return expr.accept(self) diff --git a/midas/generator/generator.py b/midas/generator/generator.py index 64c8c16..9fc1850 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -8,6 +8,7 @@ import midas.ast.midas as m import midas.ast.python as p from midas.ast.location import Location from midas.ast.printer import MidasPrinter +from midas.checker.registry import TypesRegistry from midas.checker.types import ( AliasType, AppliedType, @@ -35,7 +36,7 @@ class Scope: class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): - def __init__(self, workdir: Path) -> None: + def __init__(self, workdir: Path, types: TypesRegistry) -> None: self.workdir: Path = workdir.resolve() self.build_dir: Path = self.workdir / "build" / "midas" if self.build_dir.exists(): @@ -48,15 +49,18 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): judgements=[], ) self._alias_count: int = 0 + self._predicate_count: int = 0 self._scopes: list[Scope] = [] - self._constraint_generator: ConstraintGenerator = ConstraintGenerator() + self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types) + self._constraints: list[tuple[m.Expr, ast.expr]] = [] def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST: self.rel_src_path = src_path.relative_to(self.workdir) self._typed_ast = typed_ast body: list[ast.stmt] = self._visit_body(typed_ast.stmts) - module = ast.Module(body=body, type_ignores=[]) + predicates: list[ast.stmt] = self._constraint_generator.get_definitions() + module = ast.Module(body=predicates + body, type_ignores=[]) module = ast.fix_missing_locations(module) return module @@ -253,7 +257,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): return generated def _make_alias(self, expr: ast.expr) -> ast.expr: - name: str = f"__midas_alias_{self._alias_count}__" + name: str = f"__midas_a{self._alias_count}__" alias = ast.Name(id=name) self._alias_count += 1 self._scopes[-1].aliases.append(name) @@ -361,9 +365,13 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): def _make_constraint_assert( self, src_location: Location, expr: ast.expr, constraint: m.Expr ): - test: ast.expr = constraint.accept(self._constraint_generator) + test_func: ast.expr = self._get_constraint(constraint) self._add_assert( - test, + ast.Call( + func=test_func, + args=[expr], + keywords=[], + ), self._make_constraint_assert_message(src_location, expr, constraint), ) @@ -377,3 +385,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): return ast.Constant( f"{loc_str}: ConstraintError: Value does not fit constraint '{constraint_str}'" ) + + def _get_constraint(self, expr: m.Expr) -> ast.expr: + for expr2, constraint in self._constraints: + if expr2 == expr: + return constraint + + constraint: ast.expr = self._constraint_generator.generate(expr) + self._constraints.append((expr, constraint)) + return constraint diff --git a/midas/parser/midas.py b/midas/parser/midas.py index 1b60dc1..488a8ca 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -340,7 +340,7 @@ class MidasParser(Parser): def call(self) -> Expr: expr: Expr = self.reference() - if self.match(TokenType.LEFT_PAREN): + while self.match(TokenType.LEFT_PAREN): expr = self.finish_call(expr) return expr