feat(gen): handle predicate aliases
handle cases where a predicate is defined as an alias, i.e. without any parameters
This commit is contained in:
@@ -167,6 +167,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
|||||||
Predicate(
|
Predicate(
|
||||||
type=type,
|
type=type,
|
||||||
body=stmt.body,
|
body=stmt.body,
|
||||||
|
alias=len(params) == 0,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -242,6 +242,7 @@ def unfold_type(type: Type) -> Type:
|
|||||||
class Predicate:
|
class Predicate:
|
||||||
type: Type
|
type: Type
|
||||||
body: m.Expr
|
body: m.Expr
|
||||||
|
alias: bool
|
||||||
|
|
||||||
|
|
||||||
Type = (
|
Type = (
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
|||||||
)
|
)
|
||||||
alias: str = self.make_alias(None)
|
alias: str = self.make_alias(None)
|
||||||
definition: ast.stmt = self.make_definition(
|
definition: ast.stmt = self.make_definition(
|
||||||
alias, Predicate(type=func, body=expr)
|
alias, Predicate(type=func, body=expr, alias=False)
|
||||||
)
|
)
|
||||||
self._definitions.append(definition)
|
self._definitions.append(definition)
|
||||||
return ast.Name(id=alias)
|
return ast.Name(id=alias)
|
||||||
@@ -79,8 +79,15 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
|||||||
return alias
|
return alias
|
||||||
|
|
||||||
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
|
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
|
||||||
body: list[ast.stmt] = [ast.Return(value=predicate.body.accept(self))]
|
body: ast.expr = predicate.body.accept(self)
|
||||||
return self.make_func(name, body, predicate.type)
|
if predicate.alias:
|
||||||
|
return ast.Assign(
|
||||||
|
targets=[
|
||||||
|
ast.Name(id=name),
|
||||||
|
],
|
||||||
|
value=body,
|
||||||
|
)
|
||||||
|
return self.make_func(name, [ast.Return(value=body)], predicate.type)
|
||||||
|
|
||||||
def make_args(self, func: Function) -> ast.arguments:
|
def make_args(self, func: Function) -> ast.arguments:
|
||||||
return ast.arguments(
|
return ast.arguments(
|
||||||
@@ -116,7 +123,7 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Expected function, got {type}")
|
raise ValueError(f"Expected function, got {type!r}")
|
||||||
|
|
||||||
def get_predicate(self, name: str) -> Optional[ast.expr]:
|
def get_predicate(self, name: str) -> Optional[ast.expr]:
|
||||||
if name not in self._aliases:
|
if name not in self._aliases:
|
||||||
|
|||||||
Reference in New Issue
Block a user