Compare commits
3 Commits
main
...
feat/simpl
| Author | SHA1 | Date | |
|---|---|---|---|
|
f7a36f61b6
|
|||
|
ad2fabf471
|
|||
|
a59a58d21a
|
@@ -18,6 +18,7 @@ from midas.checker.types import (
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
from midas.generator.collector import AssertionCollector
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
@@ -39,6 +40,7 @@ def frame_method(*names: str):
|
||||
class Call:
|
||||
location: Location
|
||||
frame: DataFrameType
|
||||
frame_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@@ -77,6 +79,10 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
def dispatcher(self) -> CallDispatcher[p.Expr]:
|
||||
return self.typer.dispatcher
|
||||
|
||||
@property
|
||||
def assertions(self) -> AssertionCollector:
|
||||
return self.typer.assertions
|
||||
|
||||
def call(
|
||||
self,
|
||||
method: str,
|
||||
@@ -100,6 +106,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
|
||||
by_name: dict[str, DataFrameType.Column] = {}
|
||||
frame2: Optional[DataFrameType] = None
|
||||
# Get map of operand's columns by name, if there is at least 1 operand, which is a dataframe
|
||||
if len(call.positional) != 0:
|
||||
other: Type = call.positional[0][1]
|
||||
unfolded_other: Type = unfold_type(other)
|
||||
@@ -109,6 +116,10 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
col.name: col for col in frame2.columns if col.name is not None
|
||||
}
|
||||
|
||||
# Compute new schema:
|
||||
# Step 1: for all columns in frame1:
|
||||
# - if present in frame2 with equivalent type -> add to schema as is
|
||||
# - if not -> add to schema as unknown
|
||||
in_frame1: set[str] = set()
|
||||
for column in call.frame.columns:
|
||||
if column.name is not None:
|
||||
@@ -129,6 +140,8 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
)
|
||||
new_columns.append(new_column)
|
||||
|
||||
# Step 2: for all columns in frame2
|
||||
# - if not in frame1 -> add to schema as unknown
|
||||
if frame2 is not None:
|
||||
for column in frame2.columns:
|
||||
if column.name in in_frame1:
|
||||
@@ -141,6 +154,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
)
|
||||
)
|
||||
|
||||
# Build signature with new schema and generic operand
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
@@ -153,6 +167,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
returns=DataFrameType(columns=new_columns),
|
||||
)
|
||||
|
||||
# Map arguments and compute result type
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
|
||||
@@ -142,12 +142,14 @@ class FrameManager:
|
||||
method: str,
|
||||
location: Location,
|
||||
frame: DataFrameType,
|
||||
frame_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: Call = Call(
|
||||
location=location,
|
||||
frame=frame,
|
||||
frame_expr=frame_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
@@ -36,6 +36,7 @@ from midas.checker.types import (
|
||||
Variance,
|
||||
unfold_type,
|
||||
)
|
||||
from midas.generator.collector import AssertionCollector
|
||||
from midas.parser.python import PythonParser
|
||||
from midas.utils import TypedAST
|
||||
|
||||
@@ -87,6 +88,7 @@ class PythonTyper(
|
||||
self.dispatcher: CallDispatcher[p.Expr] = CallDispatcher[p.Expr](
|
||||
self.types, self.reporter
|
||||
)
|
||||
self.assertions: AssertionCollector = AssertionCollector()
|
||||
|
||||
def set_reporter(self, reporter: FileReporter):
|
||||
self.reporter = reporter
|
||||
@@ -209,23 +211,24 @@ class PythonTyper(
|
||||
def call_method(
|
||||
self,
|
||||
location: Location,
|
||||
obj: Type,
|
||||
obj: TypedExpr,
|
||||
method_name: str,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Optional[Type]:
|
||||
unfolded: Type = unfold_type(obj)
|
||||
unfolded: Type = unfold_type(obj[1])
|
||||
match unfolded:
|
||||
case DataFrameType():
|
||||
return self.frame_mgr.call(
|
||||
method=method_name,
|
||||
location=location,
|
||||
frame=unfolded,
|
||||
frame_expr=obj[0],
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
method: Optional[Type] = self.types.lookup_member(obj, method_name)
|
||||
method: Optional[Type] = self.types.lookup_member(obj[1], method_name)
|
||||
if method is None:
|
||||
raise UndefinedMethodException
|
||||
|
||||
@@ -520,7 +523,9 @@ class PythonTyper(
|
||||
|
||||
result: Optional[Type]
|
||||
try:
|
||||
result = self.call_method(location, left, method, [(right_expr, right)], {})
|
||||
result = self.call_method(
|
||||
location, (left_expr, left), method, [(right_expr, right)], {}
|
||||
)
|
||||
except UndefinedMethodException:
|
||||
self.reporter.error(
|
||||
location,
|
||||
@@ -543,7 +548,9 @@ class PythonTyper(
|
||||
|
||||
result: Optional[Type]
|
||||
try:
|
||||
result = self.call_method(expr.location, operand, method, [], {})
|
||||
result = self.call_method(
|
||||
expr.location, (expr.right, operand), method, [], {}
|
||||
)
|
||||
except UndefinedMethodException:
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
@@ -571,11 +578,12 @@ class PythonTyper(
|
||||
unfolded: Type = unfold_type(obj_type)
|
||||
if isinstance(unfolded, DataFrameType):
|
||||
return self.frame_mgr.call(
|
||||
method,
|
||||
expr.location,
|
||||
unfolded,
|
||||
positional,
|
||||
keywords,
|
||||
method=method,
|
||||
location=expr.location,
|
||||
frame=unfolded,
|
||||
frame_expr=obj,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
callee: Type = self.type_of(expr.callee)
|
||||
|
||||
22
midas/generator/collector.py
Normal file
22
midas/generator/collector.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import ast
|
||||
|
||||
import midas.ast.python as p
|
||||
|
||||
|
||||
class AssertionCollector:
|
||||
def __init__(self):
|
||||
self.assertions: list[tuple[p.Expr, list[ast.expr]]] = []
|
||||
self.definitions: dict[str, ast.stmt] = {}
|
||||
|
||||
def add(self, assertion):
|
||||
self.assertions.append(assertion)
|
||||
|
||||
def define(self, name: str, stmt: ast.stmt):
|
||||
if name not in self.definitions:
|
||||
self.definitions[name] = stmt
|
||||
|
||||
def get_definitions(self) -> list[ast.stmt]:
|
||||
return list(self.definitions.values())
|
||||
|
||||
def get_assertions(self) -> list[tuple[p.Expr, list[ast.expr]]]:
|
||||
return self.assertions
|
||||
@@ -91,6 +91,21 @@ class StubsGenerator:
|
||||
def generate_stub(self, name: str, type: Type):
|
||||
base_type: Type = type
|
||||
|
||||
# TODO: improve
|
||||
match type:
|
||||
case DerivedType(name=name_) | GenericType(name=name_) if name_ == name:
|
||||
pass
|
||||
case UnitType() if name == "None":
|
||||
pass
|
||||
case TopType() if name == "Any":
|
||||
pass
|
||||
case _:
|
||||
alias = ast.Assign(
|
||||
targets=[ast.Name(id=name)], value=self.dump_type(type)
|
||||
)
|
||||
self.add_stub(alias)
|
||||
return
|
||||
|
||||
members: dict[str, Member] = self.types._members.get(name, {})
|
||||
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user