feat(checker): type check subscript on dataframes
This commit is contained in:
51
midas/checker/frames.py
Normal file
51
midas/checker/frames.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import Optional
|
||||
|
||||
from midas.checker.types import ColumnType, DataFrameType
|
||||
|
||||
|
||||
class FrameManager:
|
||||
@classmethod
|
||||
def set_column(
|
||||
cls, frame: DataFrameType, name: str, column: ColumnType
|
||||
) -> DataFrameType:
|
||||
new_columns: list[DataFrameType.Column] = []
|
||||
index: int = len(frame.columns)
|
||||
replace: bool = False
|
||||
for i, col in enumerate(frame.columns):
|
||||
if col.name == name:
|
||||
index = i
|
||||
replace = True
|
||||
new_columns.append(col)
|
||||
|
||||
new_col: DataFrameType.Column = DataFrameType.Column(
|
||||
index=index,
|
||||
name=name,
|
||||
type=column,
|
||||
)
|
||||
if replace:
|
||||
new_columns[index] = new_col
|
||||
else:
|
||||
new_columns.append(new_col)
|
||||
|
||||
return DataFrameType(columns=new_columns)
|
||||
|
||||
@classmethod
|
||||
def set_columns(
|
||||
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
|
||||
) -> DataFrameType:
|
||||
for name, col in zip(names, columns):
|
||||
frame = cls.set_column(frame, name, col)
|
||||
return frame
|
||||
|
||||
@classmethod
|
||||
def get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
|
||||
for col in frame.columns:
|
||||
if col.name == name:
|
||||
return col.type
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_columns(
|
||||
cls, frame: DataFrameType, names: list[str]
|
||||
) -> list[Optional[ColumnType]]:
|
||||
return [cls.get_column(frame, name) for name in names]
|
||||
@@ -1,11 +1,12 @@
|
||||
import ast
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.frames import FrameManager
|
||||
from midas.checker.operators import (
|
||||
PY_COMPARATOR_METHODS,
|
||||
PY_OPERATOR_METHODS,
|
||||
@@ -629,6 +630,8 @@ class PythonTyper(
|
||||
match unfolded:
|
||||
case TupleType():
|
||||
return self._visit_tuple_subscript(unfolded, expr)
|
||||
case DataFrameType():
|
||||
return self._visit_frame_subscript(unfolded, expr)
|
||||
|
||||
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
|
||||
if operation is None:
|
||||
@@ -1144,3 +1147,39 @@ class PythonTyper(
|
||||
expr.location, f"Invalid index type {expr.index} on {tup}"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def _visit_frame_subscript(
|
||||
self, frame: DataFrameType, expr: p.SubscriptExpr
|
||||
) -> Type:
|
||||
match expr.index:
|
||||
case p.LiteralExpr(value=str() as name):
|
||||
column: Optional[ColumnType] = FrameManager.get_column(frame, name)
|
||||
if column is None:
|
||||
self.reporter.error(
|
||||
expr.location, f"Unknown column '{name}' on {frame}"
|
||||
)
|
||||
return UnknownType()
|
||||
return column
|
||||
|
||||
case p.ListExpr(items=indices) if all(
|
||||
isinstance(index, p.LiteralExpr) and isinstance(index.value, str)
|
||||
for index in indices
|
||||
):
|
||||
indices = cast(list[p.LiteralExpr], indices)
|
||||
names: list[str] = [cast(str, index.value) for index in indices]
|
||||
columns: list[ColumnType] = []
|
||||
for name in names:
|
||||
column: Optional[ColumnType] = FrameManager.get_column(frame, name)
|
||||
if column is None:
|
||||
self.reporter.error(
|
||||
expr.location, f"Unknown column '{name}' on {frame}"
|
||||
)
|
||||
return UnknownType()
|
||||
columns.append(column)
|
||||
return TupleType(items=tuple(columns))
|
||||
|
||||
case _:
|
||||
self.reporter.error(
|
||||
expr.location, f"Invalid index type {expr.index} on {frame}"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
Reference in New Issue
Block a user