From 70ce263ea2c3c7cf07a5b7ecbc60d7409b40767e Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Wed, 24 Jun 2026 11:24:08 +0200 Subject: [PATCH] feat(gen): skip assertions for evaluated casts avoid generating a runtime assertion for a cast which has already been checked statically --- midas/checker/python.py | 42 ++++++++++++++++++++++++++---------- midas/generator/generator.py | 5 +++++ midas/utils.py | 1 + 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/midas/checker/python.py b/midas/checker/python.py index 0ff9eed..ffa1600 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -75,6 +75,7 @@ class PythonTyper( self.env: Environment = self.global_env self.locals: dict[p.Expr, int] = {} self.judgements: list[tuple[p.Expr, Type]] = [] + self.evaluated_casts: list[p.CastExpr] = [] def process(self, source: str, path: Optional[str]) -> TypedAST: self.reporter = self.reporter.for_file(path) @@ -88,10 +89,15 @@ class PythonTyper( self.env = self.global_env self.locals = resolver.locals self.judgements = [] + self.evaluated_casts = [] self.check(stmts) - return TypedAST(stmts=stmts, judgements=self.judgements) + return TypedAST( + stmts=stmts, + judgements=self.judgements, + evaluated_casts=self.evaluated_casts, + ) def judge(self, expr: p.Expr, type: Type): """Record a typing judgement @@ -546,7 +552,11 @@ class PythonTyper( target_type: Type = self.resolve_type_expr(expr.type) is_lit, lit_value = self._get_literal(expr.expr) if is_lit: - self._evaluate_cast_statically(expr, subject_type, target_type, lit_value) + evaluated: bool = self._evaluate_cast_statically( + expr, subject_type, target_type, lit_value + ) + if evaluated: + self.evaluated_casts.append(expr) return target_type def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type: @@ -1157,20 +1167,25 @@ class PythonTyper( def _evaluate_cast_statically( self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any - ): + ) -> bool: match target_type: case AliasType(type=base): - self._evaluate_cast_statically(expr, subject_type, base, lit_value) + return self._evaluate_cast_statically( + expr, subject_type, base, lit_value + ) - case AppliedType(name=name, args=args, body=body): - generic: Type = self.types.get_type(name) - assert isinstance(generic, GenericType) - for param, arg in zip(generic.params, args): - pass - self._evaluate_cast_statically(expr, subject_type, body, lit_value) + case AppliedType(body=body): + return self._evaluate_cast_statically( + expr, subject_type, body, lit_value + ) case ConstraintType(type=base, constraint=constraint): - self._evaluate_cast_statically(expr, subject_type, base, lit_value) + evaluated: bool = True + if not self._evaluate_cast_statically( + expr, subject_type, base, lit_value + ): + evaluated = False + evaluator = Evaluator(self.types) evaluator.set_value("_", lit_value) res = evaluator.evaluate(constraint) @@ -1181,6 +1196,8 @@ class PythonTyper( expr.location, f"Value {lit_value!r} does not fit constraint '{constraint_str}'", ) + evaluated = False + return evaluated case BaseType(): # TODO: do we want to allow cast(float, int)? would require runtime conversion @@ -1191,8 +1208,11 @@ class PythonTyper( expr.location, f"Value {lit_value!r} of type {subject_type} cannot be cast as {target_type}", ) + return False + return True case _: self.reporter.info( expr.location, f"Cannot evaluate cast to {target_type} statically" ) + return False diff --git a/midas/generator/generator.py b/midas/generator/generator.py index e66f532..0fba91e 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -44,6 +44,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): self._typed_ast: TypedAST = TypedAST( stmts=[], judgements=[], + evaluated_casts=[], ) self._alias_count: int = 0 self._predicate_count: int = 0 @@ -131,6 +132,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr: expr2: ast.expr = expr.expr.accept(self) + + if expr in self._typed_ast.evaluated_casts: + return expr2 + alias: ast.expr = self._make_alias(expr2) type: Type = self._get_expr_type(expr) diff --git a/midas/utils.py b/midas/utils.py index f10bdb3..0ac90cf 100644 --- a/midas/utils.py +++ b/midas/utils.py @@ -62,3 +62,4 @@ class UniversalJSONDumper: class TypedAST: stmts: list[p.Stmt] judgements: list[tuple[p.Expr, Type]] + evaluated_casts: list[p.CastExpr]