From a80da5db2c192100c47f3cd51bdaf141b3dfa646 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Mon, 22 Jun 2026 10:29:06 +0200 Subject: [PATCH] feat(types): add DataFrameType and ColumnType --- midas/checker/python.py | 25 ++++++++++++++++++----- midas/checker/types.py | 44 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/midas/checker/python.py b/midas/checker/python.py index ffa1600..b210f2e 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -21,7 +21,9 @@ from midas.checker.types import ( AliasType, AppliedType, BaseType, + ColumnType, ConstraintType, + DataFrameType, Function, GenericType, OverloadedFunction, @@ -677,13 +679,26 @@ class PythonTyper( self.reporter.warning(node.location, "ConstraintType not yet supported") return UnknownType() - def visit_frame_column(self, node: p.FrameColumn) -> Type: - self.reporter.warning(node.location, "FrameColumn not yet supported") - return UnknownType() + def visit_frame_column(self, node: p.FrameColumn) -> ColumnType: + return ColumnType( + type=( + self.resolve_type_expr(node.type) + if node.type is not None + else UnknownType() + ) + ) def visit_frame_type(self, node: p.FrameType) -> Type: - self.reporter.warning(node.location, "FrameType not yet supported") - return UnknownType() + return DataFrameType( + columns=[ + DataFrameType.Column( + index=i, + name=column.name, + type=self.visit_frame_column(column), + ) + for i, column in enumerate(node.columns) + ] + ) def _get_call_result( self, diff --git a/midas/checker/types.py b/midas/checker/types.py index 4ebda35..d580f62 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -2,7 +2,7 @@ from __future__ import annotations from dataclasses import dataclass, field from enum import StrEnum -from typing import Optional, assert_never +from typing import Optional, assert_never, cast import midas.ast.midas as m from midas.ast.printer import MidasPrinter @@ -156,6 +156,22 @@ class ConstraintType: return f"{self.type} where {printer.print(self.constraint)}" +@dataclass(frozen=True, kw_only=True) +class ColumnType: + type: Type + + +@dataclass(frozen=True, kw_only=True) +class DataFrameType: + columns: list[Column] + + @dataclass(frozen=True, kw_only=True) + class Column: + index: int + name: Optional[str] + type: ColumnType + + def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: def sub_argument(arg: Function.Argument): return Function.Argument( @@ -165,6 +181,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: required=arg.required, ) + def sub_column(col: DataFrameType.Column): + return DataFrameType.Column( + index=col.index, + name=col.name, + type=cast(ColumnType, substitute_typevars(col.type, substitutions)), + ) + match type: case TopType(): return type @@ -250,10 +273,21 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: body=substitute_typevars(body, substitutions), ) + case ColumnType(type=items_type): + return ColumnType( + type=substitute_typevars(items_type, substitutions), + ) + + case DataFrameType(columns=columns): + return DataFrameType( + columns=list(map(sub_column, columns)), + ) + case UnknownType() | UnitType(): return type case TopType() | GenericType(): + raise NotImplementedError(f"Unsupported type {type}") # Ensure exhaustiveness @@ -317,6 +351,12 @@ def to_annotation(type: Type) -> str: case ConstraintType(): return str(type) + case ColumnType(): + return "pd.Series" + + case DataFrameType(): + return "pd.DataFrame" + case _: assert_never(type) @@ -342,4 +382,6 @@ Type = ( | GenericType | AppliedType | ConstraintType + | ColumnType + | DataFrameType )