refactor: add MethodResolver class

This commit is contained in:
2026-06-25 22:14:25 +02:00
parent 894d5a7196
commit 5b3e87afcb
3 changed files with 137 additions and 117 deletions

View File

@@ -1,8 +1,10 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Protocol from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional
from midas.ast.location import Location from midas.ast.location import Location
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter from midas.checker.reporter import FileReporter
from midas.checker.types import ( from midas.checker.types import (
ColumnType, ColumnType,
@@ -14,111 +16,134 @@ from midas.checker.types import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from midas.checker.frames import FrameManager
from midas.checker.python import PythonTyper, TypedExpr from midas.checker.python import PythonTyper, TypedExpr
class FrameMethod(Protocol): @dataclass(frozen=True, kw_only=True)
class Call:
location: Location
frame: DataFrameType
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
def _is_method(obj: object, method: str) -> bool:
if not callable(obj):
return False
if not hasattr(obj, "__method_names__"):
return False
return method in obj.__method_names__ # type: ignore
class MethodResolver:
@staticmethod
def frame_method(*names: str):
def wrapper(func):
names_: tuple[str, ...] = names
if len(names_) == 0:
names_ = (func.__name__,)
setattr(func, "__method_names__", names_)
return func
return wrapper
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
@property @property
def __name__(self) -> str: ... def reporter(self) -> FileReporter:
def __call__( return self.typer.reporter
@property
def types(self) -> TypesRegistry:
return self.typer.types
def _get_method_by_name(self, method: str) -> Optional[Callable]:
for name in dir(self):
attr = getattr(self, name)
if _is_method(attr, method):
return attr
return None
def call(
self, self,
typer: PythonTyper, method: str,
manager: FrameManager, call: Call,
reporter: FileReporter, ) -> Type:
location: Location, func: Optional[Callable] = self._get_method_by_name(method)
frame: DataFrameType, if func is None:
positional: list[TypedExpr], self.reporter.error(call.location, f"Unknown method {method}")
keywords: dict[str, TypedExpr], return UnknownType()
) -> Type: ... return func(call)
@frame_method("add", "__add__")
def add(
self,
call: Call,
) -> Type:
new_columns: list[DataFrameType.Column] = []
FRAME_METHODS: dict[str, FrameMethod] = {} by_name: dict[str, DataFrameType.Column] = {}
frame2: Optional[DataFrameType] = None
if len(call.positional) != 0:
other: Type = call.positional[0][1]
unfolded_other: Type = unfold_type(other)
if isinstance(unfolded_other, DataFrameType):
frame2 = unfolded_other
by_name = {
col.name: col for col in frame2.columns if col.name is not None
}
in_frame1: set[str] = set()
for column in call.frame.columns:
if column.name is not None:
in_frame1.add(column.name)
def frame_method(*names: str): col_type1: Type = column.type
def wrapper(func: FrameMethod): col_type: Type = ColumnType(type=UnknownType())
names_: tuple[str, ...] = names if column.name in by_name:
if len(names_) == 0: column2 = by_name[column.name]
names_ = (func.__name__,) col_type2: Type = column2.type
for name in names_: if self.types.are_equivalent(col_type2, col_type1):
FRAME_METHODS[name] = func col_type = col_type1
return func
return wrapper new_column = DataFrameType.Column(
index=column.index,
name=column.name,
@frame_method("add", "__add__") type=col_type,
def add(
typer: PythonTyper,
manager: FrameManager,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
new_columns: list[DataFrameType.Column] = []
by_name: dict[str, DataFrameType.Column] = {}
frame2: Optional[DataFrameType] = None
if len(positional) != 0:
other: Type = positional[0][1]
unfolded_other: Type = unfold_type(other)
if isinstance(unfolded_other, DataFrameType):
frame2 = unfolded_other
by_name = {col.name: col for col in frame2.columns if col.name is not None}
in_frame1: set[str] = set()
for column in frame.columns:
if column.name is not None:
in_frame1.add(column.name)
col_type1: Type = column.type
col_type: Type = ColumnType(type=UnknownType())
if column.name in by_name:
column2 = by_name[column.name]
col_type2: Type = column2.type
if manager.types.are_equivalent(col_type2, col_type1):
col_type = col_type1
new_column = DataFrameType.Column(
index=column.index,
name=column.name,
type=col_type,
)
new_columns.append(new_column)
if frame2 is not None:
for column in frame2.columns:
if column.name in in_frame1:
continue
new_columns.append(
DataFrameType.Column(
index=len(new_columns),
name=column.name,
type=ColumnType(type=UnknownType()),
)
) )
new_columns.append(new_column)
signature = Function( if frame2 is not None:
args=[ for column in frame2.columns:
Function.Argument( if column.name in in_frame1:
pos=0, continue
name="other", new_columns.append(
type=DataFrameType(columns=[]), DataFrameType.Column(
required=True, index=len(new_columns),
), name=column.name,
], type=ColumnType(type=UnknownType()),
returns=DataFrameType(columns=new_columns), )
) )
return ( signature = Function(
typer._get_call_result( args=[
location=location, Function.Argument(
callee=signature, pos=0,
positional=positional, name="other",
keywords=keywords, type=DataFrameType(columns=[]),
required=True,
),
],
returns=DataFrameType(columns=new_columns),
)
return (
self.typer._get_call_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
or UnknownType()
) )
or UnknownType()
)

View File

@@ -4,8 +4,7 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast
import midas.ast.python as p import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
from midas.checker.frame_methods import FRAME_METHODS, FrameMethod from midas.checker.frame_methods import Call, MethodResolver
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter from midas.checker.reporter import FileReporter
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
@@ -18,8 +17,9 @@ def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
class FrameManager: class FrameManager:
def __init__(self, types: TypesRegistry) -> None: def __init__(self, typer: PythonTyper) -> None:
self.types: TypesRegistry = types self.typer: PythonTyper = typer
self.method_resolver: MethodResolver = MethodResolver(self.typer)
def assign( def assign(
self, self,
@@ -137,15 +137,18 @@ class FrameManager:
) -> list[Optional[ColumnType]]: ) -> list[Optional[ColumnType]]:
return [cls._get_column(frame, name) for name in names] return [cls._get_column(frame, name) for name in names]
def call_method( def call(
self, self,
typer: PythonTyper, method: str,
reporter: FileReporter,
location: Location, location: Location,
frame: DataFrameType, frame: DataFrameType,
method: str,
positional: list[TypedExpr], positional: list[TypedExpr],
keywords: dict[str, TypedExpr], keywords: dict[str, TypedExpr],
) -> Type: ) -> Type:
function: FrameMethod = FRAME_METHODS[method] call: Call = Call(
return function(typer, self, reporter, location, frame, positional, keywords) location=location,
frame=frame,
positional=positional,
keywords=keywords,
)
return self.method_resolver.call(method, call)

View File

@@ -8,7 +8,6 @@ from midas.ast.location import Location
from midas.ast.printer import MidasPrinter from midas.ast.printer import MidasPrinter
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.evaluator import Evaluator from midas.checker.evaluator import Evaluator
from midas.checker.frame_methods import FRAME_METHODS
from midas.checker.frames import FrameManager from midas.checker.frames import FrameManager
from midas.checker.operators import ( from midas.checker.operators import (
PY_COMPARATOR_METHODS, PY_COMPARATOR_METHODS,
@@ -76,7 +75,7 @@ class PythonTyper(
self.logger: logging.Logger = logging.getLogger("PythonTyper") self.logger: logging.Logger = logging.getLogger("PythonTyper")
self.reporter: FileReporter = reporter.for_file(None) self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types self.types: TypesRegistry = types
self.frame_mgr: FrameManager = FrameManager(self.types) self.frame_mgr: FrameManager = FrameManager(self)
self.global_env: Environment = Preamble(self.types) self.global_env: Environment = Preamble(self.types)
self.env: Environment = self.global_env self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {} self.locals: dict[p.Expr, int] = {}
@@ -527,15 +526,11 @@ class PythonTyper(
case p.GetExpr(object=obj, name=method): case p.GetExpr(object=obj, name=method):
obj_type: Type = self.type_of(obj) obj_type: Type = self.type_of(obj)
unfolded: Type = unfold_type(obj_type) unfolded: Type = unfold_type(obj_type)
if isinstance(unfolded, DataFrameType) and self._is_frame_method( if isinstance(unfolded, DataFrameType):
method return self.frame_mgr.call(
): method,
return self.frame_mgr.call_method(
self,
self.reporter,
expr.location, expr.location,
unfolded, unfolded,
method,
positional, positional,
keywords, keywords,
) )
@@ -1307,6 +1302,3 @@ class PythonTyper(
self, frame: DataFrameType, expr: p.SubscriptExpr self, frame: DataFrameType, expr: p.SubscriptExpr
) -> Type: ) -> Type:
return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index) return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index)
def _is_frame_method(self, method: str) -> bool:
return method in FRAME_METHODS