feat(checker): type check subscript on dataframes

This commit is contained in:
2026-06-23 12:27:31 +02:00
parent 83bd3793df
commit 45e27ee04e
2 changed files with 91 additions and 1 deletions

51
midas/checker/frames.py Normal file
View File

@@ -0,0 +1,51 @@
from typing import Optional
from midas.checker.types import ColumnType, DataFrameType
class FrameManager:
@classmethod
def set_column(
cls, frame: DataFrameType, name: str, column: ColumnType
) -> DataFrameType:
new_columns: list[DataFrameType.Column] = []
index: int = len(frame.columns)
replace: bool = False
for i, col in enumerate(frame.columns):
if col.name == name:
index = i
replace = True
new_columns.append(col)
new_col: DataFrameType.Column = DataFrameType.Column(
index=index,
name=name,
type=column,
)
if replace:
new_columns[index] = new_col
else:
new_columns.append(new_col)
return DataFrameType(columns=new_columns)
@classmethod
def set_columns(
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
) -> DataFrameType:
for name, col in zip(names, columns):
frame = cls.set_column(frame, name, col)
return frame
@classmethod
def get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
for col in frame.columns:
if col.name == name:
return col.type
return None
@classmethod
def get_columns(
cls, frame: DataFrameType, names: list[str]
) -> list[Optional[ColumnType]]:
return [cls.get_column(frame, name) for name in names]

View File

@@ -1,11 +1,12 @@
import ast
import logging
from dataclasses import dataclass
from typing import Optional
from typing import Optional, cast
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.environment import Environment
from midas.checker.frames import FrameManager
from midas.checker.operators import (
PY_COMPARATOR_METHODS,
PY_OPERATOR_METHODS,
@@ -629,6 +630,8 @@ class PythonTyper(
match unfolded:
case TupleType():
return self._visit_tuple_subscript(unfolded, expr)
case DataFrameType():
return self._visit_frame_subscript(unfolded, expr)
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
if operation is None:
@@ -1144,3 +1147,39 @@ class PythonTyper(
expr.location, f"Invalid index type {expr.index} on {tup}"
)
return UnknownType()
def _visit_frame_subscript(
self, frame: DataFrameType, expr: p.SubscriptExpr
) -> Type:
match expr.index:
case p.LiteralExpr(value=str() as name):
column: Optional[ColumnType] = FrameManager.get_column(frame, name)
if column is None:
self.reporter.error(
expr.location, f"Unknown column '{name}' on {frame}"
)
return UnknownType()
return column
case p.ListExpr(items=indices) if all(
isinstance(index, p.LiteralExpr) and isinstance(index.value, str)
for index in indices
):
indices = cast(list[p.LiteralExpr], indices)
names: list[str] = [cast(str, index.value) for index in indices]
columns: list[ColumnType] = []
for name in names:
column: Optional[ColumnType] = FrameManager.get_column(frame, name)
if column is None:
self.reporter.error(
expr.location, f"Unknown column '{name}' on {frame}"
)
return UnknownType()
columns.append(column)
return TupleType(items=tuple(columns))
case _:
self.reporter.error(
expr.location, f"Invalid index type {expr.index} on {frame}"
)
return UnknownType()