feat(checker): handle setting dataframe column
This commit is contained in:
@@ -1,11 +1,92 @@
|
|||||||
from typing import Optional
|
from typing import Optional, TypeGuard, cast
|
||||||
|
|
||||||
from midas.checker.types import ColumnType, DataFrameType
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.reporter import FileReporter
|
||||||
|
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
|
||||||
|
|
||||||
|
import midas.ast.python as p
|
||||||
|
|
||||||
|
|
||||||
|
def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
|
||||||
|
return all(isinstance(expr, p.LiteralExpr) for expr in exprs)
|
||||||
|
|
||||||
|
|
||||||
class FrameManager:
|
class FrameManager:
|
||||||
|
def __init__(self, types: TypesRegistry) -> None:
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
|
||||||
|
def assign(
|
||||||
|
self,
|
||||||
|
reporter: FileReporter,
|
||||||
|
location: Location,
|
||||||
|
frame: DataFrameType,
|
||||||
|
index: p.Expr,
|
||||||
|
value_type: Type,
|
||||||
|
) -> Type:
|
||||||
|
match index:
|
||||||
|
case p.LiteralExpr(value=str() as name):
|
||||||
|
return self.assign_column(reporter, location, frame, name, value_type)
|
||||||
|
|
||||||
|
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
|
||||||
|
isinstance(idx, str) for idx in indices
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
case _:
|
||||||
|
reporter.error(location, f"Invalid index type {index} on {frame}")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def assign_column(
|
||||||
|
self,
|
||||||
|
reporter: FileReporter,
|
||||||
|
location: Location,
|
||||||
|
frame: DataFrameType,
|
||||||
|
name: str,
|
||||||
|
type: Type,
|
||||||
|
) -> Type:
|
||||||
|
if not isinstance(type, ColumnType):
|
||||||
|
reporter.error(
|
||||||
|
location,
|
||||||
|
f"Cannot assign {type} to dataframe column. Must be a ColumnType",
|
||||||
|
)
|
||||||
|
return frame
|
||||||
|
return self._set_column(frame, name, type)
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
reporter: FileReporter,
|
||||||
|
location: Location,
|
||||||
|
frame: DataFrameType,
|
||||||
|
index: p.Expr,
|
||||||
|
) -> Type:
|
||||||
|
match index:
|
||||||
|
case p.LiteralExpr(value=str() as name):
|
||||||
|
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
|
||||||
|
if column is None:
|
||||||
|
reporter.error(location, f"Unknown column '{name}' on {frame}")
|
||||||
|
return UnknownType()
|
||||||
|
return column
|
||||||
|
|
||||||
|
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
|
||||||
|
isinstance(index.value, str) for index in indices
|
||||||
|
):
|
||||||
|
names: list[str] = [cast(str, index.value) for index in indices]
|
||||||
|
columns: list[ColumnType] = []
|
||||||
|
for name in names:
|
||||||
|
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
|
||||||
|
if column is None:
|
||||||
|
reporter.error(location, f"Unknown column '{name}' on {frame}")
|
||||||
|
return UnknownType()
|
||||||
|
columns.append(column)
|
||||||
|
return TupleType(items=tuple(columns))
|
||||||
|
|
||||||
|
case _:
|
||||||
|
reporter.error(location, f"Invalid index type {index} on {frame}")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_column(
|
def _set_column(
|
||||||
cls, frame: DataFrameType, name: str, column: ColumnType
|
cls, frame: DataFrameType, name: str, column: ColumnType
|
||||||
) -> DataFrameType:
|
) -> DataFrameType:
|
||||||
new_columns: list[DataFrameType.Column] = []
|
new_columns: list[DataFrameType.Column] = []
|
||||||
@@ -15,6 +96,7 @@ class FrameManager:
|
|||||||
if col.name == name:
|
if col.name == name:
|
||||||
index = i
|
index = i
|
||||||
replace = True
|
replace = True
|
||||||
|
# TODO: check column type here to prevent changing it
|
||||||
new_columns.append(col)
|
new_columns.append(col)
|
||||||
|
|
||||||
new_col: DataFrameType.Column = DataFrameType.Column(
|
new_col: DataFrameType.Column = DataFrameType.Column(
|
||||||
@@ -30,22 +112,22 @@ class FrameManager:
|
|||||||
return DataFrameType(columns=new_columns)
|
return DataFrameType(columns=new_columns)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_columns(
|
def _set_columns(
|
||||||
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
|
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
|
||||||
) -> DataFrameType:
|
) -> DataFrameType:
|
||||||
for name, col in zip(names, columns):
|
for name, col in zip(names, columns):
|
||||||
frame = cls.set_column(frame, name, col)
|
frame = cls._set_column(frame, name, col)
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
|
def _get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
|
||||||
for col in frame.columns:
|
for col in frame.columns:
|
||||||
if col.name == name:
|
if col.name == name:
|
||||||
return col.type
|
return col.type
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_columns(
|
def _get_columns(
|
||||||
cls, frame: DataFrameType, names: list[str]
|
cls, frame: DataFrameType, names: list[str]
|
||||||
) -> list[Optional[ColumnType]]:
|
) -> list[Optional[ColumnType]]:
|
||||||
return [cls.get_column(frame, name) for name in names]
|
return [cls._get_column(frame, name) for name in names]
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import ast
|
import ast
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional
|
||||||
|
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
@@ -75,6 +75,7 @@ class PythonTyper(
|
|||||||
self.logger: logging.Logger = logging.getLogger("PythonTyper")
|
self.logger: logging.Logger = logging.getLogger("PythonTyper")
|
||||||
self.reporter: FileReporter = reporter.for_file(None)
|
self.reporter: FileReporter = reporter.for_file(None)
|
||||||
self.types: TypesRegistry = types
|
self.types: TypesRegistry = types
|
||||||
|
self.frame_mgr: FrameManager = FrameManager(self.types)
|
||||||
self.global_env: Environment = Preamble(self.types)
|
self.global_env: Environment = Preamble(self.types)
|
||||||
self.env: Environment = self.global_env
|
self.env: Environment = self.global_env
|
||||||
self.locals: dict[p.Expr, int] = {}
|
self.locals: dict[p.Expr, int] = {}
|
||||||
@@ -323,9 +324,15 @@ class PythonTyper(
|
|||||||
case p.VariableExpr():
|
case p.VariableExpr():
|
||||||
self._assign_var(location, target, value_type)
|
self._assign_var(location, target, value_type)
|
||||||
|
|
||||||
|
# Allow any kind of object because we disallow creating new attributes
|
||||||
case p.GetExpr(object=object, name=name):
|
case p.GetExpr(object=object, name=name):
|
||||||
self._assign_attr(location, object, name, value_type)
|
self._assign_attr(location, object, name, value_type)
|
||||||
|
|
||||||
|
# Only support variable expressions because modifying
|
||||||
|
# the underlying value would require reference types
|
||||||
|
case p.SubscriptExpr(object=p.VariableExpr() as var, index=index):
|
||||||
|
self._assign_sub(location, var, index, value_type)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
if not isinstance(target, p.VariableExpr):
|
if not isinstance(target, p.VariableExpr):
|
||||||
self.logger.warning(f"Unsupported assignment to {target}")
|
self.logger.warning(f"Unsupported assignment to {target}")
|
||||||
@@ -364,6 +371,27 @@ class PythonTyper(
|
|||||||
f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}",
|
f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _assign_sub(
|
||||||
|
self,
|
||||||
|
location: Location,
|
||||||
|
var: p.VariableExpr,
|
||||||
|
index: p.Expr,
|
||||||
|
value_type: Type,
|
||||||
|
):
|
||||||
|
var_type: Type = self.type_of(var)
|
||||||
|
# TODO: what happens if type is an alias of a dataframe type
|
||||||
|
match var_type:
|
||||||
|
case DataFrameType() as frame:
|
||||||
|
new_type: Type = self.frame_mgr.assign(
|
||||||
|
self.reporter, location, frame, index, value_type
|
||||||
|
)
|
||||||
|
self.env.assign(var.name, new_type)
|
||||||
|
case _:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Cannot assign {value_type} to index {index} of {var_type}",
|
||||||
|
)
|
||||||
|
|
||||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||||
type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType()
|
type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType()
|
||||||
self.env.return_types.append(type)
|
self.env.return_types.append(type)
|
||||||
@@ -1259,35 +1287,4 @@ class PythonTyper(
|
|||||||
def _visit_frame_subscript(
|
def _visit_frame_subscript(
|
||||||
self, frame: DataFrameType, expr: p.SubscriptExpr
|
self, frame: DataFrameType, expr: p.SubscriptExpr
|
||||||
) -> Type:
|
) -> Type:
|
||||||
match expr.index:
|
return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index)
|
||||||
case p.LiteralExpr(value=str() as name):
|
|
||||||
column: Optional[ColumnType] = FrameManager.get_column(frame, name)
|
|
||||||
if column is None:
|
|
||||||
self.reporter.error(
|
|
||||||
expr.location, f"Unknown column '{name}' on {frame}"
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
return column
|
|
||||||
|
|
||||||
case p.ListExpr(items=indices) if all(
|
|
||||||
isinstance(index, p.LiteralExpr) and isinstance(index.value, str)
|
|
||||||
for index in indices
|
|
||||||
):
|
|
||||||
indices = cast(list[p.LiteralExpr], indices)
|
|
||||||
names: list[str] = [cast(str, index.value) for index in indices]
|
|
||||||
columns: list[ColumnType] = []
|
|
||||||
for name in names:
|
|
||||||
column: Optional[ColumnType] = FrameManager.get_column(frame, name)
|
|
||||||
if column is None:
|
|
||||||
self.reporter.error(
|
|
||||||
expr.location, f"Unknown column '{name}' on {frame}"
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
columns.append(column)
|
|
||||||
return TupleType(items=tuple(columns))
|
|
||||||
|
|
||||||
case _:
|
|
||||||
self.reporter.error(
|
|
||||||
expr.location, f"Invalid index type {expr.index} on {frame}"
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
|
|||||||
@@ -128,6 +128,10 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
|
|
||||||
case p.GetExpr():
|
case p.GetExpr():
|
||||||
target.accept(self)
|
target.accept(self)
|
||||||
|
|
||||||
|
case p.SubscriptExpr():
|
||||||
|
target.accept(self)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise Exception(f"Unsupported assignment to {target}")
|
raise Exception(f"Unsupported assignment to {target}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user