diff --git a/midas/checker/frames/column_methods.py b/midas/checker/frames/column_methods.py index da22a43..91909ff 100644 --- a/midas/checker/frames/column_methods.py +++ b/midas/checker/frames/column_methods.py @@ -1,13 +1,23 @@ from __future__ import annotations +import ast from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import midas.ast.python as p from midas.ast.location import Location -from midas.checker.frames.utils import MethodRegistry +from midas.checker.dispatcher import CallResult +from midas.checker.frames.utils import MethodRegistry, method from midas.checker.types import ( + ColumnGroupBy, ColumnType, + Function, + GenericType, + TopType, + Type, + TypeVar, + UnknownType, + unfold_type, ) if TYPE_CHECKING: @@ -24,4 +34,183 @@ class Call: keywords: dict[str, TypedExpr] -class ColumnMethodRegistry(MethodRegistry[Call]): ... +class ColumnMethodRegistry(MethodRegistry[Call]): + @method("add", "__add__") + def add(self, call: Call) -> Type: + # TODO: support add with scalar + # TODO: check operation exists on inner column types + + column2: Optional[ColumnType] = None + + col_type1: Type = call.column.type + new_column: Type = ColumnType(type=UnknownType()) + if len(call.positional) != 0: + other: Type = call.positional[0][1] + unfolded_other: Type = unfold_type(other) + if isinstance(unfolded_other, ColumnType): + column2 = unfolded_other + col_type2: Type = column2.type + if self.types.are_equivalent(col_type2, col_type1): + new_column = ColumnType(type=col_type1) + + # Build signature with new column type and generic operand + param_type: TypeVar = TypeVar(name="T", bound=None) + signature = GenericType( + name="add", + params=[param_type], + body=Function( + args=[ + Function.Argument( + pos=0, + name="other", + type=ColumnType(type=param_type), + required=True, + ), + ], + returns=new_column, + ), + ) + + # Map arguments and compute result type + result: CallResult = self.dispatcher.get_result( + location=call.location, + callee=signature, + positional=call.positional, + keywords=call.keywords, + ) + if result.is_valid: + self._assert_same_length( + call.call_expr, call.column_expr, call.positional[0][0] + ) + + return result.result + + @method() + def mean(self, call: Call) -> Type: + signature = Function( + kw_args=[ + Function.Argument( + pos=0, + name="axis", + type=TopType(), + required=False, + ) + ], + returns=ColumnType(type=TopType()), + ) + + result: CallResult = self.dispatcher.get_result( + location=call.location, + callee=signature, + positional=call.positional, + keywords=call.keywords, + ) + return result.result + + @method() + def groupby(self, call: Call) -> Type: + bool_: Type = self.types.get_type("bool") + function: Function = Function( + args=[ + Function.Argument( + pos=0, + name="by", + type=TopType(), + required=False, + ), + Function.Argument( + pos=1, + name="level", + type=TopType(), + required=False, + ), + ], + kw_args=[ + Function.Argument( + pos=2, + name="as_index", + type=bool_, + required=False, + ), + Function.Argument( + pos=3, + name="sort", + type=bool_, + required=False, + ), + Function.Argument( + pos=4, + name="group_keys", + type=bool_, + required=False, + ), + Function.Argument( + pos=5, + name="observed", + type=bool_, + required=False, + ), + Function.Argument( + pos=6, + name="dropna", + type=bool_, + required=False, + ), + ], + returns=ColumnGroupBy(column=call.column), + ) + + result: CallResult = self.dispatcher.get_result( + location=call.location, + callee=function, + positional=call.positional, + keywords=call.keywords, + ) + return result.result + + def _assert_same_length(self, call_expr: p.Expr, column1: p.Expr, column2: p.Expr): + func_name: str = "__midas_column_same_length__" + self.assertions.define( + func_name, + ast.FunctionDef( + name=func_name, + args=ast.arguments( + posonlyargs=[], + args=[ + ast.arg(arg="column1"), + ast.arg(arg="column2"), + ], + kwonlyargs=[], + defaults=[], + kw_defaults=[], + ), + body=[ + ast.Return( + value=ast.Compare( + left=ast.Attribute( + value=ast.Name(id="column1"), + attr="size", + ), + ops=[ast.Eq()], + comparators=[ + ast.Attribute( + value=ast.Name(id="column2"), + attr="size", + ) + ], + ) + ) + ], + decorator_list=[], + ), + ) + self.assertions.add( + bound_expr=call_expr, + inputs=[column1, column2], + builder=lambda c1, c2: ast.Call( + func=ast.Name(id=func_name), + args=[c1, c2], + keywords=[], + ), + message="Columns must have the same length", + )