feat(gen): generate asserts for dataframes and columns
This commit is contained in:
@@ -1306,6 +1306,12 @@ class PythonTyper(
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
case DataFrameType() | ColumnType():
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Cannot cast {lit_value!r} to {target_type}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
self.reporter.info(
|
self.reporter.info(
|
||||||
expr.location, f"Cannot evaluate cast to {target_type} statically"
|
expr.location, f"Cannot evaluate cast to {target_type} statically"
|
||||||
|
|||||||
@@ -40,6 +40,9 @@ class Scope:
|
|||||||
|
|
||||||
|
|
||||||
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||||
|
IS_DATAFRAME_FUNC = "__midas_is_dataframe__"
|
||||||
|
IS_COLUMN_FUNC = "__midas_is_column__"
|
||||||
|
|
||||||
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
|
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
|
||||||
self.workdir: Path = workdir.resolve()
|
self.workdir: Path = workdir.resolve()
|
||||||
self.build_dir: Path = self.workdir / "build" / "midas"
|
self.build_dir: Path = self.workdir / "build" / "midas"
|
||||||
@@ -58,12 +61,24 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
|
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
|
||||||
self._constraints: list[tuple[m.Expr, ast.expr]] = []
|
self._constraints: list[tuple[m.Expr, ast.expr]] = []
|
||||||
|
|
||||||
|
self.define_is_dataframe: bool = False
|
||||||
|
self.define_is_column: bool = False
|
||||||
|
|
||||||
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
|
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
|
||||||
self.rel_src_path = src_path.resolve().relative_to(self.workdir)
|
self.rel_src_path = src_path.resolve().relative_to(self.workdir)
|
||||||
self._typed_ast = typed_ast
|
self._typed_ast = typed_ast
|
||||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
||||||
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
|
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
|
||||||
module = ast.Module(body=predicates + body, type_ignores=[])
|
|
||||||
|
body = predicates + body
|
||||||
|
|
||||||
|
if self.define_is_dataframe:
|
||||||
|
body = [self._is_dataframe_definition()] + body
|
||||||
|
|
||||||
|
if self.define_is_column:
|
||||||
|
body = [self._is_column_definition()] + body
|
||||||
|
|
||||||
|
module = ast.Module(body=body, type_ignores=[])
|
||||||
module = ast.fix_missing_locations(module)
|
module = ast.fix_missing_locations(module)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
@@ -350,6 +365,51 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
for item, item_type in zip(expr.elts, items):
|
for item, item_type in zip(expr.elts, items):
|
||||||
self._make_cast_asserts(src_location, item, item_type)
|
self._make_cast_asserts(src_location, item, item_type)
|
||||||
|
|
||||||
|
case DataFrameType(columns=columns):
|
||||||
|
self.define_is_dataframe = True
|
||||||
|
self._add_assert(
|
||||||
|
ast.Call(
|
||||||
|
func=ast.Name(id=self.IS_DATAFRAME_FUNC),
|
||||||
|
args=[expr],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(
|
||||||
|
src_location, expr, type, ": Not a dataframe"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for column in columns:
|
||||||
|
self._add_assert(
|
||||||
|
ast.Compare(
|
||||||
|
left=ast.Constant(value=column.name),
|
||||||
|
ops=[ast.In()],
|
||||||
|
comparators=[expr],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(
|
||||||
|
src_location, expr, type, f": Missing column {column.name}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._make_cast_asserts(
|
||||||
|
src_location,
|
||||||
|
ast.Subscript(
|
||||||
|
value=expr, slice=ast.Constant(value=column.name)
|
||||||
|
),
|
||||||
|
column.type,
|
||||||
|
)
|
||||||
|
|
||||||
|
case ColumnType(type=inner):
|
||||||
|
self.define_is_column = True
|
||||||
|
self._add_assert(
|
||||||
|
ast.Call(
|
||||||
|
func=ast.Name(id=self.IS_COLUMN_FUNC),
|
||||||
|
args=[expr],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(
|
||||||
|
src_location, expr, type, ": Not a column"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# TODO: check value type
|
||||||
|
|
||||||
case (
|
case (
|
||||||
TopType()
|
TopType()
|
||||||
| Function()
|
| Function()
|
||||||
@@ -357,8 +417,6 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
| ComplexType()
|
| ComplexType()
|
||||||
| ExtensionType()
|
| ExtensionType()
|
||||||
| GenericType()
|
| GenericType()
|
||||||
| ColumnType()
|
|
||||||
| DataFrameType()
|
|
||||||
):
|
):
|
||||||
self.logger.warning(f"Can't make assertion for type {type}")
|
self.logger.warning(f"Can't make assertion for type {type}")
|
||||||
|
|
||||||
@@ -367,7 +425,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
assert_never(type)
|
assert_never(type)
|
||||||
|
|
||||||
def _make_cast_assert_message(
|
def _make_cast_assert_message(
|
||||||
self, location: Location, expr: ast.expr, type: Type
|
self,
|
||||||
|
location: Location,
|
||||||
|
expr: ast.expr,
|
||||||
|
type: Type,
|
||||||
|
extra: Optional[str] = None,
|
||||||
) -> ast.expr:
|
) -> ast.expr:
|
||||||
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||||
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
|
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
|
||||||
@@ -385,7 +447,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
),
|
),
|
||||||
conversion=-1,
|
conversion=-1,
|
||||||
),
|
),
|
||||||
ast.Constant(f" to {type}"),
|
ast.Constant(f" to {type}{extra or ''}"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -421,3 +483,75 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
constraint: ast.expr = self._constraint_generator.generate(expr)
|
constraint: ast.expr = self._constraint_generator.generate(expr)
|
||||||
self._constraints.append((expr, constraint))
|
self._constraints.append((expr, constraint))
|
||||||
return constraint
|
return constraint
|
||||||
|
|
||||||
|
def _is_dataframe_definition(self) -> ast.stmt:
|
||||||
|
"""
|
||||||
|
def IS_DATAFRAME_FUNC(obj) -> bool:
|
||||||
|
import pandas as pd
|
||||||
|
return isinstance(obj, pd.DataFrame)
|
||||||
|
"""
|
||||||
|
|
||||||
|
return ast.FunctionDef(
|
||||||
|
name=self.IS_DATAFRAME_FUNC,
|
||||||
|
args=ast.arguments(
|
||||||
|
posonlyargs=[ast.arg(arg="obj")],
|
||||||
|
args=[],
|
||||||
|
kwonlyargs=[],
|
||||||
|
defaults=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
),
|
||||||
|
body=[
|
||||||
|
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
|
||||||
|
ast.Return(
|
||||||
|
value=ast.Call(
|
||||||
|
func=ast.Name(id="isinstance"),
|
||||||
|
args=[
|
||||||
|
ast.Name(id="obj"),
|
||||||
|
ast.Attribute(
|
||||||
|
value=ast.Name(id="pd"),
|
||||||
|
attr="DataFrame",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
keywords=[],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=ast.Name(id="bool"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_column_definition(self) -> ast.stmt:
|
||||||
|
"""
|
||||||
|
def IS_COLUMN_FUNC(obj) -> bool:
|
||||||
|
import pandas as pd
|
||||||
|
return isinstance(obj, pd.Series)
|
||||||
|
"""
|
||||||
|
|
||||||
|
return ast.FunctionDef(
|
||||||
|
name=self.IS_COLUMN_FUNC,
|
||||||
|
args=ast.arguments(
|
||||||
|
posonlyargs=[ast.arg(arg="obj")],
|
||||||
|
args=[],
|
||||||
|
kwonlyargs=[],
|
||||||
|
defaults=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
),
|
||||||
|
body=[
|
||||||
|
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
|
||||||
|
ast.Return(
|
||||||
|
value=ast.Call(
|
||||||
|
func=ast.Name(id="isinstance"),
|
||||||
|
args=[
|
||||||
|
ast.Name(id="obj"),
|
||||||
|
ast.Attribute(
|
||||||
|
value=ast.Name(id="pd"),
|
||||||
|
attr="Series",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
keywords=[],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=ast.Name(id="bool"),
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user