feat(gen): handle predicate aliases

handle cases where a predicate is defined as an alias, i.e. without any parameters
This commit is contained in:
2026-06-19 14:05:34 +02:00
parent 2974386110
commit 657406ea01
3 changed files with 13 additions and 4 deletions

View File

@@ -167,6 +167,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
Predicate( Predicate(
type=type, type=type,
body=stmt.body, body=stmt.body,
alias=len(params) == 0,
), ),
) )

View File

@@ -242,6 +242,7 @@ def unfold_type(type: Type) -> Type:
class Predicate: class Predicate:
type: Type type: Type
body: m.Expr body: m.Expr
alias: bool
Type = ( Type = (

View File

@@ -63,7 +63,7 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
) )
alias: str = self.make_alias(None) alias: str = self.make_alias(None)
definition: ast.stmt = self.make_definition( definition: ast.stmt = self.make_definition(
alias, Predicate(type=func, body=expr) alias, Predicate(type=func, body=expr, alias=False)
) )
self._definitions.append(definition) self._definitions.append(definition)
return ast.Name(id=alias) return ast.Name(id=alias)
@@ -79,8 +79,15 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
return alias return alias
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt: def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
body: list[ast.stmt] = [ast.Return(value=predicate.body.accept(self))] body: ast.expr = predicate.body.accept(self)
return self.make_func(name, body, predicate.type) if predicate.alias:
return ast.Assign(
targets=[
ast.Name(id=name),
],
value=body,
)
return self.make_func(name, [ast.Return(value=body)], predicate.type)
def make_args(self, func: Function) -> ast.arguments: def make_args(self, func: Function) -> ast.arguments:
return ast.arguments( return ast.arguments(
@@ -116,7 +123,7 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
) )
case _: case _:
raise ValueError(f"Expected function, got {type}") raise ValueError(f"Expected function, got {type!r}")
def get_predicate(self, name: str) -> Optional[ast.expr]: def get_predicate(self, name: str) -> Optional[ast.expr]:
if name not in self._aliases: if name not in self._aliases: