feat(checker): add add/mean/groupby on columns
This commit is contained in:
@@ -1,13 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frames.utils import MethodRegistry
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
Function,
|
||||
GenericType,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -24,4 +34,183 @@ class Call:
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
|
||||
class ColumnMethodRegistry(MethodRegistry[Call]): ...
|
||||
class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
# TODO: support add with scalar
|
||||
# TODO: check operation exists on inner column types
|
||||
|
||||
column2: Optional[ColumnType] = None
|
||||
|
||||
col_type1: Type = call.column.type
|
||||
new_column: Type = ColumnType(type=UnknownType())
|
||||
if len(call.positional) != 0:
|
||||
other: Type = call.positional[0][1]
|
||||
unfolded_other: Type = unfold_type(other)
|
||||
if isinstance(unfolded_other, ColumnType):
|
||||
column2 = unfolded_other
|
||||
col_type2: Type = column2.type
|
||||
if self.types.are_equivalent(col_type2, col_type1):
|
||||
new_column = ColumnType(type=col_type1)
|
||||
|
||||
# Build signature with new column type and generic operand
|
||||
param_type: TypeVar = TypeVar(name="T", bound=None)
|
||||
signature = GenericType(
|
||||
name="add",
|
||||
params=[param_type],
|
||||
body=Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="other",
|
||||
type=ColumnType(type=param_type),
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
returns=new_column,
|
||||
),
|
||||
)
|
||||
|
||||
# Map arguments and compute result type
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
if result.is_valid:
|
||||
self._assert_same_length(
|
||||
call.call_expr, call.column_expr, call.positional[0][0]
|
||||
)
|
||||
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
returns=ColumnType(type=TopType()),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def groupby(self, call: Call) -> Type:
|
||||
bool_: Type = self.types.get_type("bool")
|
||||
function: Function = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="by",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="level",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=2,
|
||||
name="as_index",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=3,
|
||||
name="sort",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=4,
|
||||
name="group_keys",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=5,
|
||||
name="observed",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=6,
|
||||
name="dropna",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=ColumnGroupBy(column=call.column),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=function,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def _assert_same_length(self, call_expr: p.Expr, column1: p.Expr, column2: p.Expr):
|
||||
func_name: str = "__midas_column_same_length__"
|
||||
self.assertions.define(
|
||||
func_name,
|
||||
ast.FunctionDef(
|
||||
name=func_name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
ast.arg(arg="column1"),
|
||||
ast.arg(arg="column2"),
|
||||
],
|
||||
kwonlyargs=[],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
),
|
||||
body=[
|
||||
ast.Return(
|
||||
value=ast.Compare(
|
||||
left=ast.Attribute(
|
||||
value=ast.Name(id="column1"),
|
||||
attr="size",
|
||||
),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="column2"),
|
||||
attr="size",
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
],
|
||||
decorator_list=[],
|
||||
),
|
||||
)
|
||||
self.assertions.add(
|
||||
bound_expr=call_expr,
|
||||
inputs=[column1, column2],
|
||||
builder=lambda c1, c2: ast.Call(
|
||||
func=ast.Name(id=func_name),
|
||||
args=[c1, c2],
|
||||
keywords=[],
|
||||
),
|
||||
message="Columns must have the same length",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user