diff --git a/midas/checker/frames.py b/midas/checker/frames.py index 0dbd95b..50ae080 100644 --- a/midas/checker/frames.py +++ b/midas/checker/frames.py @@ -1,11 +1,16 @@ -from typing import Optional, TypeGuard, cast +from __future__ import annotations +from typing import TYPE_CHECKING, Optional, TypeGuard, cast + +import midas.ast.python as p from midas.ast.location import Location +from midas.checker.frame_methods import FRAME_METHODS, FrameMethod from midas.checker.registry import TypesRegistry from midas.checker.reporter import FileReporter from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType -import midas.ast.python as p +if TYPE_CHECKING: + from midas.checker.python import PythonTyper, TypedExpr def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]: @@ -131,3 +136,16 @@ class FrameManager: cls, frame: DataFrameType, names: list[str] ) -> list[Optional[ColumnType]]: return [cls._get_column(frame, name) for name in names] + + def call_method( + self, + typer: PythonTyper, + reporter: FileReporter, + location: Location, + frame: DataFrameType, + method: str, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + ) -> Type: + function: FrameMethod = FRAME_METHODS[method] + return function(typer, self, reporter, location, frame, positional, keywords) diff --git a/midas/checker/python.py b/midas/checker/python.py index cd46cdd..d61238b 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -8,6 +8,7 @@ from midas.ast.location import Location from midas.ast.printer import MidasPrinter from midas.checker.environment import Environment from midas.checker.evaluator import Evaluator +from midas.checker.frame_methods import FRAME_METHODS from midas.checker.frames import FrameManager from midas.checker.operators import ( PY_COMPARATOR_METHODS, @@ -515,13 +516,31 @@ class PythonTyper( case p.VariableExpr(name="TypeVar"): return self.define_typevar(expr) or UnknownType() - callee: Type = self.type_of(expr.callee) positional: list[TypedExpr] = [ (arg, self.type_of(arg)) for arg in expr.arguments ] keywords: dict[str, TypedExpr] = { name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items() } + + match expr.callee: + case p.GetExpr(object=obj, name=method): + obj_type: Type = self.type_of(obj) + unfolded: Type = unfold_type(obj_type) + if isinstance(unfolded, DataFrameType) and self._is_frame_method( + method + ): + return self.frame_mgr.call_method( + self, + self.reporter, + expr.location, + unfolded, + method, + positional, + keywords, + ) + + callee: Type = self.type_of(expr.callee) return ( self._get_call_result( location=expr.location, @@ -1288,3 +1307,6 @@ class PythonTyper( self, frame: DataFrameType, expr: p.SubscriptExpr ) -> Type: return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index) + + def _is_frame_method(self, method: str) -> bool: + return method in FRAME_METHODS