From bdc1b265a62ce7bb8121ff619378266c1f7f1938 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Thu, 18 Jun 2026 13:19:17 +0200 Subject: [PATCH] feat(gen): generate basic constraint assertion --- midas/generator/constraints.py | 92 ++++++++++++++++++++++++++++++++++ midas/generator/generator.py | 22 +++++++- 2 files changed, 112 insertions(+), 2 deletions(-) create mode 100644 midas/generator/constraints.py diff --git a/midas/generator/constraints.py b/midas/generator/constraints.py new file mode 100644 index 0000000..329b79b --- /dev/null +++ b/midas/generator/constraints.py @@ -0,0 +1,92 @@ +import ast + +import midas.ast.midas as m +from midas.lexer.token import TokenType + +LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = { + TokenType.AND: ast.And, + # TokenType.OR: ast.Or, +} + +BINARY_OPERATORS: dict[TokenType, type[ast.operator]] = { + # TokenType.PLUS: ast.Add, + TokenType.MINUS: ast.Sub, + TokenType.STAR: ast.Mult, + TokenType.SLASH: ast.Div, +} + +UNARY_OPERATORS: dict[TokenType, type[ast.unaryop]] = { + # TokenType.PLUS: ast.UAdd, + TokenType.MINUS: ast.USub, +} + +COMPARISON_OPERATORS: dict[TokenType, type[ast.cmpop]] = { + TokenType.GREATER: ast.Gt, + TokenType.GREATER_EQUAL: ast.GtE, + TokenType.LESS: ast.Lt, + TokenType.LESS_EQUAL: ast.LtE, + TokenType.EQUAL_EQUAL: ast.Eq, + TokenType.BANG_EQUAL: ast.NotEq, +} + + +class ConstraintGenerator(m.Expr.Visitor[ast.expr]): + def visit_logical_expr(self, expr: m.LogicalExpr) -> ast.expr: + return ast.BoolOp( + op=LOGICAL_OPERATORS[expr.operator.type](), + values=[ + expr.left.accept(self), + expr.right.accept(self), + ], + ) + + def visit_binary_expr(self, expr: m.BinaryExpr) -> ast.expr: + op: TokenType = expr.operator.type + if op in BINARY_OPERATORS: + return ast.BinOp( + left=expr.left.accept(self), + op=BINARY_OPERATORS[op](), + right=expr.right.accept(self), + ) + if op in COMPARISON_OPERATORS: + return ast.Compare( + left=expr.left.accept(self), + ops=[COMPARISON_OPERATORS[op]()], + comparators=[expr.right.accept(self)], + ) + raise ValueError(f"Unexpected binary operator {op}") + + def visit_unary_expr(self, expr: m.UnaryExpr) -> ast.expr: + return ast.UnaryOp( + op=UNARY_OPERATORS[expr.operator.type](), + operand=expr.right.accept(self), + ) + + def visit_call_expr(self, expr: m.CallExpr) -> ast.expr: + return ast.Call( + func=expr.callee.accept(self), + args=[arg.accept(self) for arg in expr.arguments], + keywords=[ + ast.keyword(arg=name, value=arg.accept(self)) + for name, arg in expr.keywords.items() + ], + ) + + def visit_get_expr(self, expr: m.GetExpr) -> ast.expr: + return ast.Attribute( + value=expr.expr.accept(self), + attr=expr.name.lexeme, + ) + + def visit_variable_expr(self, expr: m.VariableExpr) -> ast.expr: + # TODO: lookup predicate + return ast.Name(id=expr.name.lexeme) + + def visit_grouping_expr(self, expr: m.GroupingExpr) -> ast.expr: + return expr.accept(self) + + def visit_literal_expr(self, expr: m.LiteralExpr) -> ast.expr: + return ast.Constant(value=expr.value) + + def visit_wildcard_expr(self, expr: m.WildcardExpr) -> ast.expr: + return ast.Name(id="_") diff --git a/midas/generator/generator.py b/midas/generator/generator.py index 67e11e9..64c8c16 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -7,6 +7,7 @@ from typing import Optional, assert_never import midas.ast.midas as m import midas.ast.python as p from midas.ast.location import Location +from midas.ast.printer import MidasPrinter from midas.checker.types import ( AliasType, AppliedType, @@ -23,6 +24,7 @@ from midas.checker.types import ( UnitType, UnknownType, ) +from midas.generator.constraints import ConstraintGenerator from midas.utils import TypedAST @@ -48,6 +50,8 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): self._alias_count: int = 0 self._scopes: list[Scope] = [] + self._constraint_generator: ConstraintGenerator = ConstraintGenerator() + def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST: self.rel_src_path = src_path.relative_to(self.workdir) self._typed_ast = typed_ast @@ -357,5 +361,19 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): def _make_constraint_assert( self, src_location: Location, expr: ast.expr, constraint: m.Expr ): - # TODO - pass + test: ast.expr = constraint.accept(self._constraint_generator) + self._add_assert( + test, + self._make_constraint_assert_message(src_location, expr, constraint), + ) + + def _make_constraint_assert_message( + self, location: Location, expr: ast.expr, constraint: m.Expr + ) -> ast.expr: + printer = MidasPrinter() + constraint_str: str = printer.print(constraint) + loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}" + # f"file.py:L1:1: ConstraintError: Value does not fit constraint 'v > 0'" + return ast.Constant( + f"{loc_str}: ConstraintError: Value does not fit constraint '{constraint_str}'" + )