fix(checker): delegate frame aggregate methods to columns
This commit is contained in:
@@ -5,9 +5,15 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import FrameGroupBy, Function, Type
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
FrameGroupBy,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
@@ -35,161 +41,63 @@ class FrameGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
"engine_kwargs": "dict",
|
||||
}
|
||||
|
||||
def _aggregate(
|
||||
self, call: Call, args: list[str | tuple[str, str, bool]] = []
|
||||
) -> Type:
|
||||
real_args: list[Function.Argument] = []
|
||||
for i, arg in enumerate(args):
|
||||
match arg:
|
||||
case str() as name:
|
||||
arg = Function.Argument(
|
||||
pos=i,
|
||||
name=name,
|
||||
type=self.types.get_type(self.NAMED_ARGS[name]),
|
||||
required=False,
|
||||
)
|
||||
case (name, type, required):
|
||||
arg = Function.Argument(
|
||||
pos=i,
|
||||
name=name,
|
||||
type=self.types.get_type(type),
|
||||
required=required,
|
||||
)
|
||||
real_args.append(arg)
|
||||
def _aggregate(self, call: Call, method: str) -> Type:
|
||||
new_columns: list[DataFrameType.Column] = []
|
||||
|
||||
signature = Function(
|
||||
args=real_args,
|
||||
returns=call.groupby.frame,
|
||||
)
|
||||
for column in call.groupby.frame.columns:
|
||||
column_groupby: ColumnGroupBy = ColumnGroupBy(column=column.type)
|
||||
result_type: Type = self.typer.call_method(
|
||||
location=call.location,
|
||||
call_expr=call.call_expr,
|
||||
obj=(call.groupby_expr, column_groupby),
|
||||
method_name=method,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
if not isinstance(result_type, ColumnType):
|
||||
result_type = ColumnType(type=UnknownType())
|
||||
new_columns.append(
|
||||
DataFrameType.Column(
|
||||
index=column.index,
|
||||
name=column.name,
|
||||
type=result_type,
|
||||
)
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
return DataFrameType(columns=new_columns)
|
||||
|
||||
@method()
|
||||
def kurt(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"skipna",
|
||||
"numeric_only",
|
||||
],
|
||||
)
|
||||
return self._aggregate(call, "kurt")
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
)
|
||||
return self._aggregate(call, "max")
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["numeric_only", "skipna", "engine", "engine_kwargs"],
|
||||
)
|
||||
return self._aggregate(call, "mean")
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["numeric_only", "skipna"],
|
||||
)
|
||||
return self._aggregate(call, "median")
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
)
|
||||
return self._aggregate(call, "min")
|
||||
|
||||
@method()
|
||||
def prod(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
return self._aggregate(call, "prod")
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
(
|
||||
"ddof",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
"numeric_only",
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
return self._aggregate(call, "std")
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
)
|
||||
return self._aggregate(call, "sum")
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
(
|
||||
"var",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
"numeric_only",
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
return self._aggregate(call, "var")
|
||||
|
||||
@@ -222,7 +222,7 @@ class PythonTyper(
|
||||
method_name: str,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Optional[Type]:
|
||||
) -> Type:
|
||||
unfolded: Type = unfold_type(obj[1])
|
||||
match unfolded:
|
||||
case DataFrameType():
|
||||
@@ -580,9 +580,8 @@ class PythonTyper(
|
||||
right: TypedExpr,
|
||||
method: str,
|
||||
) -> Type:
|
||||
result: Optional[Type]
|
||||
try:
|
||||
result = self.call_method(
|
||||
return self.call_method(
|
||||
location=location,
|
||||
call_expr=expr,
|
||||
obj=left,
|
||||
@@ -597,8 +596,6 @@ class PythonTyper(
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return result or UnknownType()
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
||||
method: Optional[str] = PY_UNARY_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
@@ -610,9 +607,8 @@ class PythonTyper(
|
||||
|
||||
operand: Type = self.type_of(expr.right)
|
||||
|
||||
result: Optional[Type]
|
||||
try:
|
||||
result = self.call_method(
|
||||
return self.call_method(
|
||||
location=expr.location,
|
||||
call_expr=expr,
|
||||
obj=(expr.right, operand),
|
||||
@@ -627,8 +623,6 @@ class PythonTyper(
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return result or UnknownType()
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||
match expr.callee:
|
||||
case p.VariableExpr(name="TypeVar"):
|
||||
@@ -644,16 +638,13 @@ class PythonTyper(
|
||||
match expr.callee:
|
||||
case p.GetExpr(object=obj, name=method):
|
||||
obj_type: Type = self.type_of(obj)
|
||||
return (
|
||||
self.call_method(
|
||||
location=expr.location,
|
||||
call_expr=expr,
|
||||
obj=(obj, obj_type),
|
||||
method_name=method,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
or UnknownType()
|
||||
return self.call_method(
|
||||
location=expr.location,
|
||||
call_expr=expr,
|
||||
obj=(obj, obj_type),
|
||||
method_name=method,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
callee: Type = self.type_of(expr.callee)
|
||||
|
||||
@@ -2150,18 +2150,14 @@
|
||||
"index": 0,
|
||||
"name": "a",
|
||||
"type": {
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
"type": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"name": "b",
|
||||
"type": {
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
"type": {}
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -2300,18 +2296,14 @@
|
||||
"index": 0,
|
||||
"name": "a",
|
||||
"type": {
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
"type": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"name": "b",
|
||||
"type": {
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
"type": {}
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -2525,18 +2517,14 @@
|
||||
"index": 0,
|
||||
"name": "a",
|
||||
"type": {
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
"type": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"name": "b",
|
||||
"type": {
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
"type": {}
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -2600,18 +2588,14 @@
|
||||
"index": 0,
|
||||
"name": "a",
|
||||
"type": {
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
"type": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"name": "b",
|
||||
"type": {
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
"type": {}
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -2675,18 +2659,14 @@
|
||||
"index": 0,
|
||||
"name": "a",
|
||||
"type": {
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
"type": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"name": "b",
|
||||
"type": {
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
"type": {}
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -2750,18 +2730,14 @@
|
||||
"index": 0,
|
||||
"name": "a",
|
||||
"type": {
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
"type": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"name": "b",
|
||||
"type": {
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
"type": {}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user