diff --git a/midas/checker/frames/frame_groupby_methods.py b/midas/checker/frames/frame_groupby_methods.py index d3f882d..615b2b3 100644 --- a/midas/checker/frames/frame_groupby_methods.py +++ b/midas/checker/frames/frame_groupby_methods.py @@ -28,36 +28,37 @@ class Call: class FrameGroupByMethodRegistry(MethodRegistry[Call]): - @method() - def mean(self, call: Call) -> Type: - bool_ = self.types.get_type("bool") + NAMED_ARGS: dict[str, str] = { + "numeric_only": "bool", + "skipna": "bool", + "engine": "str", + "engine_kwargs": "dict", + } + + def _aggregate( + self, call: Call, args: list[str | tuple[str, str, bool]] = [] + ) -> Type: + real_args: list[Function.Argument] = [] + for i, arg in enumerate(args): + match arg: + case str() as name: + arg = Function.Argument( + pos=i, + name=name, + type=self.types.get_type(self.NAMED_ARGS[name]), + required=False, + ) + case (name, type, required): + arg = Function.Argument( + pos=i, + name=name, + type=self.types.get_type(type), + required=required, + ) + real_args.append(arg) + signature = Function( - args=[ - Function.Argument( - pos=0, - name="numeric_only", - type=bool_, - required=False, - ), - Function.Argument( - pos=1, - name="skipna", - type=bool_, - required=False, - ), - Function.Argument( - pos=2, - name="engine", - type=self.types.get_type("str"), - required=False, - ), - Function.Argument( - pos=3, - name="engine_kwargs", - type=self.types.get_type("dict"), - required=False, - ), - ], + args=real_args, returns=call.groupby.frame, ) @@ -68,3 +69,127 @@ class FrameGroupByMethodRegistry(MethodRegistry[Call]): keywords=call.keywords, ) return result.result + + @method() + def kurt(self, call: Call) -> Type: + return self._aggregate( + call, + [ + "skipna", + "numeric_only", + ], + ) + + @method() + def max(self, call: Call) -> Type: + return self._aggregate( + call, + [ + "numeric_only", + ( + "min_count", + "int", + False, + ), + "skipna", + "engine", + "engine_kwargs", + ], + ) + + @method() + def mean(self, call: Call) -> Type: + return self._aggregate( + call, + ["numeric_only", "skipna", "engine", "engine_kwargs"], + ) + + @method() + def median(self, call: Call) -> Type: + return self._aggregate( + call, + ["numeric_only", "skipna"], + ) + + @method() + def min(self, call: Call) -> Type: + return self._aggregate( + call, + [ + "numeric_only", + ( + "min_count", + "int", + False, + ), + "skipna", + "engine", + "engine_kwargs", + ], + ) + + @method() + def prod(self, call: Call) -> Type: + return self._aggregate( + call, + [ + "numeric_only", + ( + "min_count", + "int", + False, + ), + "skipna", + ], + ) + + @method() + def std(self, call: Call) -> Type: + return self._aggregate( + call, + [ + ( + "ddof", + "int", + False, + ), + "engine", + "engine_kwargs", + "numeric_only", + "skipna", + ], + ) + + @method() + def sum(self, call: Call) -> Type: + return self._aggregate( + call, + [ + "numeric_only", + ( + "min_count", + "int", + False, + ), + "skipna", + "engine", + "engine_kwargs", + ], + ) + + @method() + def var(self, call: Call) -> Type: + return self._aggregate( + call, + [ + ( + "var", + "int", + False, + ), + "engine", + "engine_kwargs", + "numeric_only", + "skipna", + ], + ) diff --git a/midas/checker/frames/frame_methods.py b/midas/checker/frames/frame_methods.py index 5c7b4b4..2bd8c8b 100644 --- a/midas/checker/frames/frame_methods.py +++ b/midas/checker/frames/frame_methods.py @@ -227,7 +227,7 @@ class FrameMethodRegistry(MethodRegistry[Call]): def eq(self, call: Call) -> Type: return self._element_wise(call, "__eq__") - def _statistical(self, call: Call, kwargs: list[Function.Argument] = []) -> Type: + def _aggregate(self, call: Call, kwargs: list[Function.Argument] = []) -> Type: with_axis = Function( kw_args=[ Function.Argument( @@ -269,35 +269,35 @@ class FrameMethodRegistry(MethodRegistry[Call]): @method("kurtosis", "kurt") def kurtosis(self, call: Call) -> Type: - return self._statistical(call) + return self._aggregate(call) @method() def max(self, call: Call) -> Type: - return self._statistical(call) + return self._aggregate(call) @method() def mean(self, call: Call) -> Type: - return self._statistical(call) + return self._aggregate(call) @method() def median(self, call: Call) -> Type: - return self._statistical(call) + return self._aggregate(call) @method() def min(self, call: Call) -> Type: - return self._statistical(call) + return self._aggregate(call) @method() def mode(self, call: Call) -> Type: - return self._statistical(call) + return self._aggregate(call) @method("product", "prod") def product(self, call: Call) -> Type: - return self._statistical(call) + return self._aggregate(call) @method() def std(self, call: Call) -> Type: - return self._statistical( + return self._aggregate( call, [ Function.Argument( @@ -311,11 +311,11 @@ class FrameMethodRegistry(MethodRegistry[Call]): @method() def sum(self, call: Call) -> Type: - return self._statistical(call) + return self._aggregate(call) @method() def var(self, call: Call) -> Type: - return self._statistical( + return self._aggregate( call, [ Function.Argument( diff --git a/tests/cases/checker/09_frame_ops.py b/tests/cases/checker/09_frame_ops.py index 3bd25cf..0c2f00c 100644 --- a/tests/cases/checker/09_frame_ops.py +++ b/tests/cases/checker/09_frame_ops.py @@ -23,7 +23,7 @@ _ = df1 >= df2 _ = df1 != df2 _ = df1 == df2 -# Statistical +# Aggregate _ = df1.kurt() _ = df1.kurtosis() _ = df1.max() @@ -36,3 +36,16 @@ _ = df1.product() _ = df1.std() _ = df1.sum() _ = df1.var() + +# Groupby +gb = df1.groupby(by="a") + +_ = gb.kurt() +_ = gb.max() +_ = gb.mean() +_ = gb.median() +_ = gb.min() +_ = gb.prod() +_ = gb.std() +_ = gb.sum() +_ = gb.var() diff --git a/tests/cases/checker/09_frame_ops.py.ref.json b/tests/cases/checker/09_frame_ops.py.ref.json index 551b8dc..d7e9b4a 100644 --- a/tests/cases/checker/09_frame_ops.py.ref.json +++ b/tests/cases/checker/09_frame_ops.py.ref.json @@ -1998,6 +1998,774 @@ "type": { "type": {} } + }, + { + "location": { + "from": "L41:20", + "to": "L41:23" + }, + "expr": { + "_type": "LiteralExpr", + "value": "a" + }, + "type": { + "name": "str" + } + }, + { + "location": { + "from": "L41:5", + "to": "L41:8" + }, + "expr": { + "_type": "VariableExpr", + "name": "df1" + }, + "type": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + }, + { + "location": { + "from": "L41:5", + "to": "L41:24" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "GetExpr", + "object": { + "_type": "VariableExpr", + "name": "df1" + }, + "name": "groupby" + }, + "arguments": [], + "keywords": { + "by": { + "_type": "LiteralExpr", + "value": "a" + } + } + }, + "type": { + "frame": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + } + }, + { + "location": { + "from": "L43:4", + "to": "L43:6" + }, + "expr": { + "_type": "VariableExpr", + "name": "gb" + }, + "type": { + "frame": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + } + }, + { + "location": { + "from": "L43:4", + "to": "L43:13" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "GetExpr", + "object": { + "_type": "VariableExpr", + "name": "gb" + }, + "name": "kurt" + }, + "arguments": [], + "keywords": {} + }, + "type": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + }, + { + "location": { + "from": "L44:4", + "to": "L44:6" + }, + "expr": { + "_type": "VariableExpr", + "name": "gb" + }, + "type": { + "frame": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + } + }, + { + "location": { + "from": "L44:4", + "to": "L44:12" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "GetExpr", + "object": { + "_type": "VariableExpr", + "name": "gb" + }, + "name": "max" + }, + "arguments": [], + "keywords": {} + }, + "type": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + }, + { + "location": { + "from": "L45:4", + "to": "L45:6" + }, + "expr": { + "_type": "VariableExpr", + "name": "gb" + }, + "type": { + "frame": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + } + }, + { + "location": { + "from": "L45:4", + "to": "L45:13" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "GetExpr", + "object": { + "_type": "VariableExpr", + "name": "gb" + }, + "name": "mean" + }, + "arguments": [], + "keywords": {} + }, + "type": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + }, + { + "location": { + "from": "L46:4", + "to": "L46:6" + }, + "expr": { + "_type": "VariableExpr", + "name": "gb" + }, + "type": { + "frame": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + } + }, + { + "location": { + "from": "L46:4", + "to": "L46:15" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "GetExpr", + "object": { + "_type": "VariableExpr", + "name": "gb" + }, + "name": "median" + }, + "arguments": [], + "keywords": {} + }, + "type": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + }, + { + "location": { + "from": "L47:4", + "to": "L47:6" + }, + "expr": { + "_type": "VariableExpr", + "name": "gb" + }, + "type": { + "frame": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + } + }, + { + "location": { + "from": "L47:4", + "to": "L47:12" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "GetExpr", + "object": { + "_type": "VariableExpr", + "name": "gb" + }, + "name": "min" + }, + "arguments": [], + "keywords": {} + }, + "type": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + }, + { + "location": { + "from": "L48:4", + "to": "L48:6" + }, + "expr": { + "_type": "VariableExpr", + "name": "gb" + }, + "type": { + "frame": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + } + }, + { + "location": { + "from": "L48:4", + "to": "L48:13" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "GetExpr", + "object": { + "_type": "VariableExpr", + "name": "gb" + }, + "name": "prod" + }, + "arguments": [], + "keywords": {} + }, + "type": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + }, + { + "location": { + "from": "L49:4", + "to": "L49:6" + }, + "expr": { + "_type": "VariableExpr", + "name": "gb" + }, + "type": { + "frame": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + } + }, + { + "location": { + "from": "L49:4", + "to": "L49:12" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "GetExpr", + "object": { + "_type": "VariableExpr", + "name": "gb" + }, + "name": "std" + }, + "arguments": [], + "keywords": {} + }, + "type": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + }, + { + "location": { + "from": "L50:4", + "to": "L50:6" + }, + "expr": { + "_type": "VariableExpr", + "name": "gb" + }, + "type": { + "frame": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + } + }, + { + "location": { + "from": "L50:4", + "to": "L50:12" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "GetExpr", + "object": { + "_type": "VariableExpr", + "name": "gb" + }, + "name": "sum" + }, + "arguments": [], + "keywords": {} + }, + "type": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + }, + { + "location": { + "from": "L51:4", + "to": "L51:6" + }, + "expr": { + "_type": "VariableExpr", + "name": "gb" + }, + "type": { + "frame": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } + } + }, + { + "location": { + "from": "L51:4", + "to": "L51:12" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "GetExpr", + "object": { + "_type": "VariableExpr", + "name": "gb" + }, + "name": "var" + }, + "arguments": [], + "keywords": {} + }, + "type": { + "columns": [ + { + "index": 0, + "name": "a", + "type": { + "type": { + "name": "int" + } + } + }, + { + "index": 1, + "name": "b", + "type": { + "type": { + "name": "float" + } + } + } + ] + } } ] } \ No newline at end of file