feat(checker): add aggregation ops on frame groupby

This commit is contained in:
2026-07-03 02:20:51 +02:00
parent 0c70048b62
commit a143972ef1
4 changed files with 947 additions and 41 deletions

View File

@@ -28,36 +28,37 @@ class Call:
class FrameGroupByMethodRegistry(MethodRegistry[Call]):
@method()
def mean(self, call: Call) -> Type:
bool_ = self.types.get_type("bool")
NAMED_ARGS: dict[str, str] = {
"numeric_only": "bool",
"skipna": "bool",
"engine": "str",
"engine_kwargs": "dict",
}
def _aggregate(
self, call: Call, args: list[str | tuple[str, str, bool]] = []
) -> Type:
real_args: list[Function.Argument] = []
for i, arg in enumerate(args):
match arg:
case str() as name:
arg = Function.Argument(
pos=i,
name=name,
type=self.types.get_type(self.NAMED_ARGS[name]),
required=False,
)
case (name, type, required):
arg = Function.Argument(
pos=i,
name=name,
type=self.types.get_type(type),
required=required,
)
real_args.append(arg)
signature = Function(
args=[
Function.Argument(
pos=0,
name="numeric_only",
type=bool_,
required=False,
),
Function.Argument(
pos=1,
name="skipna",
type=bool_,
required=False,
),
Function.Argument(
pos=2,
name="engine",
type=self.types.get_type("str"),
required=False,
),
Function.Argument(
pos=3,
name="engine_kwargs",
type=self.types.get_type("dict"),
required=False,
),
],
args=real_args,
returns=call.groupby.frame,
)
@@ -68,3 +69,127 @@ class FrameGroupByMethodRegistry(MethodRegistry[Call]):
keywords=call.keywords,
)
return result.result
@method()
def kurt(self, call: Call) -> Type:
return self._aggregate(
call,
[
"skipna",
"numeric_only",
],
)
@method()
def max(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
)
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(
call,
["numeric_only", "skipna", "engine", "engine_kwargs"],
)
@method()
def median(self, call: Call) -> Type:
return self._aggregate(
call,
["numeric_only", "skipna"],
)
@method()
def min(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
)
@method()
def prod(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
],
)
@method()
def std(self, call: Call) -> Type:
return self._aggregate(
call,
[
(
"ddof",
"int",
False,
),
"engine",
"engine_kwargs",
"numeric_only",
"skipna",
],
)
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
)
@method()
def var(self, call: Call) -> Type:
return self._aggregate(
call,
[
(
"var",
"int",
False,
),
"engine",
"engine_kwargs",
"numeric_only",
"skipna",
],
)

View File

@@ -227,7 +227,7 @@ class FrameMethodRegistry(MethodRegistry[Call]):
def eq(self, call: Call) -> Type:
return self._element_wise(call, "__eq__")
def _statistical(self, call: Call, kwargs: list[Function.Argument] = []) -> Type:
def _aggregate(self, call: Call, kwargs: list[Function.Argument] = []) -> Type:
with_axis = Function(
kw_args=[
Function.Argument(
@@ -269,35 +269,35 @@ class FrameMethodRegistry(MethodRegistry[Call]):
@method("kurtosis", "kurt")
def kurtosis(self, call: Call) -> Type:
return self._statistical(call)
return self._aggregate(call)
@method()
def max(self, call: Call) -> Type:
return self._statistical(call)
return self._aggregate(call)
@method()
def mean(self, call: Call) -> Type:
return self._statistical(call)
return self._aggregate(call)
@method()
def median(self, call: Call) -> Type:
return self._statistical(call)
return self._aggregate(call)
@method()
def min(self, call: Call) -> Type:
return self._statistical(call)
return self._aggregate(call)
@method()
def mode(self, call: Call) -> Type:
return self._statistical(call)
return self._aggregate(call)
@method("product", "prod")
def product(self, call: Call) -> Type:
return self._statistical(call)
return self._aggregate(call)
@method()
def std(self, call: Call) -> Type:
return self._statistical(
return self._aggregate(
call,
[
Function.Argument(
@@ -311,11 +311,11 @@ class FrameMethodRegistry(MethodRegistry[Call]):
@method()
def sum(self, call: Call) -> Type:
return self._statistical(call)
return self._aggregate(call)
@method()
def var(self, call: Call) -> Type:
return self._statistical(
return self._aggregate(
call,
[
Function.Argument(

View File

@@ -23,7 +23,7 @@ _ = df1 >= df2
_ = df1 != df2
_ = df1 == df2
# Statistical
# Aggregate
_ = df1.kurt()
_ = df1.kurtosis()
_ = df1.max()
@@ -36,3 +36,16 @@ _ = df1.product()
_ = df1.std()
_ = df1.sum()
_ = df1.var()
# Groupby
gb = df1.groupby(by="a")
_ = gb.kurt()
_ = gb.max()
_ = gb.mean()
_ = gb.median()
_ = gb.min()
_ = gb.prod()
_ = gb.std()
_ = gb.sum()
_ = gb.var()

View File

@@ -1998,6 +1998,774 @@
"type": {
"type": {}
}
},
{
"location": {
"from": "L41:20",
"to": "L41:23"
},
"expr": {
"_type": "LiteralExpr",
"value": "a"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L41:5",
"to": "L41:8"
},
"expr": {
"_type": "VariableExpr",
"name": "df1"
},
"type": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
},
{
"location": {
"from": "L41:5",
"to": "L41:24"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "GetExpr",
"object": {
"_type": "VariableExpr",
"name": "df1"
},
"name": "groupby"
},
"arguments": [],
"keywords": {
"by": {
"_type": "LiteralExpr",
"value": "a"
}
}
},
"type": {
"frame": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
}
},
{
"location": {
"from": "L43:4",
"to": "L43:6"
},
"expr": {
"_type": "VariableExpr",
"name": "gb"
},
"type": {
"frame": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
}
},
{
"location": {
"from": "L43:4",
"to": "L43:13"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "GetExpr",
"object": {
"_type": "VariableExpr",
"name": "gb"
},
"name": "kurt"
},
"arguments": [],
"keywords": {}
},
"type": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
},
{
"location": {
"from": "L44:4",
"to": "L44:6"
},
"expr": {
"_type": "VariableExpr",
"name": "gb"
},
"type": {
"frame": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
}
},
{
"location": {
"from": "L44:4",
"to": "L44:12"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "GetExpr",
"object": {
"_type": "VariableExpr",
"name": "gb"
},
"name": "max"
},
"arguments": [],
"keywords": {}
},
"type": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
},
{
"location": {
"from": "L45:4",
"to": "L45:6"
},
"expr": {
"_type": "VariableExpr",
"name": "gb"
},
"type": {
"frame": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
}
},
{
"location": {
"from": "L45:4",
"to": "L45:13"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "GetExpr",
"object": {
"_type": "VariableExpr",
"name": "gb"
},
"name": "mean"
},
"arguments": [],
"keywords": {}
},
"type": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
},
{
"location": {
"from": "L46:4",
"to": "L46:6"
},
"expr": {
"_type": "VariableExpr",
"name": "gb"
},
"type": {
"frame": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
}
},
{
"location": {
"from": "L46:4",
"to": "L46:15"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "GetExpr",
"object": {
"_type": "VariableExpr",
"name": "gb"
},
"name": "median"
},
"arguments": [],
"keywords": {}
},
"type": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
},
{
"location": {
"from": "L47:4",
"to": "L47:6"
},
"expr": {
"_type": "VariableExpr",
"name": "gb"
},
"type": {
"frame": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
}
},
{
"location": {
"from": "L47:4",
"to": "L47:12"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "GetExpr",
"object": {
"_type": "VariableExpr",
"name": "gb"
},
"name": "min"
},
"arguments": [],
"keywords": {}
},
"type": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
},
{
"location": {
"from": "L48:4",
"to": "L48:6"
},
"expr": {
"_type": "VariableExpr",
"name": "gb"
},
"type": {
"frame": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
}
},
{
"location": {
"from": "L48:4",
"to": "L48:13"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "GetExpr",
"object": {
"_type": "VariableExpr",
"name": "gb"
},
"name": "prod"
},
"arguments": [],
"keywords": {}
},
"type": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
},
{
"location": {
"from": "L49:4",
"to": "L49:6"
},
"expr": {
"_type": "VariableExpr",
"name": "gb"
},
"type": {
"frame": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
}
},
{
"location": {
"from": "L49:4",
"to": "L49:12"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "GetExpr",
"object": {
"_type": "VariableExpr",
"name": "gb"
},
"name": "std"
},
"arguments": [],
"keywords": {}
},
"type": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
},
{
"location": {
"from": "L50:4",
"to": "L50:6"
},
"expr": {
"_type": "VariableExpr",
"name": "gb"
},
"type": {
"frame": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
}
},
{
"location": {
"from": "L50:4",
"to": "L50:12"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "GetExpr",
"object": {
"_type": "VariableExpr",
"name": "gb"
},
"name": "sum"
},
"arguments": [],
"keywords": {}
},
"type": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
},
{
"location": {
"from": "L51:4",
"to": "L51:6"
},
"expr": {
"_type": "VariableExpr",
"name": "gb"
},
"type": {
"frame": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
}
},
{
"location": {
"from": "L51:4",
"to": "L51:12"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "GetExpr",
"object": {
"_type": "VariableExpr",
"name": "gb"
},
"name": "var"
},
"arguments": [],
"keywords": {}
},
"type": {
"columns": [
{
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
}
}
]
}
}
]
}