feat(gen): assertions for column values
This commit is contained in:
@@ -35,8 +35,8 @@ from midas.utils import TypedAST
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Scope:
|
class Scope:
|
||||||
pre_assertions: list[ast.stmt] = field(default_factory=list)
|
pre_assertions: list[ast.stmt] = field(default_factory=list[ast.stmt])
|
||||||
aliases: list[str] = field(default_factory=list)
|
aliases: list[str] = field(default_factory=list[str])
|
||||||
|
|
||||||
|
|
||||||
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||||
@@ -159,7 +159,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
alias: ast.expr = self._make_alias(expr2)
|
alias: ast.expr = self._make_alias(expr2)
|
||||||
|
|
||||||
type: Type = self._get_expr_type(expr)
|
type: Type = self._get_expr_type(expr)
|
||||||
self._make_cast_asserts(expr.location, alias, type)
|
asserts: list[ast.stmt] = self._make_cast_asserts(expr.location, alias, type)
|
||||||
|
for assert_ in asserts:
|
||||||
|
self._add_assert(assert_)
|
||||||
|
|
||||||
return alias
|
return alias
|
||||||
|
|
||||||
@@ -294,121 +296,152 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
)
|
)
|
||||||
return alias
|
return alias
|
||||||
|
|
||||||
def _add_assert(self, expr: ast.expr, message: str | ast.expr):
|
def _build_assert(self, expr: ast.expr, message: str | ast.expr) -> ast.stmt:
|
||||||
if isinstance(message, str):
|
if isinstance(message, str):
|
||||||
message = ast.Constant(value=message)
|
message = ast.Constant(value=message)
|
||||||
self._scopes[-1].pre_assertions.append(
|
return ast.Assert(
|
||||||
ast.Assert(
|
test=expr,
|
||||||
test=expr,
|
msg=message,
|
||||||
msg=message,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _add_assert(self, assertion: ast.stmt):
|
||||||
|
self._scopes[-1].pre_assertions.append(assertion)
|
||||||
|
|
||||||
def _get_expr_type(self, query: p.Expr) -> Type:
|
def _get_expr_type(self, query: p.Expr) -> Type:
|
||||||
for expr, type in self._typed_ast.judgements:
|
for expr, type in self._typed_ast.judgements:
|
||||||
if expr == query:
|
if expr == query:
|
||||||
return type
|
return type
|
||||||
raise RuntimeError(f"Cannot get type judgement for {query}")
|
raise RuntimeError(f"Cannot get type judgement for {query}")
|
||||||
|
|
||||||
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
|
def _make_cast_asserts(
|
||||||
|
self, src_location: Location, expr: ast.expr, type: Type
|
||||||
|
) -> list[ast.stmt]:
|
||||||
match type:
|
match type:
|
||||||
case UnknownType():
|
case UnknownType():
|
||||||
pass
|
return []
|
||||||
|
|
||||||
case BaseType(name=name):
|
case BaseType(name=name):
|
||||||
self._add_assert(
|
return [
|
||||||
ast.Call(
|
self._build_assert(
|
||||||
func=ast.Name(id="isinstance"),
|
ast.Call(
|
||||||
args=[expr, ast.Name(id=name)],
|
func=ast.Name(id="isinstance"),
|
||||||
keywords=[],
|
args=[expr, ast.Name(id=name)],
|
||||||
),
|
keywords=[],
|
||||||
self._make_cast_assert_message(src_location, expr, type),
|
),
|
||||||
)
|
self._make_cast_assert_message(src_location, expr, type),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
case DerivedType(type=base):
|
case DerivedType(type=base):
|
||||||
self._make_cast_asserts(src_location, expr, base)
|
return self._make_cast_asserts(src_location, expr, base)
|
||||||
|
|
||||||
case UnitType():
|
case UnitType():
|
||||||
self._add_assert(
|
return [
|
||||||
ast.Compare(
|
self._build_assert(
|
||||||
left=expr,
|
ast.Compare(
|
||||||
ops=[ast.Is()],
|
left=expr,
|
||||||
comparators=[
|
ops=[ast.Is()],
|
||||||
ast.Constant(value=None),
|
comparators=[
|
||||||
],
|
ast.Constant(value=None),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(src_location, expr, type),
|
||||||
),
|
),
|
||||||
self._make_cast_assert_message(src_location, expr, type),
|
]
|
||||||
)
|
|
||||||
|
|
||||||
case AppliedType(body=body):
|
case AppliedType(body=body):
|
||||||
self._make_cast_asserts(src_location, expr, body)
|
return self._make_cast_asserts(src_location, expr, body)
|
||||||
|
|
||||||
case ConstraintType(type=base, constraint=constraint):
|
case ConstraintType(type=base, constraint=constraint):
|
||||||
self._make_cast_asserts(src_location, expr, base)
|
asserts: list[ast.stmt] = self._make_cast_asserts(
|
||||||
self._make_constraint_assert(src_location, expr, constraint)
|
src_location, expr, base
|
||||||
|
)
|
||||||
|
asserts.append(
|
||||||
|
self._make_constraint_assert(src_location, expr, constraint)
|
||||||
|
)
|
||||||
|
return asserts
|
||||||
|
|
||||||
case TypeVar(bound=bound):
|
case TypeVar(bound=bound):
|
||||||
# TODO: check with type from arguments / use call-site context
|
# TODO: check with type from arguments / use call-site context
|
||||||
if bound is not None:
|
if bound is None:
|
||||||
self._make_cast_asserts(src_location, expr, bound)
|
return []
|
||||||
|
return self._make_cast_asserts(src_location, expr, bound)
|
||||||
|
|
||||||
case TupleType(items=items):
|
case TupleType(items=items):
|
||||||
self._add_assert(
|
asserts: list[ast.stmt] = [
|
||||||
ast.Call(
|
self._build_assert(
|
||||||
func=ast.Name(id="isinstance"),
|
ast.Call(
|
||||||
args=[expr, ast.Name(id="tuple")],
|
func=ast.Name(id="isinstance"),
|
||||||
keywords=[],
|
args=[expr, ast.Name(id="tuple")],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(src_location, expr, type),
|
||||||
),
|
),
|
||||||
self._make_cast_assert_message(src_location, expr, type),
|
]
|
||||||
)
|
|
||||||
assert isinstance(expr, ast.Tuple)
|
assert isinstance(expr, ast.Tuple)
|
||||||
for item, item_type in zip(expr.elts, items):
|
for item, item_type in zip(expr.elts, items):
|
||||||
self._make_cast_asserts(src_location, item, item_type)
|
asserts.extend(
|
||||||
|
self._make_cast_asserts(src_location, item, item_type)
|
||||||
|
)
|
||||||
|
return asserts
|
||||||
|
|
||||||
case DataFrameType(columns=columns):
|
case DataFrameType(columns=columns):
|
||||||
self.define_is_dataframe = True
|
self.define_is_dataframe = True
|
||||||
self._add_assert(
|
asserts: list[ast.stmt] = [
|
||||||
ast.Call(
|
self._build_assert(
|
||||||
func=ast.Name(id=self.IS_DATAFRAME_FUNC),
|
ast.Call(
|
||||||
args=[expr],
|
func=ast.Name(id=self.IS_DATAFRAME_FUNC),
|
||||||
keywords=[],
|
args=[expr],
|
||||||
),
|
keywords=[],
|
||||||
self._make_cast_assert_message(
|
|
||||||
src_location, expr, type, ": Not a dataframe"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for column in columns:
|
|
||||||
self._add_assert(
|
|
||||||
ast.Compare(
|
|
||||||
left=ast.Constant(value=column.name),
|
|
||||||
ops=[ast.In()],
|
|
||||||
comparators=[expr],
|
|
||||||
),
|
),
|
||||||
self._make_cast_assert_message(
|
self._make_cast_assert_message(
|
||||||
src_location, expr, type, f": Missing column {column.name}"
|
src_location, expr, type, ": Not a dataframe"
|
||||||
),
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for column in columns:
|
||||||
|
asserts.append(
|
||||||
|
self._build_assert(
|
||||||
|
ast.Compare(
|
||||||
|
left=ast.Constant(value=column.name),
|
||||||
|
ops=[ast.In()],
|
||||||
|
comparators=[expr],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(
|
||||||
|
src_location,
|
||||||
|
expr,
|
||||||
|
type,
|
||||||
|
f": Missing column {column.name}",
|
||||||
|
),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self._make_cast_asserts(
|
asserts.extend(
|
||||||
src_location,
|
self._make_cast_asserts(
|
||||||
ast.Subscript(
|
src_location,
|
||||||
value=expr, slice=ast.Constant(value=column.name)
|
ast.Subscript(
|
||||||
),
|
value=expr, slice=ast.Constant(value=column.name)
|
||||||
column.type,
|
),
|
||||||
|
column.type,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
return asserts
|
||||||
|
|
||||||
case ColumnType(type=inner):
|
case ColumnType():
|
||||||
self.define_is_column = True
|
self.define_is_column = True
|
||||||
self._add_assert(
|
asserts: list[ast.stmt] = [
|
||||||
ast.Call(
|
self._build_assert(
|
||||||
func=ast.Name(id=self.IS_COLUMN_FUNC),
|
ast.Call(
|
||||||
args=[expr],
|
func=ast.Name(id=self.IS_COLUMN_FUNC),
|
||||||
keywords=[],
|
args=[expr],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(
|
||||||
|
src_location, expr, type, ": Not a column"
|
||||||
|
),
|
||||||
),
|
),
|
||||||
self._make_cast_assert_message(
|
]
|
||||||
src_location, expr, type, ": Not a column"
|
asserts.append(self._make_column_inner_assert(src_location, expr, type))
|
||||||
),
|
return asserts
|
||||||
)
|
|
||||||
# TODO: check value type
|
|
||||||
|
|
||||||
case (
|
case (
|
||||||
TopType()
|
TopType()
|
||||||
@@ -419,6 +452,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
| GenericType()
|
| GenericType()
|
||||||
):
|
):
|
||||||
self.logger.warning(f"Can't make assertion for type {type}")
|
self.logger.warning(f"Can't make assertion for type {type}")
|
||||||
|
return []
|
||||||
|
|
||||||
# Ensure exhaustiveness
|
# Ensure exhaustiveness
|
||||||
case _:
|
case _:
|
||||||
@@ -453,9 +487,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
|
|
||||||
def _make_constraint_assert(
|
def _make_constraint_assert(
|
||||||
self, src_location: Location, expr: ast.expr, constraint: m.Expr
|
self, src_location: Location, expr: ast.expr, constraint: m.Expr
|
||||||
):
|
) -> ast.stmt:
|
||||||
test_func: ast.expr = self._get_constraint(constraint)
|
test_func: ast.expr = self._get_constraint(constraint)
|
||||||
self._add_assert(
|
return self._build_assert(
|
||||||
ast.Call(
|
ast.Call(
|
||||||
func=test_func,
|
func=test_func,
|
||||||
args=[expr],
|
args=[expr],
|
||||||
@@ -555,3 +589,15 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
decorator_list=[],
|
decorator_list=[],
|
||||||
returns=ast.Name(id="bool"),
|
returns=ast.Name(id="bool"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _make_column_inner_assert(
|
||||||
|
self, src_location: Location, column: ast.expr, type: ColumnType
|
||||||
|
) -> ast.stmt:
|
||||||
|
# TODO: improve message, maybe chain contexts
|
||||||
|
col: ast.expr = ast.Name(id="col")
|
||||||
|
return ast.For(
|
||||||
|
target=col,
|
||||||
|
iter=column,
|
||||||
|
body=self._make_cast_asserts(src_location, col, type.type),
|
||||||
|
orelse=[],
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user