refactor: add MethodResolver class

This commit is contained in:
2026-06-25 22:14:25 +02:00
parent 894d5a7196
commit 5b3e87afcb
3 changed files with 137 additions and 117 deletions

View File

@@ -1,8 +1,10 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Protocol from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional
from midas.ast.location import Location from midas.ast.location import Location
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter from midas.checker.reporter import FileReporter
from midas.checker.types import ( from midas.checker.types import (
ColumnType, ColumnType,
@@ -14,63 +16,86 @@ from midas.checker.types import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from midas.checker.frames import FrameManager
from midas.checker.python import PythonTyper, TypedExpr from midas.checker.python import PythonTyper, TypedExpr
class FrameMethod(Protocol): @dataclass(frozen=True, kw_only=True)
@property class Call:
def __name__(self) -> str: ... location: Location
def __call__( frame: DataFrameType
self, positional: list[TypedExpr]
typer: PythonTyper, keywords: dict[str, TypedExpr]
manager: FrameManager,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type: ...
FRAME_METHODS: dict[str, FrameMethod] = {} def _is_method(obj: object, method: str) -> bool:
if not callable(obj):
return False
if not hasattr(obj, "__method_names__"):
return False
return method in obj.__method_names__ # type: ignore
def frame_method(*names: str): class MethodResolver:
def wrapper(func: FrameMethod): @staticmethod
def frame_method(*names: str):
def wrapper(func):
names_: tuple[str, ...] = names names_: tuple[str, ...] = names
if len(names_) == 0: if len(names_) == 0:
names_ = (func.__name__,) names_ = (func.__name__,)
for name in names_: setattr(func, "__method_names__", names_)
FRAME_METHODS[name] = func
return func return func
return wrapper return wrapper
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
@frame_method("add", "__add__") @property
def add( def reporter(self) -> FileReporter:
typer: PythonTyper, return self.typer.reporter
manager: FrameManager,
reporter: FileReporter, @property
location: Location, def types(self) -> TypesRegistry:
frame: DataFrameType, return self.typer.types
positional: list[TypedExpr],
keywords: dict[str, TypedExpr], def _get_method_by_name(self, method: str) -> Optional[Callable]:
) -> Type: for name in dir(self):
attr = getattr(self, name)
if _is_method(attr, method):
return attr
return None
def call(
self,
method: str,
call: Call,
) -> Type:
func: Optional[Callable] = self._get_method_by_name(method)
if func is None:
self.reporter.error(call.location, f"Unknown method {method}")
return UnknownType()
return func(call)
@frame_method("add", "__add__")
def add(
self,
call: Call,
) -> Type:
new_columns: list[DataFrameType.Column] = [] new_columns: list[DataFrameType.Column] = []
by_name: dict[str, DataFrameType.Column] = {} by_name: dict[str, DataFrameType.Column] = {}
frame2: Optional[DataFrameType] = None frame2: Optional[DataFrameType] = None
if len(positional) != 0: if len(call.positional) != 0:
other: Type = positional[0][1] other: Type = call.positional[0][1]
unfolded_other: Type = unfold_type(other) unfolded_other: Type = unfold_type(other)
if isinstance(unfolded_other, DataFrameType): if isinstance(unfolded_other, DataFrameType):
frame2 = unfolded_other frame2 = unfolded_other
by_name = {col.name: col for col in frame2.columns if col.name is not None} by_name = {
col.name: col for col in frame2.columns if col.name is not None
}
in_frame1: set[str] = set() in_frame1: set[str] = set()
for column in frame.columns: for column in call.frame.columns:
if column.name is not None: if column.name is not None:
in_frame1.add(column.name) in_frame1.add(column.name)
@@ -79,7 +104,7 @@ def add(
if column.name in by_name: if column.name in by_name:
column2 = by_name[column.name] column2 = by_name[column.name]
col_type2: Type = column2.type col_type2: Type = column2.type
if manager.types.are_equivalent(col_type2, col_type1): if self.types.are_equivalent(col_type2, col_type1):
col_type = col_type1 col_type = col_type1
new_column = DataFrameType.Column( new_column = DataFrameType.Column(
@@ -114,11 +139,11 @@ def add(
) )
return ( return (
typer._get_call_result( self.typer._get_call_result(
location=location, location=call.location,
callee=signature, callee=signature,
positional=positional, positional=call.positional,
keywords=keywords, keywords=call.keywords,
) )
or UnknownType() or UnknownType()
) )

View File

@@ -4,8 +4,7 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast
import midas.ast.python as p import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
from midas.checker.frame_methods import FRAME_METHODS, FrameMethod from midas.checker.frame_methods import Call, MethodResolver
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter from midas.checker.reporter import FileReporter
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
@@ -18,8 +17,9 @@ def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
class FrameManager: class FrameManager:
def __init__(self, types: TypesRegistry) -> None: def __init__(self, typer: PythonTyper) -> None:
self.types: TypesRegistry = types self.typer: PythonTyper = typer
self.method_resolver: MethodResolver = MethodResolver(self.typer)
def assign( def assign(
self, self,
@@ -137,15 +137,18 @@ class FrameManager:
) -> list[Optional[ColumnType]]: ) -> list[Optional[ColumnType]]:
return [cls._get_column(frame, name) for name in names] return [cls._get_column(frame, name) for name in names]
def call_method( def call(
self, self,
typer: PythonTyper, method: str,
reporter: FileReporter,
location: Location, location: Location,
frame: DataFrameType, frame: DataFrameType,
method: str,
positional: list[TypedExpr], positional: list[TypedExpr],
keywords: dict[str, TypedExpr], keywords: dict[str, TypedExpr],
) -> Type: ) -> Type:
function: FrameMethod = FRAME_METHODS[method] call: Call = Call(
return function(typer, self, reporter, location, frame, positional, keywords) location=location,
frame=frame,
positional=positional,
keywords=keywords,
)
return self.method_resolver.call(method, call)

View File

@@ -8,7 +8,6 @@ from midas.ast.location import Location
from midas.ast.printer import MidasPrinter from midas.ast.printer import MidasPrinter
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.evaluator import Evaluator from midas.checker.evaluator import Evaluator
from midas.checker.frame_methods import FRAME_METHODS
from midas.checker.frames import FrameManager from midas.checker.frames import FrameManager
from midas.checker.operators import ( from midas.checker.operators import (
PY_COMPARATOR_METHODS, PY_COMPARATOR_METHODS,
@@ -76,7 +75,7 @@ class PythonTyper(
self.logger: logging.Logger = logging.getLogger("PythonTyper") self.logger: logging.Logger = logging.getLogger("PythonTyper")
self.reporter: FileReporter = reporter.for_file(None) self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types self.types: TypesRegistry = types
self.frame_mgr: FrameManager = FrameManager(self.types) self.frame_mgr: FrameManager = FrameManager(self)
self.global_env: Environment = Preamble(self.types) self.global_env: Environment = Preamble(self.types)
self.env: Environment = self.global_env self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {} self.locals: dict[p.Expr, int] = {}
@@ -527,15 +526,11 @@ class PythonTyper(
case p.GetExpr(object=obj, name=method): case p.GetExpr(object=obj, name=method):
obj_type: Type = self.type_of(obj) obj_type: Type = self.type_of(obj)
unfolded: Type = unfold_type(obj_type) unfolded: Type = unfold_type(obj_type)
if isinstance(unfolded, DataFrameType) and self._is_frame_method( if isinstance(unfolded, DataFrameType):
method return self.frame_mgr.call(
): method,
return self.frame_mgr.call_method(
self,
self.reporter,
expr.location, expr.location,
unfolded, unfolded,
method,
positional, positional,
keywords, keywords,
) )
@@ -1307,6 +1302,3 @@ class PythonTyper(
self, frame: DataFrameType, expr: p.SubscriptExpr self, frame: DataFrameType, expr: p.SubscriptExpr
) -> Type: ) -> Type:
return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index) return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index)
def _is_frame_method(self, method: str) -> bool:
return method in FRAME_METHODS