feat(gen): assertions for column values

This commit is contained in:
2026-06-29 11:04:25 +02:00
parent 75bd203d4a
commit 7e0319906a

View File

@@ -35,8 +35,8 @@ from midas.utils import TypedAST
@dataclass @dataclass
class Scope: class Scope:
pre_assertions: list[ast.stmt] = field(default_factory=list) pre_assertions: list[ast.stmt] = field(default_factory=list[ast.stmt])
aliases: list[str] = field(default_factory=list) aliases: list[str] = field(default_factory=list[str])
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
@@ -159,7 +159,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
alias: ast.expr = self._make_alias(expr2) alias: ast.expr = self._make_alias(expr2)
type: Type = self._get_expr_type(expr) type: Type = self._get_expr_type(expr)
self._make_cast_asserts(expr.location, alias, type) asserts: list[ast.stmt] = self._make_cast_asserts(expr.location, alias, type)
for assert_ in asserts:
self._add_assert(assert_)
return alias return alias
@@ -294,15 +296,16 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
) )
return alias return alias
def _add_assert(self, expr: ast.expr, message: str | ast.expr): def _build_assert(self, expr: ast.expr, message: str | ast.expr) -> ast.stmt:
if isinstance(message, str): if isinstance(message, str):
message = ast.Constant(value=message) message = ast.Constant(value=message)
self._scopes[-1].pre_assertions.append( return ast.Assert(
ast.Assert(
test=expr, test=expr,
msg=message, msg=message,
) )
)
def _add_assert(self, assertion: ast.stmt):
self._scopes[-1].pre_assertions.append(assertion)
def _get_expr_type(self, query: p.Expr) -> Type: def _get_expr_type(self, query: p.Expr) -> Type:
for expr, type in self._typed_ast.judgements: for expr, type in self._typed_ast.judgements:
@@ -310,13 +313,16 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
return type return type
raise RuntimeError(f"Cannot get type judgement for {query}") raise RuntimeError(f"Cannot get type judgement for {query}")
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type): def _make_cast_asserts(
self, src_location: Location, expr: ast.expr, type: Type
) -> list[ast.stmt]:
match type: match type:
case UnknownType(): case UnknownType():
pass return []
case BaseType(name=name): case BaseType(name=name):
self._add_assert( return [
self._build_assert(
ast.Call( ast.Call(
func=ast.Name(id="isinstance"), func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id=name)], args=[expr, ast.Name(id=name)],
@@ -324,12 +330,14 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
), ),
self._make_cast_assert_message(src_location, expr, type), self._make_cast_assert_message(src_location, expr, type),
) )
]
case DerivedType(type=base): case DerivedType(type=base):
self._make_cast_asserts(src_location, expr, base) return self._make_cast_asserts(src_location, expr, base)
case UnitType(): case UnitType():
self._add_assert( return [
self._build_assert(
ast.Compare( ast.Compare(
left=expr, left=expr,
ops=[ast.Is()], ops=[ast.Is()],
@@ -338,36 +346,49 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
], ],
), ),
self._make_cast_assert_message(src_location, expr, type), self._make_cast_assert_message(src_location, expr, type),
) ),
]
case AppliedType(body=body): case AppliedType(body=body):
self._make_cast_asserts(src_location, expr, body) return self._make_cast_asserts(src_location, expr, body)
case ConstraintType(type=base, constraint=constraint): case ConstraintType(type=base, constraint=constraint):
self._make_cast_asserts(src_location, expr, base) asserts: list[ast.stmt] = self._make_cast_asserts(
src_location, expr, base
)
asserts.append(
self._make_constraint_assert(src_location, expr, constraint) self._make_constraint_assert(src_location, expr, constraint)
)
return asserts
case TypeVar(bound=bound): case TypeVar(bound=bound):
# TODO: check with type from arguments / use call-site context # TODO: check with type from arguments / use call-site context
if bound is not None: if bound is None:
self._make_cast_asserts(src_location, expr, bound) return []
return self._make_cast_asserts(src_location, expr, bound)
case TupleType(items=items): case TupleType(items=items):
self._add_assert( asserts: list[ast.stmt] = [
self._build_assert(
ast.Call( ast.Call(
func=ast.Name(id="isinstance"), func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id="tuple")], args=[expr, ast.Name(id="tuple")],
keywords=[], keywords=[],
), ),
self._make_cast_assert_message(src_location, expr, type), self._make_cast_assert_message(src_location, expr, type),
) ),
]
assert isinstance(expr, ast.Tuple) assert isinstance(expr, ast.Tuple)
for item, item_type in zip(expr.elts, items): for item, item_type in zip(expr.elts, items):
asserts.extend(
self._make_cast_asserts(src_location, item, item_type) self._make_cast_asserts(src_location, item, item_type)
)
return asserts
case DataFrameType(columns=columns): case DataFrameType(columns=columns):
self.define_is_dataframe = True self.define_is_dataframe = True
self._add_assert( asserts: list[ast.stmt] = [
self._build_assert(
ast.Call( ast.Call(
func=ast.Name(id=self.IS_DATAFRAME_FUNC), func=ast.Name(id=self.IS_DATAFRAME_FUNC),
args=[expr], args=[expr],
@@ -376,18 +397,25 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._make_cast_assert_message( self._make_cast_assert_message(
src_location, expr, type, ": Not a dataframe" src_location, expr, type, ": Not a dataframe"
), ),
) ),
]
for column in columns: for column in columns:
self._add_assert( asserts.append(
self._build_assert(
ast.Compare( ast.Compare(
left=ast.Constant(value=column.name), left=ast.Constant(value=column.name),
ops=[ast.In()], ops=[ast.In()],
comparators=[expr], comparators=[expr],
), ),
self._make_cast_assert_message( self._make_cast_assert_message(
src_location, expr, type, f": Missing column {column.name}" src_location,
expr,
type,
f": Missing column {column.name}",
), ),
) )
)
asserts.extend(
self._make_cast_asserts( self._make_cast_asserts(
src_location, src_location,
ast.Subscript( ast.Subscript(
@@ -395,10 +423,13 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
), ),
column.type, column.type,
) )
)
return asserts
case ColumnType(type=inner): case ColumnType():
self.define_is_column = True self.define_is_column = True
self._add_assert( asserts: list[ast.stmt] = [
self._build_assert(
ast.Call( ast.Call(
func=ast.Name(id=self.IS_COLUMN_FUNC), func=ast.Name(id=self.IS_COLUMN_FUNC),
args=[expr], args=[expr],
@@ -407,8 +438,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._make_cast_assert_message( self._make_cast_assert_message(
src_location, expr, type, ": Not a column" src_location, expr, type, ": Not a column"
), ),
) ),
# TODO: check value type ]
asserts.append(self._make_column_inner_assert(src_location, expr, type))
return asserts
case ( case (
TopType() TopType()
@@ -419,6 +452,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
| GenericType() | GenericType()
): ):
self.logger.warning(f"Can't make assertion for type {type}") self.logger.warning(f"Can't make assertion for type {type}")
return []
# Ensure exhaustiveness # Ensure exhaustiveness
case _: case _:
@@ -453,9 +487,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def _make_constraint_assert( def _make_constraint_assert(
self, src_location: Location, expr: ast.expr, constraint: m.Expr self, src_location: Location, expr: ast.expr, constraint: m.Expr
): ) -> ast.stmt:
test_func: ast.expr = self._get_constraint(constraint) test_func: ast.expr = self._get_constraint(constraint)
self._add_assert( return self._build_assert(
ast.Call( ast.Call(
func=test_func, func=test_func,
args=[expr], args=[expr],
@@ -555,3 +589,15 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
decorator_list=[], decorator_list=[],
returns=ast.Name(id="bool"), returns=ast.Name(id="bool"),
) )
def _make_column_inner_assert(
self, src_location: Location, column: ast.expr, type: ColumnType
) -> ast.stmt:
# TODO: improve message, maybe chain contexts
col: ast.expr = ast.Name(id="col")
return ast.For(
target=col,
iter=column,
body=self._make_cast_asserts(src_location, col, type.type),
orelse=[],
)