feat(types): add DataFrameType and ColumnType

This commit is contained in:
2026-06-22 10:29:06 +02:00
parent f7c43837b5
commit a80da5db2c
2 changed files with 63 additions and 6 deletions

View File

@@ -21,7 +21,9 @@ from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ColumnType,
ConstraintType,
DataFrameType,
Function,
GenericType,
OverloadedFunction,
@@ -677,13 +679,26 @@ class PythonTyper(
self.reporter.warning(node.location, "ConstraintType not yet supported")
return UnknownType()
def visit_frame_column(self, node: p.FrameColumn) -> Type:
self.reporter.warning(node.location, "FrameColumn not yet supported")
return UnknownType()
def visit_frame_column(self, node: p.FrameColumn) -> ColumnType:
return ColumnType(
type=(
self.resolve_type_expr(node.type)
if node.type is not None
else UnknownType()
)
)
def visit_frame_type(self, node: p.FrameType) -> Type:
self.reporter.warning(node.location, "FrameType not yet supported")
return UnknownType()
return DataFrameType(
columns=[
DataFrameType.Column(
index=i,
name=column.name,
type=self.visit_frame_column(column),
)
for i, column in enumerate(node.columns)
]
)
def _get_call_result(
self,

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Optional, assert_never
from typing import Optional, assert_never, cast
import midas.ast.midas as m
from midas.ast.printer import MidasPrinter
@@ -156,6 +156,22 @@ class ConstraintType:
return f"{self.type} where {printer.print(self.constraint)}"
@dataclass(frozen=True, kw_only=True)
class ColumnType:
type: Type
@dataclass(frozen=True, kw_only=True)
class DataFrameType:
columns: list[Column]
@dataclass(frozen=True, kw_only=True)
class Column:
index: int
name: Optional[str]
type: ColumnType
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def sub_argument(arg: Function.Argument):
return Function.Argument(
@@ -165,6 +181,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
required=arg.required,
)
def sub_column(col: DataFrameType.Column):
return DataFrameType.Column(
index=col.index,
name=col.name,
type=cast(ColumnType, substitute_typevars(col.type, substitutions)),
)
match type:
case TopType():
return type
@@ -250,10 +273,21 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
body=substitute_typevars(body, substitutions),
)
case ColumnType(type=items_type):
return ColumnType(
type=substitute_typevars(items_type, substitutions),
)
case DataFrameType(columns=columns):
return DataFrameType(
columns=list(map(sub_column, columns)),
)
case UnknownType() | UnitType():
return type
case TopType() | GenericType():
raise NotImplementedError(f"Unsupported type {type}")
# Ensure exhaustiveness
@@ -317,6 +351,12 @@ def to_annotation(type: Type) -> str:
case ConstraintType():
return str(type)
case ColumnType():
return "pd.Series"
case DataFrameType():
return "pd.DataFrame"
case _:
assert_never(type)
@@ -342,4 +382,6 @@ Type = (
| GenericType
| AppliedType
| ConstraintType
| ColumnType
| DataFrameType
)