feat(gen): generate predicate functions

This commit is contained in:
2026-06-18 22:48:10 +02:00
parent bdc1b265a6
commit 48fcb499a1
5 changed files with 126 additions and 10 deletions

View File

@@ -352,3 +352,6 @@ class TypesRegistry:
case _:
self.logger.debug(f"Can't get member on {type}")
return None
def lookup_predicate(self, name: str) -> Optional[Predicate]:
return self._predicates.get(name)

View File

@@ -38,5 +38,5 @@ def compile(
if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)):
sys.exit(1)
generator = Generator(workdir=source_path.parent)
generator = Generator(workdir=source_path.parent, types=checker.types)
generator.generate(typed_ast, source_path)

View File

@@ -1,6 +1,9 @@
import ast
from typing import Optional
import midas.ast.midas as m
from midas.checker.registry import TypesRegistry
from midas.checker.types import Function, Predicate, Type
from midas.lexer.token import TokenType
LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = {
@@ -31,6 +34,97 @@ COMPARISON_OPERATORS: dict[TokenType, type[ast.cmpop]] = {
class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
def __init__(self, types: TypesRegistry):
self.types: TypesRegistry = types
self._id: int = 0
self._definitions: list[ast.stmt] = []
self._aliases: dict[str, str] = {}
def get_definitions(self) -> list[ast.stmt]:
return self._definitions
def generate(self, expr: m.Expr) -> ast.expr:
match expr:
case m.VariableExpr():
return expr.accept(self)
case _:
func = Function(
pos_args=[],
args=[
Function.Argument(
pos=0,
name="_",
type=self.types.get_type("Any"),
required=True,
)
],
kw_args=[],
returns=self.types.get_type("bool"),
)
alias: str = self.make_alias(None)
definition: ast.stmt = self.make_definition(
alias, Predicate(type=func, body=expr)
)
self._definitions.append(definition)
return ast.Name(id=alias)
def make_alias(self, name: Optional[str]) -> str:
suffix: str = f"_{name}" if name is not None else ""
alias: str = f"__midas_p{self._id}{suffix}__"
self._id += 1
return alias
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
body: list[ast.stmt] = [ast.Return(value=predicate.body.accept(self))]
return self.make_func(name, body, predicate.type)
def make_args(self, func: Function) -> ast.arguments:
return ast.arguments(
posonlyargs=[ast.arg(arg=arg.name) for arg in func.pos_args],
args=[ast.arg(arg=arg.name) for arg in func.args],
kwonlyargs=[ast.arg(arg=arg.name) for arg in func.kw_args],
defaults=[],
kw_defaults=[],
)
def make_func(
self, name: str, inner_body: list[ast.stmt], type: Type, level: int = 0
) -> ast.stmt:
match type:
case Function(returns=Function()):
inner_name: str = f"inner{level}"
return ast.FunctionDef(
name=name,
args=self.make_args(type),
body=[
self.make_func(inner_name, inner_body, type.returns, level + 1),
ast.Return(value=ast.Name(id=inner_name)),
],
decorator_list=[],
)
case Function():
return ast.FunctionDef(
name=name,
args=self.make_args(type),
body=inner_body,
decorator_list=[],
)
case _:
raise ValueError(f"Expected function, got {type}")
def get_predicate(self, name: str) -> Optional[ast.expr]:
if name not in self._aliases:
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
if predicate is None:
return None
alias: str = self.make_alias(name)
self._aliases[name] = alias
self._definitions.append(self.make_definition(alias, predicate))
return ast.Name(id=self._aliases[name])
def visit_logical_expr(self, expr: m.LogicalExpr) -> ast.expr:
return ast.BoolOp(
op=LOGICAL_OPERATORS[expr.operator.type](),
@@ -79,8 +173,10 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
)
def visit_variable_expr(self, expr: m.VariableExpr) -> ast.expr:
# TODO: lookup predicate
return ast.Name(id=expr.name.lexeme)
name: str = expr.name.lexeme
if (p := self.get_predicate(name)) is not None:
return p
return ast.Name(id=name)
def visit_grouping_expr(self, expr: m.GroupingExpr) -> ast.expr:
return expr.accept(self)

View File

@@ -8,6 +8,7 @@ import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasPrinter
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
AliasType,
AppliedType,
@@ -35,7 +36,7 @@ class Scope:
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def __init__(self, workdir: Path) -> None:
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
self.workdir: Path = workdir.resolve()
self.build_dir: Path = self.workdir / "build" / "midas"
if self.build_dir.exists():
@@ -48,15 +49,18 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
judgements=[],
)
self._alias_count: int = 0
self._predicate_count: int = 0
self._scopes: list[Scope] = []
self._constraint_generator: ConstraintGenerator = ConstraintGenerator()
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
self._constraints: list[tuple[m.Expr, ast.expr]] = []
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
self.rel_src_path = src_path.relative_to(self.workdir)
self._typed_ast = typed_ast
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
module = ast.Module(body=body, type_ignores=[])
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
module = ast.Module(body=predicates + body, type_ignores=[])
module = ast.fix_missing_locations(module)
return module
@@ -253,7 +257,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
return generated
def _make_alias(self, expr: ast.expr) -> ast.expr:
name: str = f"__midas_alias_{self._alias_count}__"
name: str = f"__midas_a{self._alias_count}__"
alias = ast.Name(id=name)
self._alias_count += 1
self._scopes[-1].aliases.append(name)
@@ -361,9 +365,13 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def _make_constraint_assert(
self, src_location: Location, expr: ast.expr, constraint: m.Expr
):
test: ast.expr = constraint.accept(self._constraint_generator)
test_func: ast.expr = self._get_constraint(constraint)
self._add_assert(
test,
ast.Call(
func=test_func,
args=[expr],
keywords=[],
),
self._make_constraint_assert_message(src_location, expr, constraint),
)
@@ -377,3 +385,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
return ast.Constant(
f"{loc_str}: ConstraintError: Value does not fit constraint '{constraint_str}'"
)
def _get_constraint(self, expr: m.Expr) -> ast.expr:
for expr2, constraint in self._constraints:
if expr2 == expr:
return constraint
constraint: ast.expr = self._constraint_generator.generate(expr)
self._constraints.append((expr, constraint))
return constraint

View File

@@ -340,7 +340,7 @@ class MidasParser(Parser):
def call(self) -> Expr:
expr: Expr = self.reference()
if self.match(TokenType.LEFT_PAREN):
while self.match(TokenType.LEFT_PAREN):
expr = self.finish_call(expr)
return expr