feat(checker): lookup dataframe methods
This commit is contained in:
@@ -1,11 +1,16 @@
|
|||||||
from typing import Optional, TypeGuard, cast
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional, TypeGuard, cast
|
||||||
|
|
||||||
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.frame_methods import FRAME_METHODS, FrameMethod
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.reporter import FileReporter
|
from midas.checker.reporter import FileReporter
|
||||||
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
|
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
|
||||||
|
|
||||||
import midas.ast.python as p
|
if TYPE_CHECKING:
|
||||||
|
from midas.checker.python import PythonTyper, TypedExpr
|
||||||
|
|
||||||
|
|
||||||
def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
|
def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
|
||||||
@@ -131,3 +136,16 @@ class FrameManager:
|
|||||||
cls, frame: DataFrameType, names: list[str]
|
cls, frame: DataFrameType, names: list[str]
|
||||||
) -> list[Optional[ColumnType]]:
|
) -> list[Optional[ColumnType]]:
|
||||||
return [cls._get_column(frame, name) for name in names]
|
return [cls._get_column(frame, name) for name in names]
|
||||||
|
|
||||||
|
def call_method(
|
||||||
|
self,
|
||||||
|
typer: PythonTyper,
|
||||||
|
reporter: FileReporter,
|
||||||
|
location: Location,
|
||||||
|
frame: DataFrameType,
|
||||||
|
method: str,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
) -> Type:
|
||||||
|
function: FrameMethod = FRAME_METHODS[method]
|
||||||
|
return function(typer, self, reporter, location, frame, positional, keywords)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from midas.ast.location import Location
|
|||||||
from midas.ast.printer import MidasPrinter
|
from midas.ast.printer import MidasPrinter
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
from midas.checker.evaluator import Evaluator
|
from midas.checker.evaluator import Evaluator
|
||||||
|
from midas.checker.frame_methods import FRAME_METHODS
|
||||||
from midas.checker.frames import FrameManager
|
from midas.checker.frames import FrameManager
|
||||||
from midas.checker.operators import (
|
from midas.checker.operators import (
|
||||||
PY_COMPARATOR_METHODS,
|
PY_COMPARATOR_METHODS,
|
||||||
@@ -515,13 +516,31 @@ class PythonTyper(
|
|||||||
case p.VariableExpr(name="TypeVar"):
|
case p.VariableExpr(name="TypeVar"):
|
||||||
return self.define_typevar(expr) or UnknownType()
|
return self.define_typevar(expr) or UnknownType()
|
||||||
|
|
||||||
callee: Type = self.type_of(expr.callee)
|
|
||||||
positional: list[TypedExpr] = [
|
positional: list[TypedExpr] = [
|
||||||
(arg, self.type_of(arg)) for arg in expr.arguments
|
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||||
]
|
]
|
||||||
keywords: dict[str, TypedExpr] = {
|
keywords: dict[str, TypedExpr] = {
|
||||||
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
match expr.callee:
|
||||||
|
case p.GetExpr(object=obj, name=method):
|
||||||
|
obj_type: Type = self.type_of(obj)
|
||||||
|
unfolded: Type = unfold_type(obj_type)
|
||||||
|
if isinstance(unfolded, DataFrameType) and self._is_frame_method(
|
||||||
|
method
|
||||||
|
):
|
||||||
|
return self.frame_mgr.call_method(
|
||||||
|
self,
|
||||||
|
self.reporter,
|
||||||
|
expr.location,
|
||||||
|
unfolded,
|
||||||
|
method,
|
||||||
|
positional,
|
||||||
|
keywords,
|
||||||
|
)
|
||||||
|
|
||||||
|
callee: Type = self.type_of(expr.callee)
|
||||||
return (
|
return (
|
||||||
self._get_call_result(
|
self._get_call_result(
|
||||||
location=expr.location,
|
location=expr.location,
|
||||||
@@ -1288,3 +1307,6 @@ class PythonTyper(
|
|||||||
self, frame: DataFrameType, expr: p.SubscriptExpr
|
self, frame: DataFrameType, expr: p.SubscriptExpr
|
||||||
) -> Type:
|
) -> Type:
|
||||||
return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index)
|
return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index)
|
||||||
|
|
||||||
|
def _is_frame_method(self, method: str) -> bool:
|
||||||
|
return method in FRAME_METHODS
|
||||||
|
|||||||
Reference in New Issue
Block a user