feat(gen): handle predicate aliases

handle cases where a predicate is defined as an alias, i.e. without any parameters
This commit is contained in:
2026-06-19 14:05:34 +02:00
parent 2974386110
commit 657406ea01
3 changed files with 13 additions and 4 deletions

View File

@@ -167,6 +167,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
Predicate(
type=type,
body=stmt.body,
alias=len(params) == 0,
),
)

View File

@@ -242,6 +242,7 @@ def unfold_type(type: Type) -> Type:
class Predicate:
type: Type
body: m.Expr
alias: bool
Type = (

View File

@@ -63,7 +63,7 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
)
alias: str = self.make_alias(None)
definition: ast.stmt = self.make_definition(
alias, Predicate(type=func, body=expr)
alias, Predicate(type=func, body=expr, alias=False)
)
self._definitions.append(definition)
return ast.Name(id=alias)
@@ -79,8 +79,15 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
return alias
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
body: list[ast.stmt] = [ast.Return(value=predicate.body.accept(self))]
return self.make_func(name, body, predicate.type)
body: ast.expr = predicate.body.accept(self)
if predicate.alias:
return ast.Assign(
targets=[
ast.Name(id=name),
],
value=body,
)
return self.make_func(name, [ast.Return(value=body)], predicate.type)
def make_args(self, func: Function) -> ast.arguments:
return ast.arguments(
@@ -116,7 +123,7 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
)
case _:
raise ValueError(f"Expected function, got {type}")
raise ValueError(f"Expected function, got {type!r}")
def get_predicate(self, name: str) -> Optional[ast.expr]:
if name not in self._aliases: