3 Commits

5 changed files with 72 additions and 10 deletions

View File

@@ -18,6 +18,7 @@ from midas.checker.types import (
UnknownType,
unfold_type,
)
from midas.generator.collector import AssertionCollector
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
@@ -39,6 +40,7 @@ def frame_method(*names: str):
class Call:
location: Location
frame: DataFrameType
frame_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@@ -77,6 +79,10 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
def dispatcher(self) -> CallDispatcher[p.Expr]:
return self.typer.dispatcher
@property
def assertions(self) -> AssertionCollector:
return self.typer.assertions
def call(
self,
method: str,
@@ -100,6 +106,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
by_name: dict[str, DataFrameType.Column] = {}
frame2: Optional[DataFrameType] = None
# Get map of operand's columns by name, if there is at least 1 operand, which is a dataframe
if len(call.positional) != 0:
other: Type = call.positional[0][1]
unfolded_other: Type = unfold_type(other)
@@ -109,6 +116,10 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
col.name: col for col in frame2.columns if col.name is not None
}
# Compute new schema:
# Step 1: for all columns in frame1:
# - if present in frame2 with equivalent type -> add to schema as is
# - if not -> add to schema as unknown
in_frame1: set[str] = set()
for column in call.frame.columns:
if column.name is not None:
@@ -129,6 +140,8 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
)
new_columns.append(new_column)
# Step 2: for all columns in frame2
# - if not in frame1 -> add to schema as unknown
if frame2 is not None:
for column in frame2.columns:
if column.name in in_frame1:
@@ -141,6 +154,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
)
)
# Build signature with new schema and generic operand
signature = Function(
args=[
Function.Argument(
@@ -153,6 +167,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
returns=DataFrameType(columns=new_columns),
)
# Map arguments and compute result type
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,

View File

@@ -142,12 +142,14 @@ class FrameManager:
method: str,
location: Location,
frame: DataFrameType,
frame_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: Call = Call(
location=location,
frame=frame,
frame_expr=frame_expr,
positional=positional,
keywords=keywords,
)

View File

@@ -36,6 +36,7 @@ from midas.checker.types import (
Variance,
unfold_type,
)
from midas.generator.collector import AssertionCollector
from midas.parser.python import PythonParser
from midas.utils import TypedAST
@@ -87,6 +88,7 @@ class PythonTyper(
self.dispatcher: CallDispatcher[p.Expr] = CallDispatcher[p.Expr](
self.types, self.reporter
)
self.assertions: AssertionCollector = AssertionCollector()
def set_reporter(self, reporter: FileReporter):
self.reporter = reporter
@@ -209,23 +211,24 @@ class PythonTyper(
def call_method(
self,
location: Location,
obj: Type,
obj: TypedExpr,
method_name: str,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Optional[Type]:
unfolded: Type = unfold_type(obj)
unfolded: Type = unfold_type(obj[1])
match unfolded:
case DataFrameType():
return self.frame_mgr.call(
method=method_name,
location=location,
frame=unfolded,
frame_expr=obj[0],
positional=positional,
keywords=keywords,
)
method: Optional[Type] = self.types.lookup_member(obj, method_name)
method: Optional[Type] = self.types.lookup_member(obj[1], method_name)
if method is None:
raise UndefinedMethodException
@@ -520,7 +523,9 @@ class PythonTyper(
result: Optional[Type]
try:
result = self.call_method(location, left, method, [(right_expr, right)], {})
result = self.call_method(
location, (left_expr, left), method, [(right_expr, right)], {}
)
except UndefinedMethodException:
self.reporter.error(
location,
@@ -543,7 +548,9 @@ class PythonTyper(
result: Optional[Type]
try:
result = self.call_method(expr.location, operand, method, [], {})
result = self.call_method(
expr.location, (expr.right, operand), method, [], {}
)
except UndefinedMethodException:
self.reporter.error(
expr.location,
@@ -571,11 +578,12 @@ class PythonTyper(
unfolded: Type = unfold_type(obj_type)
if isinstance(unfolded, DataFrameType):
return self.frame_mgr.call(
method,
expr.location,
unfolded,
positional,
keywords,
method=method,
location=expr.location,
frame=unfolded,
frame_expr=obj,
positional=positional,
keywords=keywords,
)
callee: Type = self.type_of(expr.callee)

View File

@@ -0,0 +1,22 @@
import ast
import midas.ast.python as p
class AssertionCollector:
def __init__(self):
self.assertions: list[tuple[p.Expr, list[ast.expr]]] = []
self.definitions: dict[str, ast.stmt] = {}
def add(self, assertion):
self.assertions.append(assertion)
def define(self, name: str, stmt: ast.stmt):
if name not in self.definitions:
self.definitions[name] = stmt
def get_definitions(self) -> list[ast.stmt]:
return list(self.definitions.values())
def get_assertions(self) -> list[tuple[p.Expr, list[ast.expr]]]:
return self.assertions

View File

@@ -91,6 +91,21 @@ class StubsGenerator:
def generate_stub(self, name: str, type: Type):
base_type: Type = type
# TODO: improve
match type:
case DerivedType(name=name_) | GenericType(name=name_) if name_ == name:
pass
case UnitType() if name == "None":
pass
case TopType() if name == "Any":
pass
case _:
alias = ast.Assign(
targets=[ast.Name(id=name)], value=self.dump_type(type)
)
self.add_stub(alias)
return
members: dict[str, Member] = self.types._members.get(name, {})
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
return