diff --git a/midas/checker/python.py b/midas/checker/python.py index b210f2e..5ea0c85 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -27,6 +27,7 @@ from midas.checker.types import ( Function, GenericType, OverloadedFunction, + TupleType, Type, TypeVar, UnitType, @@ -642,6 +643,11 @@ class PythonTyper( def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type: object: Type = self.type_of(expr.object) + unfolded: Type = unfold_type(object) + match unfolded: + case TupleType(): + return self._visit_tuple_subscript(unfolded, expr) + operation: Optional[Type] = self.types.lookup_member(object, "__getitem__") if operation is None: self.reporter.error( @@ -1231,3 +1237,18 @@ class PythonTyper( expr.location, f"Cannot evaluate cast to {target_type} statically" ) return False + + def _visit_tuple_subscript(self, tup: TupleType, expr: p.SubscriptExpr) -> Type: + match expr.index: + case p.LiteralExpr(value=int() as index): + if index < 0 or index >= len(tup.items): + self.reporter.error( + expr.location, f"Index {index} out of range for tuple {tup}" + ) + return UnknownType() + return tup.items[index] + case _: + self.reporter.error( + expr.location, f"Invalid index type {expr.index} on {tup}" + ) + return UnknownType() diff --git a/midas/checker/types.py b/midas/checker/types.py index 471e805..07acfd4 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -156,6 +156,14 @@ class ConstraintType: return f"{self.type} where {printer.print(self.constraint)}" +@dataclass(frozen=True, kw_only=True) +class TupleType: + items: tuple[Type, ...] + + def __str__(self) -> str: + return f"({', '.join(map(str, self.items))})" + + @dataclass(frozen=True, kw_only=True) class ColumnType: type: Type @@ -280,6 +288,11 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: body=substitute_typevars(body, substitutions), ) + case TupleType(items=items): + return TupleType( + items=tuple(substitute_typevars(item, substitutions) for item in items), + ) + case ColumnType(type=items_type): return ColumnType( type=substitute_typevars(items_type, substitutions), @@ -358,6 +371,9 @@ def to_annotation(type: Type) -> str: case ConstraintType(): return str(type) + case TupleType(items=items): + return f"Tuple[{', '.join(map(to_annotation, items))}]" + case ColumnType(): return "pd.Series" @@ -389,6 +405,7 @@ Type = ( | GenericType | AppliedType | ConstraintType + | TupleType | ColumnType | DataFrameType )