9 Commits

Author SHA1 Message Date
5d20f8ec3e docs: mention eager evaluation in manual 2026-07-02 17:22:28 +02:00
955c2233ed feat(checker): statically evaluate casts to Any and None 2026-07-02 17:14:30 +02:00
ff69b65171 feat(checker): add same length assertion on frames
safely adding two dataframes is only possible if the sizes are the same, or null values could be added dynamically to pad the shortest dataframe
2026-07-02 17:14:05 +02:00
8df01afd8c feat(gen): materialize assertions from collector 2026-07-02 17:10:27 +02:00
47b2dfdd73 feat(gen): add assertion collector to TypedAST 2026-07-02 17:09:50 +02:00
bd4d793ce0 feat(gen): add Assertion class 2026-07-02 17:08:43 +02:00
f7a36f61b6 fix(checker): pass AST expression to method registry 2026-07-01 22:34:02 +02:00
ad2fabf471 feat(checker): add assertion collector 2026-07-01 22:32:13 +02:00
a59a58d21a feat(gen): generate alias stubs 2026-07-01 14:43:30 +02:00
8 changed files with 313 additions and 50 deletions

View File

@@ -678,6 +678,10 @@ In the following example, a runtime check would be generated to ensure that the
caption: [Typing of `cast` expression],
)
#gc.warning[
Assertions are statements inserted just before a statement using a `cast` expression. This means that the expression is evaluated _before_ its actual intended usage location, which might cause issues if you rely on logical operator short-circuiting. See @eager-eval for more information.
]
There may be some cases where the cost of checking a value at runtime is simply not worth the safety, for example when dealing with a big dataset. If do wish so, you can use `unsafe_cast` which will only tell the type checker the type of the value, without generating a runtime assertion. This maps to the default behavior of `typing`'s own `cast` function.
If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a string, a list of literals, etc.), the assertion is evaluated _at compile-time_ and no runtime assertion is generated.
@@ -695,3 +699,26 @@ If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a
== Generating Stubs (`stubs`) <cmd-stubs>
== Showing Type Judgements (`types`) <cmd-types>
== Validating Definitions (`validate`) <cmd-validate>
= Known limitations <limitations>
== Eager evaluation in runtime assertions <eager-eval>
The process of generating assertions to ensure safety at runtime, mainly for `cast` expressions, leads to the creation of aliases for the expressions being casted. These alias definitions eagerly evaluate before the assertion, and most importantly before the real usage location. This means that you should avoid using `cast` expressions inside logical expressions like `and` or `or`, because the normal "short-circuit" behavior will be irrelevant to the evaluations of the operands.
For example:
#figure(
```py
def foo():
print("Foo")
return True
def bar():
print("Bar")
return True
result = foo() or bar()
# Foo
# Bar
```,
caption: [Runtime assertions may eagerly evaluate expressions and bypass logical operator's short-circuit],
)

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional
@@ -18,6 +19,7 @@ from midas.checker.types import (
UnknownType,
unfold_type,
)
from midas.generator.collector import AssertionCollector
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
@@ -38,7 +40,9 @@ def frame_method(*names: str):
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
call_expr: p.Expr
frame: DataFrameType
frame_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@@ -77,6 +81,10 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
def dispatcher(self) -> CallDispatcher[p.Expr]:
return self.typer.dispatcher
@property
def assertions(self) -> AssertionCollector:
return self.typer.assertions
def call(
self,
method: str,
@@ -100,6 +108,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
by_name: dict[str, DataFrameType.Column] = {}
frame2: Optional[DataFrameType] = None
# Get map of operand's columns by name, if there is at least 1 operand, which is a dataframe
if len(call.positional) != 0:
other: Type = call.positional[0][1]
unfolded_other: Type = unfold_type(other)
@@ -109,6 +118,10 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
col.name: col for col in frame2.columns if col.name is not None
}
# Compute new schema:
# Step 1: for all columns in frame1:
# - if present in frame2 with equivalent type -> add to schema as is
# - if not -> add to schema as unknown
in_frame1: set[str] = set()
for column in call.frame.columns:
if column.name is not None:
@@ -129,6 +142,8 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
)
new_columns.append(new_column)
# Step 2: for all columns in frame2
# - if not in frame1 -> add to schema as unknown
if frame2 is not None:
for column in frame2.columns:
if column.name in in_frame1:
@@ -141,6 +156,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
)
)
# Build signature with new schema and generic operand
signature = Function(
args=[
Function.Argument(
@@ -153,12 +169,18 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
returns=DataFrameType(columns=new_columns),
)
# Map arguments and compute result type
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
if result.is_valid:
self._assert_same_length(
call.call_expr, call.frame_expr, call.positional[0][0]
)
return result.result
@frame_method()
@@ -199,3 +221,50 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
keywords=call.keywords,
)
return result.result
def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr):
func_name: str = "__midas_frame_same_length__"
self.assertions.define(
func_name,
ast.FunctionDef(
name=func_name,
args=ast.arguments(
posonlyargs=[],
args=[
ast.arg(arg="frame1"),
ast.arg(arg="frame2"),
],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Return(
value=ast.Compare(
left=ast.Attribute(
value=ast.Name(id="frame1"),
attr="size",
),
ops=[ast.Eq()],
comparators=[
ast.Attribute(
value=ast.Name(id="frame2"),
attr="size",
)
],
)
)
],
decorator_list=[],
),
)
self.assertions.add(
bound_expr=call_expr,
inputs=[frame1, frame2],
builder=lambda f1, f2: ast.Call(
func=ast.Name(id=func_name),
args=[f1, f2],
keywords=[],
),
message="DataFrames must have the same length",
)

View File

@@ -141,13 +141,17 @@ class FrameManager:
self,
method: str,
location: Location,
call_expr: p.Expr,
frame: DataFrameType,
frame_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: Call = Call(
location=location,
call_expr=call_expr,
frame=frame,
frame_expr=frame_expr,
positional=positional,
keywords=keywords,
)

View File

@@ -28,6 +28,7 @@ from midas.checker.types import (
DerivedType,
Function,
GenericType,
TopType,
TupleType,
Type,
TypeVar,
@@ -36,6 +37,7 @@ from midas.checker.types import (
Variance,
unfold_type,
)
from midas.generator.collector import AssertionCollector
from midas.parser.python import PythonParser
from midas.utils import TypedAST
@@ -87,6 +89,7 @@ class PythonTyper(
self.dispatcher: CallDispatcher[p.Expr] = CallDispatcher[p.Expr](
self.types, self.reporter
)
self.assertions: AssertionCollector = AssertionCollector()
def set_reporter(self, reporter: FileReporter):
self.reporter = reporter
@@ -113,6 +116,7 @@ class PythonTyper(
stmts=stmts,
judgements=self.judgements,
evaluated_casts=self.evaluated_casts,
assertions=self.assertions,
)
def judge(self, expr: p.Expr, type: Type):
@@ -209,23 +213,26 @@ class PythonTyper(
def call_method(
self,
location: Location,
obj: Type,
call_expr: p.Expr,
obj: TypedExpr,
method_name: str,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Optional[Type]:
unfolded: Type = unfold_type(obj)
unfolded: Type = unfold_type(obj[1])
match unfolded:
case DataFrameType():
return self.frame_mgr.call(
method=method_name,
location=location,
call_expr=call_expr,
frame=unfolded,
frame_expr=obj[0],
positional=positional,
keywords=keywords,
)
method: Optional[Type] = self.types.lookup_member(obj, method_name)
method: Optional[Type] = self.types.lookup_member(obj[1], method_name)
if method is None:
raise UndefinedMethodException
@@ -499,7 +506,9 @@ class PythonTyper(
)
return UnknownType()
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
return self._visit_binary_expr(
expr.location, expr, expr.left, expr.right, method
)
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
method: Optional[str] = PY_COMPARATOR_METHODS.get(expr.operator.__class__)
@@ -510,17 +519,31 @@ class PythonTyper(
)
return UnknownType()
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
return self._visit_binary_expr(
expr.location, expr, expr.left, expr.right, method
)
def _visit_binary_expr(
self, location: Location, left_expr: p.Expr, right_expr: p.Expr, method: str
self,
location: Location,
expr: p.Expr,
left_expr: p.Expr,
right_expr: p.Expr,
method: str,
) -> Type:
left: Type = self.type_of(left_expr)
right: Type = self.type_of(right_expr)
result: Optional[Type]
try:
result = self.call_method(location, left, method, [(right_expr, right)], {})
result = self.call_method(
location=location,
call_expr=expr,
obj=(left_expr, left),
method_name=method,
positional=[(right_expr, right)],
keywords={},
)
except UndefinedMethodException:
self.reporter.error(
location,
@@ -543,7 +566,14 @@ class PythonTyper(
result: Optional[Type]
try:
result = self.call_method(expr.location, operand, method, [], {})
result = self.call_method(
location=expr.location,
call_expr=expr,
obj=(expr.right, operand),
method_name=method,
positional=[],
keywords={},
)
except UndefinedMethodException:
self.reporter.error(
expr.location,
@@ -571,11 +601,13 @@ class PythonTyper(
unfolded: Type = unfold_type(obj_type)
if isinstance(unfolded, DataFrameType):
return self.frame_mgr.call(
method,
expr.location,
unfolded,
positional,
keywords,
method=method,
location=expr.location,
call_expr=expr,
frame=unfolded,
frame_expr=obj,
positional=positional,
keywords=keywords,
)
callee: Type = self.type_of(expr.callee)
@@ -936,6 +968,17 @@ class PythonTyper(
self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any
) -> bool:
match target_type:
case TopType():
return True
case UnitType():
if lit_value is not None:
self.reporter.error(
expr.location, f"Value {lit_value!r} is not None"
)
return False
return True
case DerivedType(type=base):
return self._evaluate_cast_statically(
expr, subject_type, base, lit_value

View File

@@ -0,0 +1,59 @@
import ast
from dataclasses import dataclass
from typing import Callable
import midas.ast.python as p
AssertionBuilder = Callable[..., ast.expr]
@dataclass
class Assertion:
bound_expr: p.Expr
inputs: list[p.Expr]
builder: AssertionBuilder
message: str
def is_bound_to(self, expr: p.Expr) -> bool:
return expr == self.bound_expr
class AssertionCollector:
def __init__(self):
self.assertions: list[Assertion] = []
self.definitions: dict[str, ast.stmt] = {}
def add(
self,
bound_expr: p.Expr,
inputs: list[p.Expr],
builder: AssertionBuilder,
message: str,
):
self.assertions.append(
Assertion(
bound_expr=bound_expr,
inputs=inputs,
builder=builder,
message=message,
)
)
def remove(self, assertion: Assertion):
try:
self.assertions.remove(assertion)
except ValueError:
pass
def define(self, name: str, stmt: ast.stmt):
if name not in self.definitions:
self.definitions[name] = stmt
def get_definitions(self) -> list[ast.stmt]:
return list(self.definitions.values())
def get_assertions(self) -> list[Assertion]:
return self.assertions
def get_assertions_for(self, expr: p.Expr) -> list[Assertion]:
return list(filter(lambda a: a.is_bound_to(expr), self.assertions))

View File

@@ -30,6 +30,7 @@ from midas.checker.types import (
UnitType,
UnknownType,
)
from midas.generator.collector import Assertion, AssertionCollector
from midas.generator.constraints import ConstraintGenerator
from midas.generator.stubs import StubsGenerator
from midas.utils import TypedAST
@@ -55,10 +56,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
stmts=[],
judgements=[],
evaluated_casts=[],
assertions=AssertionCollector(),
)
self._alias_count: int = 0
self._predicate_count: int = 0
self._scopes: list[Scope] = []
self._aliases: list[tuple[p.Expr, ast.expr]] = []
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
self._constraints: list[tuple[m.Expr, ast.expr]] = []
@@ -71,7 +74,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def generate_ast(self, typed_ast: TypedAST) -> ast.AST:
self._typed_ast = typed_ast
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
body: list[ast.stmt] = self._visit_body(typed_ast.stmts, can_be_empty=True)
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
body = predicates + body
@@ -129,39 +132,48 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
output: str = ast.unparse(module)
out_path.write_text(output)
def convert(self, expr: p.Expr) -> ast.expr:
for expr2, alias in self._aliases:
if expr2 == expr:
return alias
assertions = self._typed_ast.assertions.get_assertions_for(expr)
if len(assertions) != 0:
return self._apply_assertions(expr, assertions)
return expr.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> ast.expr:
return ast.BinOp(
left=expr.left.accept(self),
left=self.convert(expr.left),
op=expr.operator,
right=expr.right.accept(self),
right=self.convert(expr.right),
)
def visit_compare_expr(self, expr: p.CompareExpr) -> ast.expr:
return ast.Compare(
left=expr.left.accept(self),
left=self.convert(expr.left),
ops=[expr.operator],
comparators=[expr.right.accept(self)],
comparators=[self.convert(expr.right)],
)
def visit_unary_expr(self, expr: p.UnaryExpr) -> ast.expr:
return ast.UnaryOp(
op=expr.operator,
operand=expr.right.accept(self),
operand=self.convert(expr.right),
)
def visit_call_expr(self, expr: p.CallExpr) -> ast.expr:
return ast.Call(
func=expr.callee.accept(self),
args=[arg.accept(self) for arg in expr.arguments],
func=self.convert(expr.callee),
args=[self.convert(arg) for arg in expr.arguments],
keywords=[
ast.keyword(arg=name, value=arg.accept(self))
ast.keyword(arg=name, value=self.convert(arg))
for name, arg in expr.keywords.items()
],
)
def visit_get_expr(self, expr: p.GetExpr) -> ast.expr:
return ast.Attribute(
value=expr.object.accept(self),
value=self.convert(expr.object),
attr=expr.name,
)
@@ -174,16 +186,16 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_logical_expr(self, expr: p.LogicalExpr) -> ast.expr:
return ast.BoolOp(
op=expr.operator,
values=[expr.left.accept(self), expr.right.accept(self)],
values=[self.convert(expr.left), self.convert(expr.right)],
)
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
expr2: ast.expr = expr.expr.accept(self)
expr2: ast.expr = self.convert(expr.expr)
if expr in self._typed_ast.evaluated_casts or expr.unsafe:
return expr2
alias: ast.expr = self._make_alias(expr2)
alias: ast.expr = self._make_alias(expr.expr, expr2)
type: Type = self._get_expr_type(expr)
asserts: list[ast.stmt] = self._make_cast_asserts(expr.location, alias, type)
@@ -194,38 +206,38 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
return ast.IfExp(
test=expr.test.accept(self),
body=expr.if_true.accept(self),
orelse=expr.if_false.accept(self),
test=self.convert(expr.test),
body=self.convert(expr.if_true),
orelse=self.convert(expr.if_false),
)
def visit_list_expr(self, expr: p.ListExpr) -> ast.expr:
return ast.List(
elts=[item.accept(self) for item in expr.items],
elts=[self.convert(item) for item in expr.items],
)
def visit_dict_expr(self, expr: p.DictExpr) -> ast.expr:
return ast.Dict(
keys=[key.accept(self) if key is not None else None for key in expr.keys],
values=[value.accept(self) for value in expr.values],
keys=[self.convert(key) if key is not None else None for key in expr.keys],
values=[self.convert(value) for value in expr.values],
)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr:
return ast.Subscript(
value=expr.object.accept(self),
slice=expr.index.accept(self),
value=self.convert(expr.object),
slice=self.convert(expr.index),
)
def visit_slice_expr(self, expr: p.SliceExpr) -> ast.expr:
return ast.Slice(
lower=expr.lower.accept(self) if expr.lower is not None else None,
upper=expr.upper.accept(self) if expr.upper is not None else None,
step=expr.step.accept(self) if expr.step is not None else None,
lower=self.convert(expr.lower) if expr.lower is not None else None,
upper=self.convert(expr.upper) if expr.upper is not None else None,
step=self.convert(expr.step) if expr.step is not None else None,
)
def visit_tuple_expr(self, expr: p.TupleExpr) -> ast.expr:
return ast.Tuple(
elts=[item.accept(self) for item in expr.items],
elts=[self.convert(item) for item in expr.items],
)
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
@@ -233,7 +245,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt:
return ast.Expr(
value=stmt.expr.accept(self),
value=self.convert(stmt.expr),
)
def visit_function(self, stmt: p.Function) -> ast.stmt:
@@ -246,12 +258,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
kwonlyargs=[ast.arg(arg=arg.name) for arg in stmt.kwonlyargs],
kwarg=None,
defaults=[
arg.default.accept(self)
self.convert(arg.default)
for arg in stmt.posonlyargs + stmt.args
if arg.default is not None
],
kw_defaults=[
arg.default.accept(self) if arg.default is not None else None
self.convert(arg.default) if arg.default is not None else None
for arg in stmt.kwonlyargs
],
),
@@ -265,20 +277,20 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_assign_stmt(self, stmt: p.AssignStmt) -> ast.stmt:
return ast.Assign(
targets=[target.accept(self) for target in stmt.targets],
value=stmt.value.accept(self),
targets=[self.convert(target) for target in stmt.targets],
value=self.convert(stmt.value),
)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> ast.stmt:
return ast.Return(
value=stmt.value.accept(self) if stmt.value is not None else None,
value=self.convert(stmt.value) if stmt.value is not None else None,
)
def visit_if_stmt(self, stmt: p.IfStmt) -> ast.stmt:
return ast.If(
test=stmt.test.accept(self),
test=self.convert(stmt.test),
body=self._visit_body(stmt.body),
orelse=self._visit_body(stmt.orelse),
orelse=self._visit_body(stmt.orelse, can_be_empty=True),
)
def visit_pass(self, stmt: p.Pass) -> ast.stmt:
@@ -286,8 +298,8 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_for_stmt(self, stmt: p.ForStmt) -> ast.stmt:
return ast.For(
target=stmt.target.accept(self),
iter=stmt.iterator.accept(self),
target=self.convert(stmt.target),
iter=self.convert(stmt.iterator),
body=self._visit_body(stmt.body),
orelse=[],
)
@@ -295,7 +307,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_raw_stmt(self, stmt: p.RawStmt) -> ast.stmt:
return stmt.stmt
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
def _visit_body(
self, stmts: list[p.Stmt], can_be_empty: bool = False
) -> list[ast.stmt]:
generated: list[ast.stmt] = []
for stmt in stmts:
scope = Scope()
@@ -313,9 +327,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
# Remove redundant pass statements
if len(generated) > 1:
generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)]
if len(generated) == 0 and not can_be_empty:
generated = [ast.Pass()]
return generated
def _make_alias(self, expr: ast.expr) -> ast.expr:
def _make_alias(self, node: p.Expr, expr: ast.expr) -> ast.expr:
name: str = f"__midas_a{self._alias_count}__"
alias = ast.Name(id=name)
self._alias_count += 1
@@ -326,6 +342,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
value=expr,
)
)
self._aliases.append((node, alias))
return alias
def _build_assert(self, expr: ast.expr, message: str | ast.expr) -> ast.stmt:
@@ -640,3 +657,30 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
body=body,
orelse=[],
)
def _convert_assertion(self, assertion: Assertion) -> ast.stmt:
inputs: list[ast.expr] = []
for input in assertion.inputs:
converted: ast.expr = self.convert(input)
alias: ast.expr = self._make_alias(input, converted)
inputs.append(alias)
test: ast.expr = assertion.builder(*inputs)
location: Location = assertion.bound_expr.location
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
return self._build_assert(
test, f"{loc_str}: AssertionError: {assertion.message}"
)
def _apply_assertions(self, expr: p.Expr, assertions: list[Assertion]) -> ast.expr:
for assertion in assertions:
assert_stmt: ast.stmt
assert_stmt = self._convert_assertion(assertion)
self._add_assert(assert_stmt)
# Mutating list in frozen dataclass
# Not ideal but easiest way to avoid duplicate assertions
self._typed_ast.assertions.remove(assertion)
return expr.accept(self)

View File

@@ -91,6 +91,21 @@ class StubsGenerator:
def generate_stub(self, name: str, type: Type):
base_type: Type = type
# TODO: improve
match type:
case DerivedType(name=name_) | GenericType(name=name_) if name_ == name:
pass
case UnitType() if name == "None":
pass
case TopType() if name == "Any":
pass
case _:
alias = ast.Assign(
targets=[ast.Name(id=name)], value=self.dump_type(type)
)
self.add_stub(alias)
return
members: dict[str, Member] = self.types._members.get(name, {})
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
return

View File

@@ -3,6 +3,7 @@ from typing import Any, Callable, Optional
import midas.ast.python as p
from midas.checker.types import Type
from midas.generator.collector import AssertionCollector
AllowRepeat = Callable[[object], bool]
@@ -63,3 +64,4 @@ class TypedAST:
stmts: list[p.Stmt]
judgements: list[tuple[p.Expr, Type]]
evaluated_casts: list[p.CastExpr]
assertions: AssertionCollector