fix(checker): delegate frame aggregate methods to columns

This commit is contained in:
2026-07-03 11:42:35 +02:00
parent 733c8736b8
commit 7a6e01cff8
3 changed files with 61 additions and 186 deletions

View File

@@ -5,9 +5,15 @@ from typing import TYPE_CHECKING
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.dispatcher import CallResult
from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import FrameGroupBy, Function, Type
from midas.checker.types import (
ColumnGroupBy,
ColumnType,
DataFrameType,
FrameGroupBy,
Type,
UnknownType,
)
if TYPE_CHECKING:
from midas.checker.python import TypedExpr
@@ -35,161 +41,63 @@ class FrameGroupByMethodRegistry(MethodRegistry[Call]):
"engine_kwargs": "dict",
}
def _aggregate(
self, call: Call, args: list[str | tuple[str, str, bool]] = []
) -> Type:
real_args: list[Function.Argument] = []
for i, arg in enumerate(args):
match arg:
case str() as name:
arg = Function.Argument(
pos=i,
name=name,
type=self.types.get_type(self.NAMED_ARGS[name]),
required=False,
)
case (name, type, required):
arg = Function.Argument(
pos=i,
name=name,
type=self.types.get_type(type),
required=required,
)
real_args.append(arg)
def _aggregate(self, call: Call, method: str) -> Type:
new_columns: list[DataFrameType.Column] = []
signature = Function(
args=real_args,
returns=call.groupby.frame,
)
for column in call.groupby.frame.columns:
column_groupby: ColumnGroupBy = ColumnGroupBy(column=column.type)
result_type: Type = self.typer.call_method(
location=call.location,
call_expr=call.call_expr,
obj=(call.groupby_expr, column_groupby),
method_name=method,
positional=call.positional,
keywords=call.keywords,
)
if not isinstance(result_type, ColumnType):
result_type = ColumnType(type=UnknownType())
new_columns.append(
DataFrameType.Column(
index=column.index,
name=column.name,
type=result_type,
)
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
return DataFrameType(columns=new_columns)
@method()
def kurt(self, call: Call) -> Type:
return self._aggregate(
call,
[
"skipna",
"numeric_only",
],
)
return self._aggregate(call, "kurt")
@method()
def max(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
)
return self._aggregate(call, "max")
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(
call,
["numeric_only", "skipna", "engine", "engine_kwargs"],
)
return self._aggregate(call, "mean")
@method()
def median(self, call: Call) -> Type:
return self._aggregate(
call,
["numeric_only", "skipna"],
)
return self._aggregate(call, "median")
@method()
def min(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
)
return self._aggregate(call, "min")
@method()
def prod(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
],
)
return self._aggregate(call, "prod")
@method()
def std(self, call: Call) -> Type:
return self._aggregate(
call,
[
(
"ddof",
"int",
False,
),
"engine",
"engine_kwargs",
"numeric_only",
"skipna",
],
)
return self._aggregate(call, "std")
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
)
return self._aggregate(call, "sum")
@method()
def var(self, call: Call) -> Type:
return self._aggregate(
call,
[
(
"var",
"int",
False,
),
"engine",
"engine_kwargs",
"numeric_only",
"skipna",
],
)
return self._aggregate(call, "var")

View File

@@ -222,7 +222,7 @@ class PythonTyper(
method_name: str,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Optional[Type]:
) -> Type:
unfolded: Type = unfold_type(obj[1])
match unfolded:
case DataFrameType():
@@ -580,9 +580,8 @@ class PythonTyper(
right: TypedExpr,
method: str,
) -> Type:
result: Optional[Type]
try:
result = self.call_method(
return self.call_method(
location=location,
call_expr=expr,
obj=left,
@@ -597,8 +596,6 @@ class PythonTyper(
)
return UnknownType()
return result or UnknownType()
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
method: Optional[str] = PY_UNARY_METHODS.get(expr.operator.__class__)
if method is None:
@@ -610,9 +607,8 @@ class PythonTyper(
operand: Type = self.type_of(expr.right)
result: Optional[Type]
try:
result = self.call_method(
return self.call_method(
location=expr.location,
call_expr=expr,
obj=(expr.right, operand),
@@ -627,8 +623,6 @@ class PythonTyper(
)
return UnknownType()
return result or UnknownType()
def visit_call_expr(self, expr: p.CallExpr) -> Type:
match expr.callee:
case p.VariableExpr(name="TypeVar"):
@@ -644,16 +638,13 @@ class PythonTyper(
match expr.callee:
case p.GetExpr(object=obj, name=method):
obj_type: Type = self.type_of(obj)
return (
self.call_method(
location=expr.location,
call_expr=expr,
obj=(obj, obj_type),
method_name=method,
positional=positional,
keywords=keywords,
)
or UnknownType()
return self.call_method(
location=expr.location,
call_expr=expr,
obj=(obj, obj_type),
method_name=method,
positional=positional,
keywords=keywords,
)
callee: Type = self.type_of(expr.callee)

View File

@@ -2150,18 +2150,14 @@
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
"type": {}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
"type": {}
}
}
]
@@ -2300,18 +2296,14 @@
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
"type": {}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
"type": {}
}
}
]
@@ -2525,18 +2517,14 @@
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
"type": {}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
"type": {}
}
}
]
@@ -2600,18 +2588,14 @@
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
"type": {}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
"type": {}
}
}
]
@@ -2675,18 +2659,14 @@
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
"type": {}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
"type": {}
}
}
]
@@ -2750,18 +2730,14 @@
"index": 0,
"name": "a",
"type": {
"type": {
"name": "int"
}
"type": {}
}
},
{
"index": 1,
"name": "b",
"type": {
"type": {
"name": "float"
}
"type": {}
}
}
]