feat(gen): skip assertions for evaluated casts
avoid generating a runtime assertion for a cast which has already been checked statically
This commit is contained in:
@@ -75,6 +75,7 @@ class PythonTyper(
|
|||||||
self.env: Environment = self.global_env
|
self.env: Environment = self.global_env
|
||||||
self.locals: dict[p.Expr, int] = {}
|
self.locals: dict[p.Expr, int] = {}
|
||||||
self.judgements: list[tuple[p.Expr, Type]] = []
|
self.judgements: list[tuple[p.Expr, Type]] = []
|
||||||
|
self.evaluated_casts: list[p.CastExpr] = []
|
||||||
|
|
||||||
def process(self, source: str, path: Optional[str]) -> TypedAST:
|
def process(self, source: str, path: Optional[str]) -> TypedAST:
|
||||||
self.reporter = self.reporter.for_file(path)
|
self.reporter = self.reporter.for_file(path)
|
||||||
@@ -88,10 +89,15 @@ class PythonTyper(
|
|||||||
self.env = self.global_env
|
self.env = self.global_env
|
||||||
self.locals = resolver.locals
|
self.locals = resolver.locals
|
||||||
self.judgements = []
|
self.judgements = []
|
||||||
|
self.evaluated_casts = []
|
||||||
|
|
||||||
self.check(stmts)
|
self.check(stmts)
|
||||||
|
|
||||||
return TypedAST(stmts=stmts, judgements=self.judgements)
|
return TypedAST(
|
||||||
|
stmts=stmts,
|
||||||
|
judgements=self.judgements,
|
||||||
|
evaluated_casts=self.evaluated_casts,
|
||||||
|
)
|
||||||
|
|
||||||
def judge(self, expr: p.Expr, type: Type):
|
def judge(self, expr: p.Expr, type: Type):
|
||||||
"""Record a typing judgement
|
"""Record a typing judgement
|
||||||
@@ -546,7 +552,11 @@ class PythonTyper(
|
|||||||
target_type: Type = self.resolve_type_expr(expr.type)
|
target_type: Type = self.resolve_type_expr(expr.type)
|
||||||
is_lit, lit_value = self._get_literal(expr.expr)
|
is_lit, lit_value = self._get_literal(expr.expr)
|
||||||
if is_lit:
|
if is_lit:
|
||||||
self._evaluate_cast_statically(expr, subject_type, target_type, lit_value)
|
evaluated: bool = self._evaluate_cast_statically(
|
||||||
|
expr, subject_type, target_type, lit_value
|
||||||
|
)
|
||||||
|
if evaluated:
|
||||||
|
self.evaluated_casts.append(expr)
|
||||||
return target_type
|
return target_type
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
||||||
@@ -1157,20 +1167,25 @@ class PythonTyper(
|
|||||||
|
|
||||||
def _evaluate_cast_statically(
|
def _evaluate_cast_statically(
|
||||||
self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any
|
self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any
|
||||||
):
|
) -> bool:
|
||||||
match target_type:
|
match target_type:
|
||||||
case AliasType(type=base):
|
case AliasType(type=base):
|
||||||
self._evaluate_cast_statically(expr, subject_type, base, lit_value)
|
return self._evaluate_cast_statically(
|
||||||
|
expr, subject_type, base, lit_value
|
||||||
|
)
|
||||||
|
|
||||||
case AppliedType(name=name, args=args, body=body):
|
case AppliedType(body=body):
|
||||||
generic: Type = self.types.get_type(name)
|
return self._evaluate_cast_statically(
|
||||||
assert isinstance(generic, GenericType)
|
expr, subject_type, body, lit_value
|
||||||
for param, arg in zip(generic.params, args):
|
)
|
||||||
pass
|
|
||||||
self._evaluate_cast_statically(expr, subject_type, body, lit_value)
|
|
||||||
|
|
||||||
case ConstraintType(type=base, constraint=constraint):
|
case ConstraintType(type=base, constraint=constraint):
|
||||||
self._evaluate_cast_statically(expr, subject_type, base, lit_value)
|
evaluated: bool = True
|
||||||
|
if not self._evaluate_cast_statically(
|
||||||
|
expr, subject_type, base, lit_value
|
||||||
|
):
|
||||||
|
evaluated = False
|
||||||
|
|
||||||
evaluator = Evaluator(self.types)
|
evaluator = Evaluator(self.types)
|
||||||
evaluator.set_value("_", lit_value)
|
evaluator.set_value("_", lit_value)
|
||||||
res = evaluator.evaluate(constraint)
|
res = evaluator.evaluate(constraint)
|
||||||
@@ -1181,6 +1196,8 @@ class PythonTyper(
|
|||||||
expr.location,
|
expr.location,
|
||||||
f"Value {lit_value!r} does not fit constraint '{constraint_str}'",
|
f"Value {lit_value!r} does not fit constraint '{constraint_str}'",
|
||||||
)
|
)
|
||||||
|
evaluated = False
|
||||||
|
return evaluated
|
||||||
|
|
||||||
case BaseType():
|
case BaseType():
|
||||||
# TODO: do we want to allow cast(float, int)? would require runtime conversion
|
# TODO: do we want to allow cast(float, int)? would require runtime conversion
|
||||||
@@ -1191,8 +1208,11 @@ class PythonTyper(
|
|||||||
expr.location,
|
expr.location,
|
||||||
f"Value {lit_value!r} of type {subject_type} cannot be cast as {target_type}",
|
f"Value {lit_value!r} of type {subject_type} cannot be cast as {target_type}",
|
||||||
)
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
self.reporter.info(
|
self.reporter.info(
|
||||||
expr.location, f"Cannot evaluate cast to {target_type} statically"
|
expr.location, f"Cannot evaluate cast to {target_type} statically"
|
||||||
)
|
)
|
||||||
|
return False
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
self._typed_ast: TypedAST = TypedAST(
|
self._typed_ast: TypedAST = TypedAST(
|
||||||
stmts=[],
|
stmts=[],
|
||||||
judgements=[],
|
judgements=[],
|
||||||
|
evaluated_casts=[],
|
||||||
)
|
)
|
||||||
self._alias_count: int = 0
|
self._alias_count: int = 0
|
||||||
self._predicate_count: int = 0
|
self._predicate_count: int = 0
|
||||||
@@ -131,6 +132,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
|
|
||||||
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
||||||
expr2: ast.expr = expr.expr.accept(self)
|
expr2: ast.expr = expr.expr.accept(self)
|
||||||
|
|
||||||
|
if expr in self._typed_ast.evaluated_casts:
|
||||||
|
return expr2
|
||||||
|
|
||||||
alias: ast.expr = self._make_alias(expr2)
|
alias: ast.expr = self._make_alias(expr2)
|
||||||
|
|
||||||
type: Type = self._get_expr_type(expr)
|
type: Type = self._get_expr_type(expr)
|
||||||
|
|||||||
Reference in New Issue
Block a user