Files
midas/midas/checker/frames.py
LordBaryhobal ff69b65171 feat(checker): add same length assertion on frames
safely adding two dataframes is only possible if the sizes are the same, or null values could be added dynamically to pad the shortest dataframe
2026-07-02 17:14:05 +02:00

159 lines
5.1 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING, Optional, TypeGuard, cast
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frame_methods import Call, MethodRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
return all(isinstance(expr, p.LiteralExpr) for expr in exprs)
class FrameManager:
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
self.method_resolver: MethodRegistry = MethodRegistry(self.typer)
def assign(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
index: p.Expr,
value_type: Type,
) -> Type:
match index:
case p.LiteralExpr(value=str() as name):
return self.assign_column(reporter, location, frame, name, value_type)
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
isinstance(idx, str) for idx in indices
):
raise NotImplementedError
case _:
reporter.error(location, f"Invalid index type {index} on {frame}")
return UnknownType()
def assign_column(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
name: str,
type: Type,
) -> Type:
if not isinstance(type, ColumnType):
reporter.error(
location,
f"Cannot assign {type} to dataframe column. Must be a ColumnType",
)
return frame
return self._set_column(frame, name, type)
def get(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
index: p.Expr,
) -> Type:
match index:
case p.LiteralExpr(value=str() as name):
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
if column is None:
reporter.error(location, f"Unknown column '{name}' on {frame}")
return UnknownType()
return column
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
isinstance(index.value, str) for index in indices
):
names: list[str] = [cast(str, index.value) for index in indices]
columns: list[ColumnType] = []
for name in names:
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
if column is None:
reporter.error(location, f"Unknown column '{name}' on {frame}")
return UnknownType()
columns.append(column)
return TupleType(items=tuple(columns))
case _:
reporter.error(location, f"Invalid index type {index} on {frame}")
return UnknownType()
@classmethod
def _set_column(
cls, frame: DataFrameType, name: str, column: ColumnType
) -> DataFrameType:
new_columns: list[DataFrameType.Column] = []
index: int = len(frame.columns)
replace: bool = False
for i, col in enumerate(frame.columns):
if col.name == name:
index = i
replace = True
# TODO: check column type here to prevent changing it
new_columns.append(col)
new_col: DataFrameType.Column = DataFrameType.Column(
index=index,
name=name,
type=column,
)
if replace:
new_columns[index] = new_col
else:
new_columns.append(new_col)
return DataFrameType(columns=new_columns)
@classmethod
def _set_columns(
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
) -> DataFrameType:
for name, col in zip(names, columns):
frame = cls._set_column(frame, name, col)
return frame
@classmethod
def _get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
for col in frame.columns:
if col.name == name:
return col.type
return None
@classmethod
def _get_columns(
cls, frame: DataFrameType, names: list[str]
) -> list[Optional[ColumnType]]:
return [cls._get_column(frame, name) for name in names]
def call(
self,
method: str,
location: Location,
call_expr: p.Expr,
frame: DataFrameType,
frame_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: Call = Call(
location=location,
call_expr=call_expr,
frame=frame,
frame_expr=frame_expr,
positional=positional,
keywords=keywords,
)
return self.method_resolver.call(method, call)