Compare commits
9 Commits
main
...
feat/simpl
| Author | SHA1 | Date | |
|---|---|---|---|
|
5d20f8ec3e
|
|||
|
955c2233ed
|
|||
|
ff69b65171
|
|||
|
8df01afd8c
|
|||
|
47b2dfdd73
|
|||
|
bd4d793ce0
|
|||
|
f7a36f61b6
|
|||
|
ad2fabf471
|
|||
|
a59a58d21a
|
@@ -678,6 +678,10 @@ In the following example, a runtime check would be generated to ensure that the
|
||||
caption: [Typing of `cast` expression],
|
||||
)
|
||||
|
||||
#gc.warning[
|
||||
Assertions are statements inserted just before a statement using a `cast` expression. This means that the expression is evaluated _before_ its actual intended usage location, which might cause issues if you rely on logical operator short-circuiting. See @eager-eval for more information.
|
||||
]
|
||||
|
||||
There may be some cases where the cost of checking a value at runtime is simply not worth the safety, for example when dealing with a big dataset. If do wish so, you can use `unsafe_cast` which will only tell the type checker the type of the value, without generating a runtime assertion. This maps to the default behavior of `typing`'s own `cast` function.
|
||||
|
||||
If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a string, a list of literals, etc.), the assertion is evaluated _at compile-time_ and no runtime assertion is generated.
|
||||
@@ -695,3 +699,26 @@ If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a
|
||||
== Generating Stubs (`stubs`) <cmd-stubs>
|
||||
== Showing Type Judgements (`types`) <cmd-types>
|
||||
== Validating Definitions (`validate`) <cmd-validate>
|
||||
|
||||
= Known limitations <limitations>
|
||||
|
||||
== Eager evaluation in runtime assertions <eager-eval>
|
||||
|
||||
The process of generating assertions to ensure safety at runtime, mainly for `cast` expressions, leads to the creation of aliases for the expressions being casted. These alias definitions eagerly evaluate before the assertion, and most importantly before the real usage location. This means that you should avoid using `cast` expressions inside logical expressions like `and` or `or`, because the normal "short-circuit" behavior will be irrelevant to the evaluations of the operands.
|
||||
|
||||
For example:
|
||||
|
||||
#figure(
|
||||
```py
|
||||
def foo():
|
||||
print("Foo")
|
||||
return True
|
||||
def bar():
|
||||
print("Bar")
|
||||
return True
|
||||
result = foo() or bar()
|
||||
# Foo
|
||||
# Bar
|
||||
```,
|
||||
caption: [Runtime assertions may eagerly evaluate expressions and bypass logical operator's short-circuit],
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
@@ -18,6 +19,7 @@ from midas.checker.types import (
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
from midas.generator.collector import AssertionCollector
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
@@ -38,7 +40,9 @@ def frame_method(*names: str):
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
frame: DataFrameType
|
||||
frame_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@@ -77,6 +81,10 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
def dispatcher(self) -> CallDispatcher[p.Expr]:
|
||||
return self.typer.dispatcher
|
||||
|
||||
@property
|
||||
def assertions(self) -> AssertionCollector:
|
||||
return self.typer.assertions
|
||||
|
||||
def call(
|
||||
self,
|
||||
method: str,
|
||||
@@ -100,6 +108,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
|
||||
by_name: dict[str, DataFrameType.Column] = {}
|
||||
frame2: Optional[DataFrameType] = None
|
||||
# Get map of operand's columns by name, if there is at least 1 operand, which is a dataframe
|
||||
if len(call.positional) != 0:
|
||||
other: Type = call.positional[0][1]
|
||||
unfolded_other: Type = unfold_type(other)
|
||||
@@ -109,6 +118,10 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
col.name: col for col in frame2.columns if col.name is not None
|
||||
}
|
||||
|
||||
# Compute new schema:
|
||||
# Step 1: for all columns in frame1:
|
||||
# - if present in frame2 with equivalent type -> add to schema as is
|
||||
# - if not -> add to schema as unknown
|
||||
in_frame1: set[str] = set()
|
||||
for column in call.frame.columns:
|
||||
if column.name is not None:
|
||||
@@ -129,6 +142,8 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
)
|
||||
new_columns.append(new_column)
|
||||
|
||||
# Step 2: for all columns in frame2
|
||||
# - if not in frame1 -> add to schema as unknown
|
||||
if frame2 is not None:
|
||||
for column in frame2.columns:
|
||||
if column.name in in_frame1:
|
||||
@@ -141,6 +156,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
)
|
||||
)
|
||||
|
||||
# Build signature with new schema and generic operand
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
@@ -153,12 +169,18 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
returns=DataFrameType(columns=new_columns),
|
||||
)
|
||||
|
||||
# Map arguments and compute result type
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
if result.is_valid:
|
||||
self._assert_same_length(
|
||||
call.call_expr, call.frame_expr, call.positional[0][0]
|
||||
)
|
||||
|
||||
return result.result
|
||||
|
||||
@frame_method()
|
||||
@@ -199,3 +221,50 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr):
|
||||
func_name: str = "__midas_frame_same_length__"
|
||||
self.assertions.define(
|
||||
func_name,
|
||||
ast.FunctionDef(
|
||||
name=func_name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
ast.arg(arg="frame1"),
|
||||
ast.arg(arg="frame2"),
|
||||
],
|
||||
kwonlyargs=[],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
),
|
||||
body=[
|
||||
ast.Return(
|
||||
value=ast.Compare(
|
||||
left=ast.Attribute(
|
||||
value=ast.Name(id="frame1"),
|
||||
attr="size",
|
||||
),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="frame2"),
|
||||
attr="size",
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
],
|
||||
decorator_list=[],
|
||||
),
|
||||
)
|
||||
self.assertions.add(
|
||||
bound_expr=call_expr,
|
||||
inputs=[frame1, frame2],
|
||||
builder=lambda f1, f2: ast.Call(
|
||||
func=ast.Name(id=func_name),
|
||||
args=[f1, f2],
|
||||
keywords=[],
|
||||
),
|
||||
message="DataFrames must have the same length",
|
||||
)
|
||||
|
||||
@@ -141,13 +141,17 @@ class FrameManager:
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
frame: DataFrameType,
|
||||
frame_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: Call = Call(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
frame=frame,
|
||||
frame_expr=frame_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ from midas.checker.types import (
|
||||
DerivedType,
|
||||
Function,
|
||||
GenericType,
|
||||
TopType,
|
||||
TupleType,
|
||||
Type,
|
||||
TypeVar,
|
||||
@@ -36,6 +37,7 @@ from midas.checker.types import (
|
||||
Variance,
|
||||
unfold_type,
|
||||
)
|
||||
from midas.generator.collector import AssertionCollector
|
||||
from midas.parser.python import PythonParser
|
||||
from midas.utils import TypedAST
|
||||
|
||||
@@ -87,6 +89,7 @@ class PythonTyper(
|
||||
self.dispatcher: CallDispatcher[p.Expr] = CallDispatcher[p.Expr](
|
||||
self.types, self.reporter
|
||||
)
|
||||
self.assertions: AssertionCollector = AssertionCollector()
|
||||
|
||||
def set_reporter(self, reporter: FileReporter):
|
||||
self.reporter = reporter
|
||||
@@ -113,6 +116,7 @@ class PythonTyper(
|
||||
stmts=stmts,
|
||||
judgements=self.judgements,
|
||||
evaluated_casts=self.evaluated_casts,
|
||||
assertions=self.assertions,
|
||||
)
|
||||
|
||||
def judge(self, expr: p.Expr, type: Type):
|
||||
@@ -209,23 +213,26 @@ class PythonTyper(
|
||||
def call_method(
|
||||
self,
|
||||
location: Location,
|
||||
obj: Type,
|
||||
call_expr: p.Expr,
|
||||
obj: TypedExpr,
|
||||
method_name: str,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Optional[Type]:
|
||||
unfolded: Type = unfold_type(obj)
|
||||
unfolded: Type = unfold_type(obj[1])
|
||||
match unfolded:
|
||||
case DataFrameType():
|
||||
return self.frame_mgr.call(
|
||||
method=method_name,
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
frame=unfolded,
|
||||
frame_expr=obj[0],
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
method: Optional[Type] = self.types.lookup_member(obj, method_name)
|
||||
method: Optional[Type] = self.types.lookup_member(obj[1], method_name)
|
||||
if method is None:
|
||||
raise UndefinedMethodException
|
||||
|
||||
@@ -499,7 +506,9 @@ class PythonTyper(
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||
return self._visit_binary_expr(
|
||||
expr.location, expr, expr.left, expr.right, method
|
||||
)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||
method: Optional[str] = PY_COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||
@@ -510,17 +519,31 @@ class PythonTyper(
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||
return self._visit_binary_expr(
|
||||
expr.location, expr, expr.left, expr.right, method
|
||||
)
|
||||
|
||||
def _visit_binary_expr(
|
||||
self, location: Location, left_expr: p.Expr, right_expr: p.Expr, method: str
|
||||
self,
|
||||
location: Location,
|
||||
expr: p.Expr,
|
||||
left_expr: p.Expr,
|
||||
right_expr: p.Expr,
|
||||
method: str,
|
||||
) -> Type:
|
||||
left: Type = self.type_of(left_expr)
|
||||
right: Type = self.type_of(right_expr)
|
||||
|
||||
result: Optional[Type]
|
||||
try:
|
||||
result = self.call_method(location, left, method, [(right_expr, right)], {})
|
||||
result = self.call_method(
|
||||
location=location,
|
||||
call_expr=expr,
|
||||
obj=(left_expr, left),
|
||||
method_name=method,
|
||||
positional=[(right_expr, right)],
|
||||
keywords={},
|
||||
)
|
||||
except UndefinedMethodException:
|
||||
self.reporter.error(
|
||||
location,
|
||||
@@ -543,7 +566,14 @@ class PythonTyper(
|
||||
|
||||
result: Optional[Type]
|
||||
try:
|
||||
result = self.call_method(expr.location, operand, method, [], {})
|
||||
result = self.call_method(
|
||||
location=expr.location,
|
||||
call_expr=expr,
|
||||
obj=(expr.right, operand),
|
||||
method_name=method,
|
||||
positional=[],
|
||||
keywords={},
|
||||
)
|
||||
except UndefinedMethodException:
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
@@ -571,11 +601,13 @@ class PythonTyper(
|
||||
unfolded: Type = unfold_type(obj_type)
|
||||
if isinstance(unfolded, DataFrameType):
|
||||
return self.frame_mgr.call(
|
||||
method,
|
||||
expr.location,
|
||||
unfolded,
|
||||
positional,
|
||||
keywords,
|
||||
method=method,
|
||||
location=expr.location,
|
||||
call_expr=expr,
|
||||
frame=unfolded,
|
||||
frame_expr=obj,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
callee: Type = self.type_of(expr.callee)
|
||||
@@ -936,6 +968,17 @@ class PythonTyper(
|
||||
self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any
|
||||
) -> bool:
|
||||
match target_type:
|
||||
case TopType():
|
||||
return True
|
||||
|
||||
case UnitType():
|
||||
if lit_value is not None:
|
||||
self.reporter.error(
|
||||
expr.location, f"Value {lit_value!r} is not None"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
case DerivedType(type=base):
|
||||
return self._evaluate_cast_statically(
|
||||
expr, subject_type, base, lit_value
|
||||
|
||||
59
midas/generator/collector.py
Normal file
59
midas/generator/collector.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import midas.ast.python as p
|
||||
|
||||
AssertionBuilder = Callable[..., ast.expr]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Assertion:
|
||||
bound_expr: p.Expr
|
||||
inputs: list[p.Expr]
|
||||
builder: AssertionBuilder
|
||||
message: str
|
||||
|
||||
def is_bound_to(self, expr: p.Expr) -> bool:
|
||||
return expr == self.bound_expr
|
||||
|
||||
|
||||
class AssertionCollector:
|
||||
def __init__(self):
|
||||
self.assertions: list[Assertion] = []
|
||||
self.definitions: dict[str, ast.stmt] = {}
|
||||
|
||||
def add(
|
||||
self,
|
||||
bound_expr: p.Expr,
|
||||
inputs: list[p.Expr],
|
||||
builder: AssertionBuilder,
|
||||
message: str,
|
||||
):
|
||||
self.assertions.append(
|
||||
Assertion(
|
||||
bound_expr=bound_expr,
|
||||
inputs=inputs,
|
||||
builder=builder,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
|
||||
def remove(self, assertion: Assertion):
|
||||
try:
|
||||
self.assertions.remove(assertion)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def define(self, name: str, stmt: ast.stmt):
|
||||
if name not in self.definitions:
|
||||
self.definitions[name] = stmt
|
||||
|
||||
def get_definitions(self) -> list[ast.stmt]:
|
||||
return list(self.definitions.values())
|
||||
|
||||
def get_assertions(self) -> list[Assertion]:
|
||||
return self.assertions
|
||||
|
||||
def get_assertions_for(self, expr: p.Expr) -> list[Assertion]:
|
||||
return list(filter(lambda a: a.is_bound_to(expr), self.assertions))
|
||||
@@ -30,6 +30,7 @@ from midas.checker.types import (
|
||||
UnitType,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.generator.collector import Assertion, AssertionCollector
|
||||
from midas.generator.constraints import ConstraintGenerator
|
||||
from midas.generator.stubs import StubsGenerator
|
||||
from midas.utils import TypedAST
|
||||
@@ -55,10 +56,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
stmts=[],
|
||||
judgements=[],
|
||||
evaluated_casts=[],
|
||||
assertions=AssertionCollector(),
|
||||
)
|
||||
self._alias_count: int = 0
|
||||
self._predicate_count: int = 0
|
||||
self._scopes: list[Scope] = []
|
||||
self._aliases: list[tuple[p.Expr, ast.expr]] = []
|
||||
|
||||
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
|
||||
self._constraints: list[tuple[m.Expr, ast.expr]] = []
|
||||
@@ -71,7 +74,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def generate_ast(self, typed_ast: TypedAST) -> ast.AST:
|
||||
self._typed_ast = typed_ast
|
||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts, can_be_empty=True)
|
||||
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
|
||||
|
||||
body = predicates + body
|
||||
@@ -129,39 +132,48 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
output: str = ast.unparse(module)
|
||||
out_path.write_text(output)
|
||||
|
||||
def convert(self, expr: p.Expr) -> ast.expr:
|
||||
for expr2, alias in self._aliases:
|
||||
if expr2 == expr:
|
||||
return alias
|
||||
assertions = self._typed_ast.assertions.get_assertions_for(expr)
|
||||
if len(assertions) != 0:
|
||||
return self._apply_assertions(expr, assertions)
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> ast.expr:
|
||||
return ast.BinOp(
|
||||
left=expr.left.accept(self),
|
||||
left=self.convert(expr.left),
|
||||
op=expr.operator,
|
||||
right=expr.right.accept(self),
|
||||
right=self.convert(expr.right),
|
||||
)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> ast.expr:
|
||||
return ast.Compare(
|
||||
left=expr.left.accept(self),
|
||||
left=self.convert(expr.left),
|
||||
ops=[expr.operator],
|
||||
comparators=[expr.right.accept(self)],
|
||||
comparators=[self.convert(expr.right)],
|
||||
)
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> ast.expr:
|
||||
return ast.UnaryOp(
|
||||
op=expr.operator,
|
||||
operand=expr.right.accept(self),
|
||||
operand=self.convert(expr.right),
|
||||
)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=expr.callee.accept(self),
|
||||
args=[arg.accept(self) for arg in expr.arguments],
|
||||
func=self.convert(expr.callee),
|
||||
args=[self.convert(arg) for arg in expr.arguments],
|
||||
keywords=[
|
||||
ast.keyword(arg=name, value=arg.accept(self))
|
||||
ast.keyword(arg=name, value=self.convert(arg))
|
||||
for name, arg in expr.keywords.items()
|
||||
],
|
||||
)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> ast.expr:
|
||||
return ast.Attribute(
|
||||
value=expr.object.accept(self),
|
||||
value=self.convert(expr.object),
|
||||
attr=expr.name,
|
||||
)
|
||||
|
||||
@@ -174,16 +186,16 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> ast.expr:
|
||||
return ast.BoolOp(
|
||||
op=expr.operator,
|
||||
values=[expr.left.accept(self), expr.right.accept(self)],
|
||||
values=[self.convert(expr.left), self.convert(expr.right)],
|
||||
)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
||||
expr2: ast.expr = expr.expr.accept(self)
|
||||
expr2: ast.expr = self.convert(expr.expr)
|
||||
|
||||
if expr in self._typed_ast.evaluated_casts or expr.unsafe:
|
||||
return expr2
|
||||
|
||||
alias: ast.expr = self._make_alias(expr2)
|
||||
alias: ast.expr = self._make_alias(expr.expr, expr2)
|
||||
|
||||
type: Type = self._get_expr_type(expr)
|
||||
asserts: list[ast.stmt] = self._make_cast_asserts(expr.location, alias, type)
|
||||
@@ -194,38 +206,38 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
|
||||
return ast.IfExp(
|
||||
test=expr.test.accept(self),
|
||||
body=expr.if_true.accept(self),
|
||||
orelse=expr.if_false.accept(self),
|
||||
test=self.convert(expr.test),
|
||||
body=self.convert(expr.if_true),
|
||||
orelse=self.convert(expr.if_false),
|
||||
)
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> ast.expr:
|
||||
return ast.List(
|
||||
elts=[item.accept(self) for item in expr.items],
|
||||
elts=[self.convert(item) for item in expr.items],
|
||||
)
|
||||
|
||||
def visit_dict_expr(self, expr: p.DictExpr) -> ast.expr:
|
||||
return ast.Dict(
|
||||
keys=[key.accept(self) if key is not None else None for key in expr.keys],
|
||||
values=[value.accept(self) for value in expr.values],
|
||||
keys=[self.convert(key) if key is not None else None for key in expr.keys],
|
||||
values=[self.convert(value) for value in expr.values],
|
||||
)
|
||||
|
||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr:
|
||||
return ast.Subscript(
|
||||
value=expr.object.accept(self),
|
||||
slice=expr.index.accept(self),
|
||||
value=self.convert(expr.object),
|
||||
slice=self.convert(expr.index),
|
||||
)
|
||||
|
||||
def visit_slice_expr(self, expr: p.SliceExpr) -> ast.expr:
|
||||
return ast.Slice(
|
||||
lower=expr.lower.accept(self) if expr.lower is not None else None,
|
||||
upper=expr.upper.accept(self) if expr.upper is not None else None,
|
||||
step=expr.step.accept(self) if expr.step is not None else None,
|
||||
lower=self.convert(expr.lower) if expr.lower is not None else None,
|
||||
upper=self.convert(expr.upper) if expr.upper is not None else None,
|
||||
step=self.convert(expr.step) if expr.step is not None else None,
|
||||
)
|
||||
|
||||
def visit_tuple_expr(self, expr: p.TupleExpr) -> ast.expr:
|
||||
return ast.Tuple(
|
||||
elts=[item.accept(self) for item in expr.items],
|
||||
elts=[self.convert(item) for item in expr.items],
|
||||
)
|
||||
|
||||
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
|
||||
@@ -233,7 +245,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt:
|
||||
return ast.Expr(
|
||||
value=stmt.expr.accept(self),
|
||||
value=self.convert(stmt.expr),
|
||||
)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> ast.stmt:
|
||||
@@ -246,12 +258,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
kwonlyargs=[ast.arg(arg=arg.name) for arg in stmt.kwonlyargs],
|
||||
kwarg=None,
|
||||
defaults=[
|
||||
arg.default.accept(self)
|
||||
self.convert(arg.default)
|
||||
for arg in stmt.posonlyargs + stmt.args
|
||||
if arg.default is not None
|
||||
],
|
||||
kw_defaults=[
|
||||
arg.default.accept(self) if arg.default is not None else None
|
||||
self.convert(arg.default) if arg.default is not None else None
|
||||
for arg in stmt.kwonlyargs
|
||||
],
|
||||
),
|
||||
@@ -265,20 +277,20 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> ast.stmt:
|
||||
return ast.Assign(
|
||||
targets=[target.accept(self) for target in stmt.targets],
|
||||
value=stmt.value.accept(self),
|
||||
targets=[self.convert(target) for target in stmt.targets],
|
||||
value=self.convert(stmt.value),
|
||||
)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> ast.stmt:
|
||||
return ast.Return(
|
||||
value=stmt.value.accept(self) if stmt.value is not None else None,
|
||||
value=self.convert(stmt.value) if stmt.value is not None else None,
|
||||
)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> ast.stmt:
|
||||
return ast.If(
|
||||
test=stmt.test.accept(self),
|
||||
test=self.convert(stmt.test),
|
||||
body=self._visit_body(stmt.body),
|
||||
orelse=self._visit_body(stmt.orelse),
|
||||
orelse=self._visit_body(stmt.orelse, can_be_empty=True),
|
||||
)
|
||||
|
||||
def visit_pass(self, stmt: p.Pass) -> ast.stmt:
|
||||
@@ -286,8 +298,8 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def visit_for_stmt(self, stmt: p.ForStmt) -> ast.stmt:
|
||||
return ast.For(
|
||||
target=stmt.target.accept(self),
|
||||
iter=stmt.iterator.accept(self),
|
||||
target=self.convert(stmt.target),
|
||||
iter=self.convert(stmt.iterator),
|
||||
body=self._visit_body(stmt.body),
|
||||
orelse=[],
|
||||
)
|
||||
@@ -295,7 +307,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
def visit_raw_stmt(self, stmt: p.RawStmt) -> ast.stmt:
|
||||
return stmt.stmt
|
||||
|
||||
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
|
||||
def _visit_body(
|
||||
self, stmts: list[p.Stmt], can_be_empty: bool = False
|
||||
) -> list[ast.stmt]:
|
||||
generated: list[ast.stmt] = []
|
||||
for stmt in stmts:
|
||||
scope = Scope()
|
||||
@@ -313,9 +327,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
# Remove redundant pass statements
|
||||
if len(generated) > 1:
|
||||
generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)]
|
||||
if len(generated) == 0 and not can_be_empty:
|
||||
generated = [ast.Pass()]
|
||||
return generated
|
||||
|
||||
def _make_alias(self, expr: ast.expr) -> ast.expr:
|
||||
def _make_alias(self, node: p.Expr, expr: ast.expr) -> ast.expr:
|
||||
name: str = f"__midas_a{self._alias_count}__"
|
||||
alias = ast.Name(id=name)
|
||||
self._alias_count += 1
|
||||
@@ -326,6 +342,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
value=expr,
|
||||
)
|
||||
)
|
||||
self._aliases.append((node, alias))
|
||||
return alias
|
||||
|
||||
def _build_assert(self, expr: ast.expr, message: str | ast.expr) -> ast.stmt:
|
||||
@@ -640,3 +657,30 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
body=body,
|
||||
orelse=[],
|
||||
)
|
||||
|
||||
def _convert_assertion(self, assertion: Assertion) -> ast.stmt:
|
||||
inputs: list[ast.expr] = []
|
||||
|
||||
for input in assertion.inputs:
|
||||
converted: ast.expr = self.convert(input)
|
||||
alias: ast.expr = self._make_alias(input, converted)
|
||||
inputs.append(alias)
|
||||
|
||||
test: ast.expr = assertion.builder(*inputs)
|
||||
location: Location = assertion.bound_expr.location
|
||||
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||
return self._build_assert(
|
||||
test, f"{loc_str}: AssertionError: {assertion.message}"
|
||||
)
|
||||
|
||||
def _apply_assertions(self, expr: p.Expr, assertions: list[Assertion]) -> ast.expr:
|
||||
for assertion in assertions:
|
||||
assert_stmt: ast.stmt
|
||||
assert_stmt = self._convert_assertion(assertion)
|
||||
self._add_assert(assert_stmt)
|
||||
|
||||
# Mutating list in frozen dataclass
|
||||
# Not ideal but easiest way to avoid duplicate assertions
|
||||
self._typed_ast.assertions.remove(assertion)
|
||||
|
||||
return expr.accept(self)
|
||||
|
||||
@@ -91,6 +91,21 @@ class StubsGenerator:
|
||||
def generate_stub(self, name: str, type: Type):
|
||||
base_type: Type = type
|
||||
|
||||
# TODO: improve
|
||||
match type:
|
||||
case DerivedType(name=name_) | GenericType(name=name_) if name_ == name:
|
||||
pass
|
||||
case UnitType() if name == "None":
|
||||
pass
|
||||
case TopType() if name == "Any":
|
||||
pass
|
||||
case _:
|
||||
alias = ast.Assign(
|
||||
targets=[ast.Name(id=name)], value=self.dump_type(type)
|
||||
)
|
||||
self.add_stub(alias)
|
||||
return
|
||||
|
||||
members: dict[str, Member] = self.types._members.get(name, {})
|
||||
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
|
||||
return
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Callable, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.checker.types import Type
|
||||
from midas.generator.collector import AssertionCollector
|
||||
|
||||
AllowRepeat = Callable[[object], bool]
|
||||
|
||||
@@ -63,3 +64,4 @@ class TypedAST:
|
||||
stmts: list[p.Stmt]
|
||||
judgements: list[tuple[p.Expr, Type]]
|
||||
evaluated_casts: list[p.CastExpr]
|
||||
assertions: AssertionCollector
|
||||
|
||||
Reference in New Issue
Block a user