diff --git a/midas/generator/generator.py b/midas/generator/generator.py index 057cbab..f5d89d5 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -35,8 +35,8 @@ from midas.utils import TypedAST @dataclass class Scope: - pre_assertions: list[ast.stmt] = field(default_factory=list) - aliases: list[str] = field(default_factory=list) + pre_assertions: list[ast.stmt] = field(default_factory=list[ast.stmt]) + aliases: list[str] = field(default_factory=list[str]) class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): @@ -159,7 +159,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): alias: ast.expr = self._make_alias(expr2) type: Type = self._get_expr_type(expr) - self._make_cast_asserts(expr.location, alias, type) + asserts: list[ast.stmt] = self._make_cast_asserts(expr.location, alias, type) + for assert_ in asserts: + self._add_assert(assert_) return alias @@ -294,121 +296,152 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): ) return alias - def _add_assert(self, expr: ast.expr, message: str | ast.expr): + def _build_assert(self, expr: ast.expr, message: str | ast.expr) -> ast.stmt: if isinstance(message, str): message = ast.Constant(value=message) - self._scopes[-1].pre_assertions.append( - ast.Assert( - test=expr, - msg=message, - ) + return ast.Assert( + test=expr, + msg=message, ) + def _add_assert(self, assertion: ast.stmt): + self._scopes[-1].pre_assertions.append(assertion) + def _get_expr_type(self, query: p.Expr) -> Type: for expr, type in self._typed_ast.judgements: if expr == query: return type raise RuntimeError(f"Cannot get type judgement for {query}") - def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type): + def _make_cast_asserts( + self, src_location: Location, expr: ast.expr, type: Type + ) -> list[ast.stmt]: match type: case UnknownType(): - pass + return [] case BaseType(name=name): - self._add_assert( - ast.Call( - func=ast.Name(id="isinstance"), - args=[expr, ast.Name(id=name)], - keywords=[], - ), - self._make_cast_assert_message(src_location, expr, type), - ) + return [ + self._build_assert( + ast.Call( + func=ast.Name(id="isinstance"), + args=[expr, ast.Name(id=name)], + keywords=[], + ), + self._make_cast_assert_message(src_location, expr, type), + ) + ] case AliasType(type=base): - self._make_cast_asserts(src_location, expr, base) + return self._make_cast_asserts(src_location, expr, base) case UnitType(): - self._add_assert( - ast.Compare( - left=expr, - ops=[ast.Is()], - comparators=[ - ast.Constant(value=None), - ], + return [ + self._build_assert( + ast.Compare( + left=expr, + ops=[ast.Is()], + comparators=[ + ast.Constant(value=None), + ], + ), + self._make_cast_assert_message(src_location, expr, type), ), - self._make_cast_assert_message(src_location, expr, type), - ) + ] case AppliedType(body=body): - self._make_cast_asserts(src_location, expr, body) + return self._make_cast_asserts(src_location, expr, body) case ConstraintType(type=base, constraint=constraint): - self._make_cast_asserts(src_location, expr, base) - self._make_constraint_assert(src_location, expr, constraint) + asserts: list[ast.stmt] = self._make_cast_asserts( + src_location, expr, base + ) + asserts.append( + self._make_constraint_assert(src_location, expr, constraint) + ) + return asserts case TypeVar(bound=bound): # TODO: check with type from arguments / use call-site context - if bound is not None: - self._make_cast_asserts(src_location, expr, bound) + if bound is None: + return [] + return self._make_cast_asserts(src_location, expr, bound) case TupleType(items=items): - self._add_assert( - ast.Call( - func=ast.Name(id="isinstance"), - args=[expr, ast.Name(id="tuple")], - keywords=[], + asserts: list[ast.stmt] = [ + self._build_assert( + ast.Call( + func=ast.Name(id="isinstance"), + args=[expr, ast.Name(id="tuple")], + keywords=[], + ), + self._make_cast_assert_message(src_location, expr, type), ), - self._make_cast_assert_message(src_location, expr, type), - ) + ] assert isinstance(expr, ast.Tuple) for item, item_type in zip(expr.elts, items): - self._make_cast_asserts(src_location, item, item_type) + asserts.extend( + self._make_cast_asserts(src_location, item, item_type) + ) + return asserts case DataFrameType(columns=columns): self.define_is_dataframe = True - self._add_assert( - ast.Call( - func=ast.Name(id=self.IS_DATAFRAME_FUNC), - args=[expr], - keywords=[], - ), - self._make_cast_assert_message( - src_location, expr, type, ": Not a dataframe" - ), - ) - for column in columns: - self._add_assert( - ast.Compare( - left=ast.Constant(value=column.name), - ops=[ast.In()], - comparators=[expr], + asserts: list[ast.stmt] = [ + self._build_assert( + ast.Call( + func=ast.Name(id=self.IS_DATAFRAME_FUNC), + args=[expr], + keywords=[], ), self._make_cast_assert_message( - src_location, expr, type, f": Missing column {column.name}" + src_location, expr, type, ": Not a dataframe" ), + ), + ] + for column in columns: + asserts.append( + self._build_assert( + ast.Compare( + left=ast.Constant(value=column.name), + ops=[ast.In()], + comparators=[expr], + ), + self._make_cast_assert_message( + src_location, + expr, + type, + f": Missing column {column.name}", + ), + ) ) - self._make_cast_asserts( - src_location, - ast.Subscript( - value=expr, slice=ast.Constant(value=column.name) - ), - column.type, + asserts.extend( + self._make_cast_asserts( + src_location, + ast.Subscript( + value=expr, slice=ast.Constant(value=column.name) + ), + column.type, + ) ) + return asserts - case ColumnType(type=inner): + case ColumnType(): self.define_is_column = True - self._add_assert( - ast.Call( - func=ast.Name(id=self.IS_COLUMN_FUNC), - args=[expr], - keywords=[], + asserts: list[ast.stmt] = [ + self._build_assert( + ast.Call( + func=ast.Name(id=self.IS_COLUMN_FUNC), + args=[expr], + keywords=[], + ), + self._make_cast_assert_message( + src_location, expr, type, ": Not a column" + ), ), - self._make_cast_assert_message( - src_location, expr, type, ": Not a column" - ), - ) - # TODO: check value type + ] + asserts.append(self._make_column_inner_assert(src_location, expr, type)) + return asserts case ( TopType() @@ -419,6 +452,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): | GenericType() ): self.logger.warning(f"Can't make assertion for type {type}") + return [] # Ensure exhaustiveness case _: @@ -453,9 +487,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): def _make_constraint_assert( self, src_location: Location, expr: ast.expr, constraint: m.Expr - ): + ) -> ast.stmt: test_func: ast.expr = self._get_constraint(constraint) - self._add_assert( + return self._build_assert( ast.Call( func=test_func, args=[expr], @@ -555,3 +589,15 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): decorator_list=[], returns=ast.Name(id="bool"), ) + + def _make_column_inner_assert( + self, src_location: Location, column: ast.expr, type: ColumnType + ) -> ast.stmt: + # TODO: improve message, maybe chain contexts + col: ast.expr = ast.Name(id="col") + return ast.For( + target=col, + iter=column, + body=self._make_cast_asserts(src_location, col, type.type), + orelse=[], + )