feat(checker): add arithmetic binary ops on frames

This commit is contained in:
2026-07-03 00:38:56 +02:00
parent be2fd4c837
commit 74c07c9afb

View File

@@ -142,10 +142,8 @@ class FrameMethodRegistry(MethodRegistry[Call]):
return DataFrameType(columns=new_columns)
@method("add", "__add__")
def add(self, call: Call) -> Type:
# TODO: support add with scalar, sequence, Series, dict
def _element_wise(self, call: Call, method: str) -> Type:
# TODO: support scalar, sequence, Series, dict operand
# Build signature with new schema and generic operand
signature = Function(
args=[
@@ -156,7 +154,7 @@ class FrameMethodRegistry(MethodRegistry[Call]):
required=True,
),
],
returns=self._element_binary_op(call, "__add__"),
returns=self._element_binary_op(call, method),
)
# Map arguments and compute result type
@@ -173,6 +171,34 @@ class FrameMethodRegistry(MethodRegistry[Call]):
return result.result
@method("add", "__add__")
def add(self, call: Call) -> Type:
return self._element_wise(call, "__add__")
@method("sub", "__sub__")
def sub(self, call: Call) -> Type:
return self._element_wise(call, "__sub__")
@method("mul", "__mul__")
def mul(self, call: Call) -> Type:
return self._element_wise(call, "__mul__")
@method("div", "truediv", "__truediv__")
def truediv(self, call: Call) -> Type:
return self._element_wise(call, "__truediv__")
@method("floordiv", "__floordiv__")
def floordiv(self, call: Call) -> Type:
return self._element_wise(call, "__floordiv__")
@method("mod", "__mod__")
def mod(self, call: Call) -> Type:
return self._element_wise(call, "__mod__")
@method("pow", "__pow__")
def pow(self, call: Call) -> Type:
return self._element_wise(call, "__pow__")
@method()
def mean(self, call: Call) -> Type:
with_axis = Function(