feat(checker): handle setting dataframe column

This commit is contained in:
2026-06-23 14:02:13 +02:00
parent c1b5284f72
commit 3bdbc80079
3 changed files with 124 additions and 41 deletions

View File

@@ -1,11 +1,92 @@
from typing import Optional from typing import Optional, TypeGuard, cast
from midas.checker.types import ColumnType, DataFrameType from midas.ast.location import Location
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
import midas.ast.python as p
def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
return all(isinstance(expr, p.LiteralExpr) for expr in exprs)
class FrameManager: class FrameManager:
def __init__(self, types: TypesRegistry) -> None:
self.types: TypesRegistry = types
def assign(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
index: p.Expr,
value_type: Type,
) -> Type:
match index:
case p.LiteralExpr(value=str() as name):
return self.assign_column(reporter, location, frame, name, value_type)
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
isinstance(idx, str) for idx in indices
):
raise NotImplementedError
case _:
reporter.error(location, f"Invalid index type {index} on {frame}")
return UnknownType()
def assign_column(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
name: str,
type: Type,
) -> Type:
if not isinstance(type, ColumnType):
reporter.error(
location,
f"Cannot assign {type} to dataframe column. Must be a ColumnType",
)
return frame
return self._set_column(frame, name, type)
def get(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
index: p.Expr,
) -> Type:
match index:
case p.LiteralExpr(value=str() as name):
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
if column is None:
reporter.error(location, f"Unknown column '{name}' on {frame}")
return UnknownType()
return column
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
isinstance(index.value, str) for index in indices
):
names: list[str] = [cast(str, index.value) for index in indices]
columns: list[ColumnType] = []
for name in names:
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
if column is None:
reporter.error(location, f"Unknown column '{name}' on {frame}")
return UnknownType()
columns.append(column)
return TupleType(items=tuple(columns))
case _:
reporter.error(location, f"Invalid index type {index} on {frame}")
return UnknownType()
@classmethod @classmethod
def set_column( def _set_column(
cls, frame: DataFrameType, name: str, column: ColumnType cls, frame: DataFrameType, name: str, column: ColumnType
) -> DataFrameType: ) -> DataFrameType:
new_columns: list[DataFrameType.Column] = [] new_columns: list[DataFrameType.Column] = []
@@ -15,6 +96,7 @@ class FrameManager:
if col.name == name: if col.name == name:
index = i index = i
replace = True replace = True
# TODO: check column type here to prevent changing it
new_columns.append(col) new_columns.append(col)
new_col: DataFrameType.Column = DataFrameType.Column( new_col: DataFrameType.Column = DataFrameType.Column(
@@ -30,22 +112,22 @@ class FrameManager:
return DataFrameType(columns=new_columns) return DataFrameType(columns=new_columns)
@classmethod @classmethod
def set_columns( def _set_columns(
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType] cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
) -> DataFrameType: ) -> DataFrameType:
for name, col in zip(names, columns): for name, col in zip(names, columns):
frame = cls.set_column(frame, name, col) frame = cls._set_column(frame, name, col)
return frame return frame
@classmethod @classmethod
def get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]: def _get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
for col in frame.columns: for col in frame.columns:
if col.name == name: if col.name == name:
return col.type return col.type
return None return None
@classmethod @classmethod
def get_columns( def _get_columns(
cls, frame: DataFrameType, names: list[str] cls, frame: DataFrameType, names: list[str]
) -> list[Optional[ColumnType]]: ) -> list[Optional[ColumnType]]:
return [cls.get_column(frame, name) for name in names] return [cls._get_column(frame, name) for name in names]

View File

@@ -1,7 +1,7 @@
import ast import ast
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, cast from typing import Any, Optional
import midas.ast.python as p import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
@@ -75,6 +75,7 @@ class PythonTyper(
self.logger: logging.Logger = logging.getLogger("PythonTyper") self.logger: logging.Logger = logging.getLogger("PythonTyper")
self.reporter: FileReporter = reporter.for_file(None) self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types self.types: TypesRegistry = types
self.frame_mgr: FrameManager = FrameManager(self.types)
self.global_env: Environment = Preamble(self.types) self.global_env: Environment = Preamble(self.types)
self.env: Environment = self.global_env self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {} self.locals: dict[p.Expr, int] = {}
@@ -323,9 +324,15 @@ class PythonTyper(
case p.VariableExpr(): case p.VariableExpr():
self._assign_var(location, target, value_type) self._assign_var(location, target, value_type)
# Allow any kind of object because we disallow creating new attributes
case p.GetExpr(object=object, name=name): case p.GetExpr(object=object, name=name):
self._assign_attr(location, object, name, value_type) self._assign_attr(location, object, name, value_type)
# Only support variable expressions because modifying
# the underlying value would require reference types
case p.SubscriptExpr(object=p.VariableExpr() as var, index=index):
self._assign_sub(location, var, index, value_type)
case _: case _:
if not isinstance(target, p.VariableExpr): if not isinstance(target, p.VariableExpr):
self.logger.warning(f"Unsupported assignment to {target}") self.logger.warning(f"Unsupported assignment to {target}")
@@ -364,6 +371,27 @@ class PythonTyper(
f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}", f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}",
) )
def _assign_sub(
self,
location: Location,
var: p.VariableExpr,
index: p.Expr,
value_type: Type,
):
var_type: Type = self.type_of(var)
# TODO: what happens if type is an alias of a dataframe type
match var_type:
case DataFrameType() as frame:
new_type: Type = self.frame_mgr.assign(
self.reporter, location, frame, index, value_type
)
self.env.assign(var.name, new_type)
case _:
self.reporter.error(
location,
f"Cannot assign {value_type} to index {index} of {var_type}",
)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType() type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType()
self.env.return_types.append(type) self.env.return_types.append(type)
@@ -1259,35 +1287,4 @@ class PythonTyper(
def _visit_frame_subscript( def _visit_frame_subscript(
self, frame: DataFrameType, expr: p.SubscriptExpr self, frame: DataFrameType, expr: p.SubscriptExpr
) -> Type: ) -> Type:
match expr.index: return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index)
case p.LiteralExpr(value=str() as name):
column: Optional[ColumnType] = FrameManager.get_column(frame, name)
if column is None:
self.reporter.error(
expr.location, f"Unknown column '{name}' on {frame}"
)
return UnknownType()
return column
case p.ListExpr(items=indices) if all(
isinstance(index, p.LiteralExpr) and isinstance(index.value, str)
for index in indices
):
indices = cast(list[p.LiteralExpr], indices)
names: list[str] = [cast(str, index.value) for index in indices]
columns: list[ColumnType] = []
for name in names:
column: Optional[ColumnType] = FrameManager.get_column(frame, name)
if column is None:
self.reporter.error(
expr.location, f"Unknown column '{name}' on {frame}"
)
return UnknownType()
columns.append(column)
return TupleType(items=tuple(columns))
case _:
self.reporter.error(
expr.location, f"Invalid index type {expr.index} on {frame}"
)
return UnknownType()

View File

@@ -128,6 +128,10 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
case p.GetExpr(): case p.GetExpr():
target.accept(self) target.accept(self)
case p.SubscriptExpr():
target.accept(self)
case _: case _:
raise Exception(f"Unsupported assignment to {target}") raise Exception(f"Unsupported assignment to {target}")