diff --git a/midas/checker/midas.py b/midas/checker/midas.py index 5e0e847..0fc4fb4 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -167,6 +167,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type Predicate( type=type, body=stmt.body, + alias=len(params) == 0, ), ) diff --git a/midas/checker/types.py b/midas/checker/types.py index f66468b..309ad0f 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -242,6 +242,7 @@ def unfold_type(type: Type) -> Type: class Predicate: type: Type body: m.Expr + alias: bool Type = ( diff --git a/midas/generator/constraints.py b/midas/generator/constraints.py index d7bf6ed..e9f21c1 100644 --- a/midas/generator/constraints.py +++ b/midas/generator/constraints.py @@ -63,7 +63,7 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]): ) alias: str = self.make_alias(None) definition: ast.stmt = self.make_definition( - alias, Predicate(type=func, body=expr) + alias, Predicate(type=func, body=expr, alias=False) ) self._definitions.append(definition) return ast.Name(id=alias) @@ -79,8 +79,15 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]): return alias def make_definition(self, name: str, predicate: Predicate) -> ast.stmt: - body: list[ast.stmt] = [ast.Return(value=predicate.body.accept(self))] - return self.make_func(name, body, predicate.type) + body: ast.expr = predicate.body.accept(self) + if predicate.alias: + return ast.Assign( + targets=[ + ast.Name(id=name), + ], + value=body, + ) + return self.make_func(name, [ast.Return(value=body)], predicate.type) def make_args(self, func: Function) -> ast.arguments: return ast.arguments( @@ -116,7 +123,7 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]): ) case _: - raise ValueError(f"Expected function, got {type}") + raise ValueError(f"Expected function, got {type!r}") def get_predicate(self, name: str) -> Optional[ast.expr]: if name not in self._aliases: