diff --git a/midas/checker/python.py b/midas/checker/python.py index fa1ec3c..29855b1 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -1306,6 +1306,12 @@ class PythonTyper( return False return True + case DataFrameType() | ColumnType(): + self.reporter.error( + expr.location, f"Cannot cast {lit_value!r} to {target_type}" + ) + return False + case _: self.reporter.info( expr.location, f"Cannot evaluate cast to {target_type} statically" diff --git a/midas/generator/generator.py b/midas/generator/generator.py index 1792dcd..057cbab 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -40,6 +40,9 @@ class Scope: class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): + IS_DATAFRAME_FUNC = "__midas_is_dataframe__" + IS_COLUMN_FUNC = "__midas_is_column__" + def __init__(self, workdir: Path, types: TypesRegistry) -> None: self.workdir: Path = workdir.resolve() self.build_dir: Path = self.workdir / "build" / "midas" @@ -58,12 +61,24 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types) self._constraints: list[tuple[m.Expr, ast.expr]] = [] + self.define_is_dataframe: bool = False + self.define_is_column: bool = False + def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST: self.rel_src_path = src_path.resolve().relative_to(self.workdir) self._typed_ast = typed_ast body: list[ast.stmt] = self._visit_body(typed_ast.stmts) predicates: list[ast.stmt] = self._constraint_generator.get_definitions() - module = ast.Module(body=predicates + body, type_ignores=[]) + + body = predicates + body + + if self.define_is_dataframe: + body = [self._is_dataframe_definition()] + body + + if self.define_is_column: + body = [self._is_column_definition()] + body + + module = ast.Module(body=body, type_ignores=[]) module = ast.fix_missing_locations(module) return module @@ -350,6 +365,51 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): for item, item_type in zip(expr.elts, items): self._make_cast_asserts(src_location, item, item_type) + case DataFrameType(columns=columns): + self.define_is_dataframe = True + self._add_assert( + ast.Call( + func=ast.Name(id=self.IS_DATAFRAME_FUNC), + args=[expr], + keywords=[], + ), + self._make_cast_assert_message( + src_location, expr, type, ": Not a dataframe" + ), + ) + for column in columns: + self._add_assert( + ast.Compare( + left=ast.Constant(value=column.name), + ops=[ast.In()], + comparators=[expr], + ), + self._make_cast_assert_message( + src_location, expr, type, f": Missing column {column.name}" + ), + ) + self._make_cast_asserts( + src_location, + ast.Subscript( + value=expr, slice=ast.Constant(value=column.name) + ), + column.type, + ) + + case ColumnType(type=inner): + self.define_is_column = True + self._add_assert( + ast.Call( + func=ast.Name(id=self.IS_COLUMN_FUNC), + args=[expr], + keywords=[], + ), + self._make_cast_assert_message( + src_location, expr, type, ": Not a column" + ), + ) + # TODO: check value type + case ( TopType() | Function() @@ -357,8 +417,6 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): | ComplexType() | ExtensionType() | GenericType() - | ColumnType() - | DataFrameType() ): self.logger.warning(f"Can't make assertion for type {type}") @@ -367,7 +425,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): assert_never(type) def _make_cast_assert_message( - self, location: Location, expr: ast.expr, type: Type + self, + location: Location, + expr: ast.expr, + type: Type, + extra: Optional[str] = None, ) -> ast.expr: loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}" # f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type" @@ -385,7 +447,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): ), conversion=-1, ), - ast.Constant(f" to {type}"), + ast.Constant(f" to {type}{extra or ''}"), ] ) @@ -421,3 +483,75 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): constraint: ast.expr = self._constraint_generator.generate(expr) self._constraints.append((expr, constraint)) return constraint + + def _is_dataframe_definition(self) -> ast.stmt: + """ + def IS_DATAFRAME_FUNC(obj) -> bool: + import pandas as pd + return isinstance(obj, pd.DataFrame) + """ + + return ast.FunctionDef( + name=self.IS_DATAFRAME_FUNC, + args=ast.arguments( + posonlyargs=[ast.arg(arg="obj")], + args=[], + kwonlyargs=[], + defaults=[], + kw_defaults=[], + ), + body=[ + ast.Import(names=[ast.alias(name="pandas", asname="pd")]), + ast.Return( + value=ast.Call( + func=ast.Name(id="isinstance"), + args=[ + ast.Name(id="obj"), + ast.Attribute( + value=ast.Name(id="pd"), + attr="DataFrame", + ), + ], + keywords=[], + ) + ), + ], + decorator_list=[], + returns=ast.Name(id="bool"), + ) + + def _is_column_definition(self) -> ast.stmt: + """ + def IS_COLUMN_FUNC(obj) -> bool: + import pandas as pd + return isinstance(obj, pd.Series) + """ + + return ast.FunctionDef( + name=self.IS_COLUMN_FUNC, + args=ast.arguments( + posonlyargs=[ast.arg(arg="obj")], + args=[], + kwonlyargs=[], + defaults=[], + kw_defaults=[], + ), + body=[ + ast.Import(names=[ast.alias(name="pandas", asname="pd")]), + ast.Return( + value=ast.Call( + func=ast.Name(id="isinstance"), + args=[ + ast.Name(id="obj"), + ast.Attribute( + value=ast.Name(id="pd"), + attr="Series", + ), + ], + keywords=[], + ) + ), + ], + decorator_list=[], + returns=ast.Name(id="bool"), + )