fix(checker): pass AST expression to method registry
This commit is contained in:
@@ -40,6 +40,7 @@ def frame_method(*names: str):
|
||||
class Call:
|
||||
location: Location
|
||||
frame: DataFrameType
|
||||
frame_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@@ -105,6 +106,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
|
||||
by_name: dict[str, DataFrameType.Column] = {}
|
||||
frame2: Optional[DataFrameType] = None
|
||||
# Get map of operand's columns by name, if there is at least 1 operand, which is a dataframe
|
||||
if len(call.positional) != 0:
|
||||
other: Type = call.positional[0][1]
|
||||
unfolded_other: Type = unfold_type(other)
|
||||
@@ -114,6 +116,10 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
col.name: col for col in frame2.columns if col.name is not None
|
||||
}
|
||||
|
||||
# Compute new schema:
|
||||
# Step 1: for all columns in frame1:
|
||||
# - if present in frame2 with equivalent type -> add to schema as is
|
||||
# - if not -> add to schema as unknown
|
||||
in_frame1: set[str] = set()
|
||||
for column in call.frame.columns:
|
||||
if column.name is not None:
|
||||
@@ -134,6 +140,8 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
)
|
||||
new_columns.append(new_column)
|
||||
|
||||
# Step 2: for all columns in frame2
|
||||
# - if not in frame1 -> add to schema as unknown
|
||||
if frame2 is not None:
|
||||
for column in frame2.columns:
|
||||
if column.name in in_frame1:
|
||||
@@ -146,6 +154,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
)
|
||||
)
|
||||
|
||||
# Build signature with new schema and generic operand
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
@@ -158,6 +167,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
returns=DataFrameType(columns=new_columns),
|
||||
)
|
||||
|
||||
# Map arguments and compute result type
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
|
||||
@@ -142,12 +142,14 @@ class FrameManager:
|
||||
method: str,
|
||||
location: Location,
|
||||
frame: DataFrameType,
|
||||
frame_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: Call = Call(
|
||||
location=location,
|
||||
frame=frame,
|
||||
frame_expr=frame_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
@@ -211,23 +211,24 @@ class PythonTyper(
|
||||
def call_method(
|
||||
self,
|
||||
location: Location,
|
||||
obj: Type,
|
||||
obj: TypedExpr,
|
||||
method_name: str,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Optional[Type]:
|
||||
unfolded: Type = unfold_type(obj)
|
||||
unfolded: Type = unfold_type(obj[1])
|
||||
match unfolded:
|
||||
case DataFrameType():
|
||||
return self.frame_mgr.call(
|
||||
method=method_name,
|
||||
location=location,
|
||||
frame=unfolded,
|
||||
frame_expr=obj[0],
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
method: Optional[Type] = self.types.lookup_member(obj, method_name)
|
||||
method: Optional[Type] = self.types.lookup_member(obj[1], method_name)
|
||||
if method is None:
|
||||
raise UndefinedMethodException
|
||||
|
||||
@@ -522,7 +523,9 @@ class PythonTyper(
|
||||
|
||||
result: Optional[Type]
|
||||
try:
|
||||
result = self.call_method(location, left, method, [(right_expr, right)], {})
|
||||
result = self.call_method(
|
||||
location, (left_expr, left), method, [(right_expr, right)], {}
|
||||
)
|
||||
except UndefinedMethodException:
|
||||
self.reporter.error(
|
||||
location,
|
||||
@@ -545,7 +548,9 @@ class PythonTyper(
|
||||
|
||||
result: Optional[Type]
|
||||
try:
|
||||
result = self.call_method(expr.location, operand, method, [], {})
|
||||
result = self.call_method(
|
||||
expr.location, (expr.right, operand), method, [], {}
|
||||
)
|
||||
except UndefinedMethodException:
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
@@ -573,11 +578,12 @@ class PythonTyper(
|
||||
unfolded: Type = unfold_type(obj_type)
|
||||
if isinstance(unfolded, DataFrameType):
|
||||
return self.frame_mgr.call(
|
||||
method,
|
||||
expr.location,
|
||||
unfolded,
|
||||
positional,
|
||||
keywords,
|
||||
method=method,
|
||||
location=expr.location,
|
||||
frame=unfolded,
|
||||
frame_expr=obj,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
callee: Type = self.type_of(expr.callee)
|
||||
|
||||
Reference in New Issue
Block a user