feat(types): add FrameGroupBy type
This commit is contained in:
@@ -12,6 +12,7 @@ from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import (
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
OverloadedFunction,
|
||||
TopType,
|
||||
@@ -222,6 +223,67 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
)
|
||||
return result.result
|
||||
|
||||
@frame_method()
|
||||
def groupby(self, call: Call) -> Type:
|
||||
bool_: Type = self.types.get_type("bool")
|
||||
function: Function = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="by",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="level",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=2,
|
||||
name="as_index",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=3,
|
||||
name="sort",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=4,
|
||||
name="group_keys",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=5,
|
||||
name="observed",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=6,
|
||||
name="dropna",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=FrameGroupBy(frame=call.frame),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=function,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr):
|
||||
func_name: str = "__midas_frame_same_length__"
|
||||
self.assertions.define(
|
||||
|
||||
@@ -187,6 +187,14 @@ class DataFrameType:
|
||||
type: ColumnType
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class FrameGroupBy:
|
||||
frame: DataFrameType
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"FrameGroupBy[{self.frame}]"
|
||||
|
||||
|
||||
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
def sub_argument(arg: Function.Argument):
|
||||
return Function.Argument(
|
||||
@@ -305,11 +313,15 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
columns=list(map(sub_column, columns)),
|
||||
)
|
||||
|
||||
case FrameGroupBy(frame=frame):
|
||||
return FrameGroupBy(
|
||||
frame=cast(DataFrameType, substitute_typevars(frame, substitutions))
|
||||
)
|
||||
|
||||
case UnknownType() | UnitType():
|
||||
return type
|
||||
|
||||
case TopType() | GenericType():
|
||||
|
||||
raise NotImplementedError(f"Unsupported type {type}")
|
||||
|
||||
# Ensure exhaustiveness
|
||||
@@ -382,6 +394,9 @@ def to_annotation(type: Type) -> str:
|
||||
case DataFrameType():
|
||||
return "pd.DataFrame"
|
||||
|
||||
case FrameGroupBy():
|
||||
return "pd.api.typing.DataFrameGroupBy"
|
||||
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
@@ -410,4 +425,5 @@ Type = (
|
||||
| TupleType
|
||||
| ColumnType
|
||||
| DataFrameType
|
||||
| FrameGroupBy
|
||||
)
|
||||
|
||||
@@ -20,6 +20,7 @@ from midas.checker.types import (
|
||||
DataFrameType,
|
||||
DerivedType,
|
||||
ExtensionType,
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
@@ -366,7 +367,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
self, src_location: Location, expr: ast.expr, type: Type
|
||||
) -> list[ast.stmt]:
|
||||
match type:
|
||||
case UnknownType():
|
||||
case UnknownType() | TopType():
|
||||
return []
|
||||
|
||||
case BaseType(name=name):
|
||||
@@ -497,12 +498,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
return asserts
|
||||
|
||||
case (
|
||||
TopType()
|
||||
| Function()
|
||||
Function()
|
||||
| OverloadedFunction()
|
||||
| ComplexType()
|
||||
| ExtensionType()
|
||||
| GenericType()
|
||||
| FrameGroupBy()
|
||||
):
|
||||
self.logger.warning(f"Can't make assertion for type {type}")
|
||||
return []
|
||||
|
||||
@@ -12,6 +12,7 @@ from midas.checker.types import (
|
||||
DataFrameType,
|
||||
DerivedType,
|
||||
ExtensionType,
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
@@ -287,6 +288,19 @@ class StubsGenerator:
|
||||
attr="DataFrame",
|
||||
)
|
||||
|
||||
case FrameGroupBy():
|
||||
self.import_pandas = True
|
||||
return ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="api",
|
||||
),
|
||||
attr="typing",
|
||||
),
|
||||
attr="DataFrameGroupBy",
|
||||
)
|
||||
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user