refactor: add MethodResolver class
This commit is contained in:
@@ -1,8 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional, Protocol
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Callable, Optional
|
||||||
|
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.reporter import FileReporter
|
from midas.checker.reporter import FileReporter
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
ColumnType,
|
ColumnType,
|
||||||
@@ -14,111 +16,134 @@ from midas.checker.types import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from midas.checker.frames import FrameManager
|
|
||||||
from midas.checker.python import PythonTyper, TypedExpr
|
from midas.checker.python import PythonTyper, TypedExpr
|
||||||
|
|
||||||
|
|
||||||
class FrameMethod(Protocol):
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Call:
|
||||||
|
location: Location
|
||||||
|
frame: DataFrameType
|
||||||
|
positional: list[TypedExpr]
|
||||||
|
keywords: dict[str, TypedExpr]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_method(obj: object, method: str) -> bool:
|
||||||
|
if not callable(obj):
|
||||||
|
return False
|
||||||
|
if not hasattr(obj, "__method_names__"):
|
||||||
|
return False
|
||||||
|
return method in obj.__method_names__ # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class MethodResolver:
|
||||||
|
@staticmethod
|
||||||
|
def frame_method(*names: str):
|
||||||
|
def wrapper(func):
|
||||||
|
names_: tuple[str, ...] = names
|
||||||
|
if len(names_) == 0:
|
||||||
|
names_ = (func.__name__,)
|
||||||
|
setattr(func, "__method_names__", names_)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def __init__(self, typer: PythonTyper) -> None:
|
||||||
|
self.typer: PythonTyper = typer
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def __name__(self) -> str: ...
|
def reporter(self) -> FileReporter:
|
||||||
def __call__(
|
return self.typer.reporter
|
||||||
|
|
||||||
|
@property
|
||||||
|
def types(self) -> TypesRegistry:
|
||||||
|
return self.typer.types
|
||||||
|
|
||||||
|
def _get_method_by_name(self, method: str) -> Optional[Callable]:
|
||||||
|
for name in dir(self):
|
||||||
|
attr = getattr(self, name)
|
||||||
|
if _is_method(attr, method):
|
||||||
|
return attr
|
||||||
|
return None
|
||||||
|
|
||||||
|
def call(
|
||||||
self,
|
self,
|
||||||
typer: PythonTyper,
|
method: str,
|
||||||
manager: FrameManager,
|
call: Call,
|
||||||
reporter: FileReporter,
|
) -> Type:
|
||||||
location: Location,
|
func: Optional[Callable] = self._get_method_by_name(method)
|
||||||
frame: DataFrameType,
|
if func is None:
|
||||||
positional: list[TypedExpr],
|
self.reporter.error(call.location, f"Unknown method {method}")
|
||||||
keywords: dict[str, TypedExpr],
|
return UnknownType()
|
||||||
) -> Type: ...
|
return func(call)
|
||||||
|
|
||||||
|
@frame_method("add", "__add__")
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
call: Call,
|
||||||
|
) -> Type:
|
||||||
|
new_columns: list[DataFrameType.Column] = []
|
||||||
|
|
||||||
FRAME_METHODS: dict[str, FrameMethod] = {}
|
by_name: dict[str, DataFrameType.Column] = {}
|
||||||
|
frame2: Optional[DataFrameType] = None
|
||||||
|
if len(call.positional) != 0:
|
||||||
|
other: Type = call.positional[0][1]
|
||||||
|
unfolded_other: Type = unfold_type(other)
|
||||||
|
if isinstance(unfolded_other, DataFrameType):
|
||||||
|
frame2 = unfolded_other
|
||||||
|
by_name = {
|
||||||
|
col.name: col for col in frame2.columns if col.name is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
in_frame1: set[str] = set()
|
||||||
|
for column in call.frame.columns:
|
||||||
|
if column.name is not None:
|
||||||
|
in_frame1.add(column.name)
|
||||||
|
|
||||||
def frame_method(*names: str):
|
col_type1: Type = column.type
|
||||||
def wrapper(func: FrameMethod):
|
col_type: Type = ColumnType(type=UnknownType())
|
||||||
names_: tuple[str, ...] = names
|
if column.name in by_name:
|
||||||
if len(names_) == 0:
|
column2 = by_name[column.name]
|
||||||
names_ = (func.__name__,)
|
col_type2: Type = column2.type
|
||||||
for name in names_:
|
if self.types.are_equivalent(col_type2, col_type1):
|
||||||
FRAME_METHODS[name] = func
|
col_type = col_type1
|
||||||
return func
|
|
||||||
|
|
||||||
return wrapper
|
new_column = DataFrameType.Column(
|
||||||
|
index=column.index,
|
||||||
|
name=column.name,
|
||||||
@frame_method("add", "__add__")
|
type=col_type,
|
||||||
def add(
|
|
||||||
typer: PythonTyper,
|
|
||||||
manager: FrameManager,
|
|
||||||
reporter: FileReporter,
|
|
||||||
location: Location,
|
|
||||||
frame: DataFrameType,
|
|
||||||
positional: list[TypedExpr],
|
|
||||||
keywords: dict[str, TypedExpr],
|
|
||||||
) -> Type:
|
|
||||||
new_columns: list[DataFrameType.Column] = []
|
|
||||||
|
|
||||||
by_name: dict[str, DataFrameType.Column] = {}
|
|
||||||
frame2: Optional[DataFrameType] = None
|
|
||||||
if len(positional) != 0:
|
|
||||||
other: Type = positional[0][1]
|
|
||||||
unfolded_other: Type = unfold_type(other)
|
|
||||||
if isinstance(unfolded_other, DataFrameType):
|
|
||||||
frame2 = unfolded_other
|
|
||||||
by_name = {col.name: col for col in frame2.columns if col.name is not None}
|
|
||||||
|
|
||||||
in_frame1: set[str] = set()
|
|
||||||
for column in frame.columns:
|
|
||||||
if column.name is not None:
|
|
||||||
in_frame1.add(column.name)
|
|
||||||
|
|
||||||
col_type1: Type = column.type
|
|
||||||
col_type: Type = ColumnType(type=UnknownType())
|
|
||||||
if column.name in by_name:
|
|
||||||
column2 = by_name[column.name]
|
|
||||||
col_type2: Type = column2.type
|
|
||||||
if manager.types.are_equivalent(col_type2, col_type1):
|
|
||||||
col_type = col_type1
|
|
||||||
|
|
||||||
new_column = DataFrameType.Column(
|
|
||||||
index=column.index,
|
|
||||||
name=column.name,
|
|
||||||
type=col_type,
|
|
||||||
)
|
|
||||||
new_columns.append(new_column)
|
|
||||||
|
|
||||||
if frame2 is not None:
|
|
||||||
for column in frame2.columns:
|
|
||||||
if column.name in in_frame1:
|
|
||||||
continue
|
|
||||||
new_columns.append(
|
|
||||||
DataFrameType.Column(
|
|
||||||
index=len(new_columns),
|
|
||||||
name=column.name,
|
|
||||||
type=ColumnType(type=UnknownType()),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
new_columns.append(new_column)
|
||||||
|
|
||||||
signature = Function(
|
if frame2 is not None:
|
||||||
args=[
|
for column in frame2.columns:
|
||||||
Function.Argument(
|
if column.name in in_frame1:
|
||||||
pos=0,
|
continue
|
||||||
name="other",
|
new_columns.append(
|
||||||
type=DataFrameType(columns=[]),
|
DataFrameType.Column(
|
||||||
required=True,
|
index=len(new_columns),
|
||||||
),
|
name=column.name,
|
||||||
],
|
type=ColumnType(type=UnknownType()),
|
||||||
returns=DataFrameType(columns=new_columns),
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
signature = Function(
|
||||||
typer._get_call_result(
|
args=[
|
||||||
location=location,
|
Function.Argument(
|
||||||
callee=signature,
|
pos=0,
|
||||||
positional=positional,
|
name="other",
|
||||||
keywords=keywords,
|
type=DataFrameType(columns=[]),
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
returns=DataFrameType(columns=new_columns),
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.typer._get_call_result(
|
||||||
|
location=call.location,
|
||||||
|
callee=signature,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
or UnknownType()
|
||||||
)
|
)
|
||||||
or UnknownType()
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -4,8 +4,7 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast
|
|||||||
|
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.checker.frame_methods import FRAME_METHODS, FrameMethod
|
from midas.checker.frame_methods import Call, MethodResolver
|
||||||
from midas.checker.registry import TypesRegistry
|
|
||||||
from midas.checker.reporter import FileReporter
|
from midas.checker.reporter import FileReporter
|
||||||
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
|
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
|
||||||
|
|
||||||
@@ -18,8 +17,9 @@ def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
|
|||||||
|
|
||||||
|
|
||||||
class FrameManager:
|
class FrameManager:
|
||||||
def __init__(self, types: TypesRegistry) -> None:
|
def __init__(self, typer: PythonTyper) -> None:
|
||||||
self.types: TypesRegistry = types
|
self.typer: PythonTyper = typer
|
||||||
|
self.method_resolver: MethodResolver = MethodResolver(self.typer)
|
||||||
|
|
||||||
def assign(
|
def assign(
|
||||||
self,
|
self,
|
||||||
@@ -137,15 +137,18 @@ class FrameManager:
|
|||||||
) -> list[Optional[ColumnType]]:
|
) -> list[Optional[ColumnType]]:
|
||||||
return [cls._get_column(frame, name) for name in names]
|
return [cls._get_column(frame, name) for name in names]
|
||||||
|
|
||||||
def call_method(
|
def call(
|
||||||
self,
|
self,
|
||||||
typer: PythonTyper,
|
method: str,
|
||||||
reporter: FileReporter,
|
|
||||||
location: Location,
|
location: Location,
|
||||||
frame: DataFrameType,
|
frame: DataFrameType,
|
||||||
method: str,
|
|
||||||
positional: list[TypedExpr],
|
positional: list[TypedExpr],
|
||||||
keywords: dict[str, TypedExpr],
|
keywords: dict[str, TypedExpr],
|
||||||
) -> Type:
|
) -> Type:
|
||||||
function: FrameMethod = FRAME_METHODS[method]
|
call: Call = Call(
|
||||||
return function(typer, self, reporter, location, frame, positional, keywords)
|
location=location,
|
||||||
|
frame=frame,
|
||||||
|
positional=positional,
|
||||||
|
keywords=keywords,
|
||||||
|
)
|
||||||
|
return self.method_resolver.call(method, call)
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from midas.ast.location import Location
|
|||||||
from midas.ast.printer import MidasPrinter
|
from midas.ast.printer import MidasPrinter
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
from midas.checker.evaluator import Evaluator
|
from midas.checker.evaluator import Evaluator
|
||||||
from midas.checker.frame_methods import FRAME_METHODS
|
|
||||||
from midas.checker.frames import FrameManager
|
from midas.checker.frames import FrameManager
|
||||||
from midas.checker.operators import (
|
from midas.checker.operators import (
|
||||||
PY_COMPARATOR_METHODS,
|
PY_COMPARATOR_METHODS,
|
||||||
@@ -76,7 +75,7 @@ class PythonTyper(
|
|||||||
self.logger: logging.Logger = logging.getLogger("PythonTyper")
|
self.logger: logging.Logger = logging.getLogger("PythonTyper")
|
||||||
self.reporter: FileReporter = reporter.for_file(None)
|
self.reporter: FileReporter = reporter.for_file(None)
|
||||||
self.types: TypesRegistry = types
|
self.types: TypesRegistry = types
|
||||||
self.frame_mgr: FrameManager = FrameManager(self.types)
|
self.frame_mgr: FrameManager = FrameManager(self)
|
||||||
self.global_env: Environment = Preamble(self.types)
|
self.global_env: Environment = Preamble(self.types)
|
||||||
self.env: Environment = self.global_env
|
self.env: Environment = self.global_env
|
||||||
self.locals: dict[p.Expr, int] = {}
|
self.locals: dict[p.Expr, int] = {}
|
||||||
@@ -527,15 +526,11 @@ class PythonTyper(
|
|||||||
case p.GetExpr(object=obj, name=method):
|
case p.GetExpr(object=obj, name=method):
|
||||||
obj_type: Type = self.type_of(obj)
|
obj_type: Type = self.type_of(obj)
|
||||||
unfolded: Type = unfold_type(obj_type)
|
unfolded: Type = unfold_type(obj_type)
|
||||||
if isinstance(unfolded, DataFrameType) and self._is_frame_method(
|
if isinstance(unfolded, DataFrameType):
|
||||||
method
|
return self.frame_mgr.call(
|
||||||
):
|
method,
|
||||||
return self.frame_mgr.call_method(
|
|
||||||
self,
|
|
||||||
self.reporter,
|
|
||||||
expr.location,
|
expr.location,
|
||||||
unfolded,
|
unfolded,
|
||||||
method,
|
|
||||||
positional,
|
positional,
|
||||||
keywords,
|
keywords,
|
||||||
)
|
)
|
||||||
@@ -1307,6 +1302,3 @@ class PythonTyper(
|
|||||||
self, frame: DataFrameType, expr: p.SubscriptExpr
|
self, frame: DataFrameType, expr: p.SubscriptExpr
|
||||||
) -> Type:
|
) -> Type:
|
||||||
return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index)
|
return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index)
|
||||||
|
|
||||||
def _is_frame_method(self, method: str) -> bool:
|
|
||||||
return method in FRAME_METHODS
|
|
||||||
|
|||||||
Reference in New Issue
Block a user