Compare commits

2 Commits

Author SHA1 Message Date
55fba6a088 tests: update test without evaluated casts 2026-06-24 11:28:44 +02:00
70ce263ea2 feat(gen): skip assertions for evaluated casts
avoid generating a runtime assertion for a cast which has already been checked statically
2026-06-24 11:28:43 +02:00
4 changed files with 39 additions and 67 deletions

View File

@@ -75,6 +75,7 @@ class PythonTyper(
self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {}
self.judgements: list[tuple[p.Expr, Type]] = []
self.evaluated_casts: list[p.CastExpr] = []
def process(self, source: str, path: Optional[str]) -> TypedAST:
self.reporter = self.reporter.for_file(path)
@@ -88,10 +89,15 @@ class PythonTyper(
self.env = self.global_env
self.locals = resolver.locals
self.judgements = []
self.evaluated_casts = []
self.check(stmts)
return TypedAST(stmts=stmts, judgements=self.judgements)
return TypedAST(
stmts=stmts,
judgements=self.judgements,
evaluated_casts=self.evaluated_casts,
)
def judge(self, expr: p.Expr, type: Type):
"""Record a typing judgement
@@ -546,7 +552,11 @@ class PythonTyper(
target_type: Type = self.resolve_type_expr(expr.type)
is_lit, lit_value = self._get_literal(expr.expr)
if is_lit:
self._evaluate_cast_statically(expr, subject_type, target_type, lit_value)
evaluated: bool = self._evaluate_cast_statically(
expr, subject_type, target_type, lit_value
)
if evaluated:
self.evaluated_casts.append(expr)
return target_type
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
@@ -1157,20 +1167,25 @@ class PythonTyper(
def _evaluate_cast_statically(
self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any
):
) -> bool:
match target_type:
case AliasType(type=base):
self._evaluate_cast_statically(expr, subject_type, base, lit_value)
return self._evaluate_cast_statically(
expr, subject_type, base, lit_value
)
case AppliedType(name=name, args=args, body=body):
generic: Type = self.types.get_type(name)
assert isinstance(generic, GenericType)
for param, arg in zip(generic.params, args):
pass
self._evaluate_cast_statically(expr, subject_type, body, lit_value)
case AppliedType(body=body):
return self._evaluate_cast_statically(
expr, subject_type, body, lit_value
)
case ConstraintType(type=base, constraint=constraint):
self._evaluate_cast_statically(expr, subject_type, base, lit_value)
evaluated: bool = True
if not self._evaluate_cast_statically(
expr, subject_type, base, lit_value
):
evaluated = False
evaluator = Evaluator(self.types)
evaluator.set_value("_", lit_value)
res = evaluator.evaluate(constraint)
@@ -1181,6 +1196,8 @@ class PythonTyper(
expr.location,
f"Value {lit_value!r} does not fit constraint '{constraint_str}'",
)
evaluated = False
return evaluated
case BaseType():
# TODO: do we want to allow cast(float, int)? would require runtime conversion
@@ -1191,8 +1208,11 @@ class PythonTyper(
expr.location,
f"Value {lit_value!r} of type {subject_type} cannot be cast as {target_type}",
)
return False
return True
case _:
self.reporter.info(
expr.location, f"Cannot evaluate cast to {target_type} statically"
)
return False

View File

@@ -44,6 +44,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._typed_ast: TypedAST = TypedAST(
stmts=[],
judgements=[],
evaluated_casts=[],
)
self._alias_count: int = 0
self._predicate_count: int = 0
@@ -131,6 +132,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
expr2: ast.expr = expr.expr.accept(self)
if expr in self._typed_ast.evaluated_casts:
return expr2
alias: ast.expr = self._make_alias(expr2)
type: Type = self._get_expr_type(expr)

View File

@@ -62,3 +62,4 @@ class UniversalJSONDumper:
class TypedAST:
stmts: list[p.Stmt]
judgements: list[tuple[p.Expr, Type]]
evaluated_casts: list[p.CastExpr]

View File

@@ -7,68 +7,14 @@ Module(
alias(name='Meter'),
alias(name='Second')],
level=0),
Assign(
targets=[
Name(id='__midas_a0__')],
value=Constant(value=123.45)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_a0__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='01_simple_types.py:L3:19: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_a0__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assign(
targets=[
Name(id='distance')],
value=Name(id='__midas_a0__')),
Delete(
targets=[
Name(id='__midas_a0__')]),
Assign(
targets=[
Name(id='__midas_a1__')],
value=Constant(value=6.7)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_a1__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='01_simple_types.py:L4:16: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_a1__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
value=Constant(value=123.45)),
Assign(
targets=[
Name(id='time')],
value=Name(id='__midas_a1__')),
Delete(
targets=[
Name(id='__midas_a1__')]),
value=Constant(value=6.7)),
Assign(
targets=[
Name(id='speed')],