feat(types): add FrameGroupBy type

This commit is contained in:
2026-07-02 17:45:18 +02:00
parent 5d20f8ec3e
commit b5acae4078
4 changed files with 97 additions and 4 deletions

View File

@@ -12,6 +12,7 @@ from midas.checker.reporter import FileReporter
from midas.checker.types import (
ColumnType,
DataFrameType,
FrameGroupBy,
Function,
OverloadedFunction,
TopType,
@@ -222,6 +223,67 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
)
return result.result
@frame_method()
def groupby(self, call: Call) -> Type:
bool_: Type = self.types.get_type("bool")
function: Function = Function(
args=[
Function.Argument(
pos=0,
name="by",
type=TopType(),
required=False,
),
Function.Argument(
pos=1,
name="level",
type=TopType(),
required=False,
),
],
kw_args=[
Function.Argument(
pos=2,
name="as_index",
type=bool_,
required=False,
),
Function.Argument(
pos=3,
name="sort",
type=bool_,
required=False,
),
Function.Argument(
pos=4,
name="group_keys",
type=bool_,
required=False,
),
Function.Argument(
pos=5,
name="observed",
type=bool_,
required=False,
),
Function.Argument(
pos=6,
name="dropna",
type=bool_,
required=False,
),
],
returns=FrameGroupBy(frame=call.frame),
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=function,
positional=call.positional,
keywords=call.keywords,
)
return result.result
def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr):
func_name: str = "__midas_frame_same_length__"
self.assertions.define(

View File

@@ -187,6 +187,14 @@ class DataFrameType:
type: ColumnType
@dataclass(frozen=True, kw_only=True)
class FrameGroupBy:
frame: DataFrameType
def __str__(self) -> str:
return f"FrameGroupBy[{self.frame}]"
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def sub_argument(arg: Function.Argument):
return Function.Argument(
@@ -305,11 +313,15 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
columns=list(map(sub_column, columns)),
)
case FrameGroupBy(frame=frame):
return FrameGroupBy(
frame=cast(DataFrameType, substitute_typevars(frame, substitutions))
)
case UnknownType() | UnitType():
return type
case TopType() | GenericType():
raise NotImplementedError(f"Unsupported type {type}")
# Ensure exhaustiveness
@@ -382,6 +394,9 @@ def to_annotation(type: Type) -> str:
case DataFrameType():
return "pd.DataFrame"
case FrameGroupBy():
return "pd.api.typing.DataFrameGroupBy"
case _:
assert_never(type)
@@ -410,4 +425,5 @@ Type = (
| TupleType
| ColumnType
| DataFrameType
| FrameGroupBy
)

View File

@@ -20,6 +20,7 @@ from midas.checker.types import (
DataFrameType,
DerivedType,
ExtensionType,
FrameGroupBy,
Function,
GenericType,
OverloadedFunction,
@@ -366,7 +367,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self, src_location: Location, expr: ast.expr, type: Type
) -> list[ast.stmt]:
match type:
case UnknownType():
case UnknownType() | TopType():
return []
case BaseType(name=name):
@@ -497,12 +498,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
return asserts
case (
TopType()
| Function()
Function()
| OverloadedFunction()
| ComplexType()
| ExtensionType()
| GenericType()
| FrameGroupBy()
):
self.logger.warning(f"Can't make assertion for type {type}")
return []

View File

@@ -12,6 +12,7 @@ from midas.checker.types import (
DataFrameType,
DerivedType,
ExtensionType,
FrameGroupBy,
Function,
GenericType,
OverloadedFunction,
@@ -287,6 +288,19 @@ class StubsGenerator:
attr="DataFrame",
)
case FrameGroupBy():
self.import_pandas = True
return ast.Attribute(
value=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id="pd"),
attr="api",
),
attr="typing",
),
attr="DataFrameGroupBy",
)
case _:
assert_never(type)