feat(gen): generate asserts for dataframes and columns

This commit is contained in:
2026-06-26 14:56:15 +02:00
parent fd0b410d74
commit 9f59366289
2 changed files with 145 additions and 5 deletions

View File

@@ -1306,6 +1306,12 @@ class PythonTyper(
return False
return True
case DataFrameType() | ColumnType():
self.reporter.error(
expr.location, f"Cannot cast {lit_value!r} to {target_type}"
)
return False
case _:
self.reporter.info(
expr.location, f"Cannot evaluate cast to {target_type} statically"

View File

@@ -40,6 +40,9 @@ class Scope:
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
IS_DATAFRAME_FUNC = "__midas_is_dataframe__"
IS_COLUMN_FUNC = "__midas_is_column__"
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
self.workdir: Path = workdir.resolve()
self.build_dir: Path = self.workdir / "build" / "midas"
@@ -58,12 +61,24 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
self._constraints: list[tuple[m.Expr, ast.expr]] = []
self.define_is_dataframe: bool = False
self.define_is_column: bool = False
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
self.rel_src_path = src_path.resolve().relative_to(self.workdir)
self._typed_ast = typed_ast
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
module = ast.Module(body=predicates + body, type_ignores=[])
body = predicates + body
if self.define_is_dataframe:
body = [self._is_dataframe_definition()] + body
if self.define_is_column:
body = [self._is_column_definition()] + body
module = ast.Module(body=body, type_ignores=[])
module = ast.fix_missing_locations(module)
return module
@@ -350,6 +365,51 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
for item, item_type in zip(expr.elts, items):
self._make_cast_asserts(src_location, item, item_type)
case DataFrameType(columns=columns):
self.define_is_dataframe = True
self._add_assert(
ast.Call(
func=ast.Name(id=self.IS_DATAFRAME_FUNC),
args=[expr],
keywords=[],
),
self._make_cast_assert_message(
src_location, expr, type, ": Not a dataframe"
),
)
for column in columns:
self._add_assert(
ast.Compare(
left=ast.Constant(value=column.name),
ops=[ast.In()],
comparators=[expr],
),
self._make_cast_assert_message(
src_location, expr, type, f": Missing column {column.name}"
),
)
self._make_cast_asserts(
src_location,
ast.Subscript(
value=expr, slice=ast.Constant(value=column.name)
),
column.type,
)
case ColumnType(type=inner):
self.define_is_column = True
self._add_assert(
ast.Call(
func=ast.Name(id=self.IS_COLUMN_FUNC),
args=[expr],
keywords=[],
),
self._make_cast_assert_message(
src_location, expr, type, ": Not a column"
),
)
# TODO: check value type
case (
TopType()
| Function()
@@ -357,8 +417,6 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
| ComplexType()
| ExtensionType()
| GenericType()
| ColumnType()
| DataFrameType()
):
self.logger.warning(f"Can't make assertion for type {type}")
@@ -367,7 +425,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
assert_never(type)
def _make_cast_assert_message(
self, location: Location, expr: ast.expr, type: Type
self,
location: Location,
expr: ast.expr,
type: Type,
extra: Optional[str] = None,
) -> ast.expr:
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
@@ -385,7 +447,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
),
conversion=-1,
),
ast.Constant(f" to {type}"),
ast.Constant(f" to {type}{extra or ''}"),
]
)
@@ -421,3 +483,75 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
constraint: ast.expr = self._constraint_generator.generate(expr)
self._constraints.append((expr, constraint))
return constraint
def _is_dataframe_definition(self) -> ast.stmt:
"""
def IS_DATAFRAME_FUNC(obj) -> bool:
import pandas as pd
return isinstance(obj, pd.DataFrame)
"""
return ast.FunctionDef(
name=self.IS_DATAFRAME_FUNC,
args=ast.arguments(
posonlyargs=[ast.arg(arg="obj")],
args=[],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
ast.Return(
value=ast.Call(
func=ast.Name(id="isinstance"),
args=[
ast.Name(id="obj"),
ast.Attribute(
value=ast.Name(id="pd"),
attr="DataFrame",
),
],
keywords=[],
)
),
],
decorator_list=[],
returns=ast.Name(id="bool"),
)
def _is_column_definition(self) -> ast.stmt:
"""
def IS_COLUMN_FUNC(obj) -> bool:
import pandas as pd
return isinstance(obj, pd.Series)
"""
return ast.FunctionDef(
name=self.IS_COLUMN_FUNC,
args=ast.arguments(
posonlyargs=[ast.arg(arg="obj")],
args=[],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
ast.Return(
value=ast.Call(
func=ast.Name(id="isinstance"),
args=[
ast.Name(id="obj"),
ast.Attribute(
value=ast.Name(id="pd"),
attr="Series",
),
],
keywords=[],
)
),
],
decorator_list=[],
returns=ast.Name(id="bool"),
)