diff --git a/midas/checker/frames.py b/midas/checker/frames.py index 069143a..0dbd95b 100644 --- a/midas/checker/frames.py +++ b/midas/checker/frames.py @@ -1,11 +1,92 @@ -from typing import Optional +from typing import Optional, TypeGuard, cast -from midas.checker.types import ColumnType, DataFrameType +from midas.ast.location import Location +from midas.checker.registry import TypesRegistry +from midas.checker.reporter import FileReporter +from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType + +import midas.ast.python as p + + +def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]: + return all(isinstance(expr, p.LiteralExpr) for expr in exprs) class FrameManager: + def __init__(self, types: TypesRegistry) -> None: + self.types: TypesRegistry = types + + def assign( + self, + reporter: FileReporter, + location: Location, + frame: DataFrameType, + index: p.Expr, + value_type: Type, + ) -> Type: + match index: + case p.LiteralExpr(value=str() as name): + return self.assign_column(reporter, location, frame, name, value_type) + + case p.ListExpr(items=indices) if is_list_of_literals(indices) and all( + isinstance(idx, str) for idx in indices + ): + raise NotImplementedError + + case _: + reporter.error(location, f"Invalid index type {index} on {frame}") + return UnknownType() + + def assign_column( + self, + reporter: FileReporter, + location: Location, + frame: DataFrameType, + name: str, + type: Type, + ) -> Type: + if not isinstance(type, ColumnType): + reporter.error( + location, + f"Cannot assign {type} to dataframe column. Must be a ColumnType", + ) + return frame + return self._set_column(frame, name, type) + + def get( + self, + reporter: FileReporter, + location: Location, + frame: DataFrameType, + index: p.Expr, + ) -> Type: + match index: + case p.LiteralExpr(value=str() as name): + column: Optional[ColumnType] = FrameManager._get_column(frame, name) + if column is None: + reporter.error(location, f"Unknown column '{name}' on {frame}") + return UnknownType() + return column + + case p.ListExpr(items=indices) if is_list_of_literals(indices) and all( + isinstance(index.value, str) for index in indices + ): + names: list[str] = [cast(str, index.value) for index in indices] + columns: list[ColumnType] = [] + for name in names: + column: Optional[ColumnType] = FrameManager._get_column(frame, name) + if column is None: + reporter.error(location, f"Unknown column '{name}' on {frame}") + return UnknownType() + columns.append(column) + return TupleType(items=tuple(columns)) + + case _: + reporter.error(location, f"Invalid index type {index} on {frame}") + return UnknownType() + @classmethod - def set_column( + def _set_column( cls, frame: DataFrameType, name: str, column: ColumnType ) -> DataFrameType: new_columns: list[DataFrameType.Column] = [] @@ -15,6 +96,7 @@ class FrameManager: if col.name == name: index = i replace = True + # TODO: check column type here to prevent changing it new_columns.append(col) new_col: DataFrameType.Column = DataFrameType.Column( @@ -30,22 +112,22 @@ class FrameManager: return DataFrameType(columns=new_columns) @classmethod - def set_columns( + def _set_columns( cls, frame: DataFrameType, names: list[str], columns: list[ColumnType] ) -> DataFrameType: for name, col in zip(names, columns): - frame = cls.set_column(frame, name, col) + frame = cls._set_column(frame, name, col) return frame @classmethod - def get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]: + def _get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]: for col in frame.columns: if col.name == name: return col.type return None @classmethod - def get_columns( + def _get_columns( cls, frame: DataFrameType, names: list[str] ) -> list[Optional[ColumnType]]: - return [cls.get_column(frame, name) for name in names] + return [cls._get_column(frame, name) for name in names] diff --git a/midas/checker/python.py b/midas/checker/python.py index ceb9e6a..cd46cdd 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -1,7 +1,7 @@ import ast import logging from dataclasses import dataclass -from typing import Any, Optional, cast +from typing import Any, Optional import midas.ast.python as p from midas.ast.location import Location @@ -75,6 +75,7 @@ class PythonTyper( self.logger: logging.Logger = logging.getLogger("PythonTyper") self.reporter: FileReporter = reporter.for_file(None) self.types: TypesRegistry = types + self.frame_mgr: FrameManager = FrameManager(self.types) self.global_env: Environment = Preamble(self.types) self.env: Environment = self.global_env self.locals: dict[p.Expr, int] = {} @@ -323,9 +324,15 @@ class PythonTyper( case p.VariableExpr(): self._assign_var(location, target, value_type) + # Allow any kind of object because we disallow creating new attributes case p.GetExpr(object=object, name=name): self._assign_attr(location, object, name, value_type) + # Only support variable expressions because modifying + # the underlying value would require reference types + case p.SubscriptExpr(object=p.VariableExpr() as var, index=index): + self._assign_sub(location, var, index, value_type) + case _: if not isinstance(target, p.VariableExpr): self.logger.warning(f"Unsupported assignment to {target}") @@ -364,6 +371,27 @@ class PythonTyper( f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}", ) + def _assign_sub( + self, + location: Location, + var: p.VariableExpr, + index: p.Expr, + value_type: Type, + ): + var_type: Type = self.type_of(var) + # TODO: what happens if type is an alias of a dataframe type + match var_type: + case DataFrameType() as frame: + new_type: Type = self.frame_mgr.assign( + self.reporter, location, frame, index, value_type + ) + self.env.assign(var.name, new_type) + case _: + self.reporter.error( + location, + f"Cannot assign {value_type} to index {index} of {var_type}", + ) + def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType() self.env.return_types.append(type) @@ -1259,35 +1287,4 @@ class PythonTyper( def _visit_frame_subscript( self, frame: DataFrameType, expr: p.SubscriptExpr ) -> Type: - match expr.index: - case p.LiteralExpr(value=str() as name): - column: Optional[ColumnType] = FrameManager.get_column(frame, name) - if column is None: - self.reporter.error( - expr.location, f"Unknown column '{name}' on {frame}" - ) - return UnknownType() - return column - - case p.ListExpr(items=indices) if all( - isinstance(index, p.LiteralExpr) and isinstance(index.value, str) - for index in indices - ): - indices = cast(list[p.LiteralExpr], indices) - names: list[str] = [cast(str, index.value) for index in indices] - columns: list[ColumnType] = [] - for name in names: - column: Optional[ColumnType] = FrameManager.get_column(frame, name) - if column is None: - self.reporter.error( - expr.location, f"Unknown column '{name}' on {frame}" - ) - return UnknownType() - columns.append(column) - return TupleType(items=tuple(columns)) - - case _: - self.reporter.error( - expr.location, f"Invalid index type {expr.index} on {frame}" - ) - return UnknownType() + return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index) diff --git a/midas/checker/resolver.py b/midas/checker/resolver.py index 3bf73d7..753f4db 100644 --- a/midas/checker/resolver.py +++ b/midas/checker/resolver.py @@ -128,6 +128,10 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): case p.GetExpr(): target.accept(self) + + case p.SubscriptExpr(): + target.accept(self) + case _: raise Exception(f"Unsupported assignment to {target}")