fix(checker): pass AST expression to method registry

This commit is contained in:
2026-07-01 22:34:02 +02:00
parent ad2fabf471
commit f7a36f61b6
3 changed files with 28 additions and 10 deletions

View File

@@ -40,6 +40,7 @@ def frame_method(*names: str):
class Call:
location: Location
frame: DataFrameType
frame_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@@ -105,6 +106,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
by_name: dict[str, DataFrameType.Column] = {}
frame2: Optional[DataFrameType] = None
# Get map of operand's columns by name, if there is at least 1 operand, which is a dataframe
if len(call.positional) != 0:
other: Type = call.positional[0][1]
unfolded_other: Type = unfold_type(other)
@@ -114,6 +116,10 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
col.name: col for col in frame2.columns if col.name is not None
}
# Compute new schema:
# Step 1: for all columns in frame1:
# - if present in frame2 with equivalent type -> add to schema as is
# - if not -> add to schema as unknown
in_frame1: set[str] = set()
for column in call.frame.columns:
if column.name is not None:
@@ -134,6 +140,8 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
)
new_columns.append(new_column)
# Step 2: for all columns in frame2
# - if not in frame1 -> add to schema as unknown
if frame2 is not None:
for column in frame2.columns:
if column.name in in_frame1:
@@ -146,6 +154,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
)
)
# Build signature with new schema and generic operand
signature = Function(
args=[
Function.Argument(
@@ -158,6 +167,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
returns=DataFrameType(columns=new_columns),
)
# Map arguments and compute result type
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,

View File

@@ -142,12 +142,14 @@ class FrameManager:
method: str,
location: Location,
frame: DataFrameType,
frame_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: Call = Call(
location=location,
frame=frame,
frame_expr=frame_expr,
positional=positional,
keywords=keywords,
)

View File

@@ -211,23 +211,24 @@ class PythonTyper(
def call_method(
self,
location: Location,
obj: Type,
obj: TypedExpr,
method_name: str,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Optional[Type]:
unfolded: Type = unfold_type(obj)
unfolded: Type = unfold_type(obj[1])
match unfolded:
case DataFrameType():
return self.frame_mgr.call(
method=method_name,
location=location,
frame=unfolded,
frame_expr=obj[0],
positional=positional,
keywords=keywords,
)
method: Optional[Type] = self.types.lookup_member(obj, method_name)
method: Optional[Type] = self.types.lookup_member(obj[1], method_name)
if method is None:
raise UndefinedMethodException
@@ -522,7 +523,9 @@ class PythonTyper(
result: Optional[Type]
try:
result = self.call_method(location, left, method, [(right_expr, right)], {})
result = self.call_method(
location, (left_expr, left), method, [(right_expr, right)], {}
)
except UndefinedMethodException:
self.reporter.error(
location,
@@ -545,7 +548,9 @@ class PythonTyper(
result: Optional[Type]
try:
result = self.call_method(expr.location, operand, method, [], {})
result = self.call_method(
expr.location, (expr.right, operand), method, [], {}
)
except UndefinedMethodException:
self.reporter.error(
expr.location,
@@ -573,11 +578,12 @@ class PythonTyper(
unfolded: Type = unfold_type(obj_type)
if isinstance(unfolded, DataFrameType):
return self.frame_mgr.call(
method,
expr.location,
unfolded,
positional,
keywords,
method=method,
location=expr.location,
frame=unfolded,
frame_expr=obj,
positional=positional,
keywords=keywords,
)
callee: Type = self.type_of(expr.callee)