feat(types): add DataFrameType and ColumnType

This commit is contained in:
2026-06-22 10:29:06 +02:00
parent f7c43837b5
commit a80da5db2c
2 changed files with 63 additions and 6 deletions

View File

@@ -21,7 +21,9 @@ from midas.checker.types import (
AliasType, AliasType,
AppliedType, AppliedType,
BaseType, BaseType,
ColumnType,
ConstraintType, ConstraintType,
DataFrameType,
Function, Function,
GenericType, GenericType,
OverloadedFunction, OverloadedFunction,
@@ -677,13 +679,26 @@ class PythonTyper(
self.reporter.warning(node.location, "ConstraintType not yet supported") self.reporter.warning(node.location, "ConstraintType not yet supported")
return UnknownType() return UnknownType()
def visit_frame_column(self, node: p.FrameColumn) -> Type: def visit_frame_column(self, node: p.FrameColumn) -> ColumnType:
self.reporter.warning(node.location, "FrameColumn not yet supported") return ColumnType(
return UnknownType() type=(
self.resolve_type_expr(node.type)
if node.type is not None
else UnknownType()
)
)
def visit_frame_type(self, node: p.FrameType) -> Type: def visit_frame_type(self, node: p.FrameType) -> Type:
self.reporter.warning(node.location, "FrameType not yet supported") return DataFrameType(
return UnknownType() columns=[
DataFrameType.Column(
index=i,
name=column.name,
type=self.visit_frame_column(column),
)
for i, column in enumerate(node.columns)
]
)
def _get_call_result( def _get_call_result(
self, self,

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import StrEnum from enum import StrEnum
from typing import Optional, assert_never from typing import Optional, assert_never, cast
import midas.ast.midas as m import midas.ast.midas as m
from midas.ast.printer import MidasPrinter from midas.ast.printer import MidasPrinter
@@ -156,6 +156,22 @@ class ConstraintType:
return f"{self.type} where {printer.print(self.constraint)}" return f"{self.type} where {printer.print(self.constraint)}"
@dataclass(frozen=True, kw_only=True)
class ColumnType:
type: Type
@dataclass(frozen=True, kw_only=True)
class DataFrameType:
columns: list[Column]
@dataclass(frozen=True, kw_only=True)
class Column:
index: int
name: Optional[str]
type: ColumnType
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def sub_argument(arg: Function.Argument): def sub_argument(arg: Function.Argument):
return Function.Argument( return Function.Argument(
@@ -165,6 +181,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
required=arg.required, required=arg.required,
) )
def sub_column(col: DataFrameType.Column):
return DataFrameType.Column(
index=col.index,
name=col.name,
type=cast(ColumnType, substitute_typevars(col.type, substitutions)),
)
match type: match type:
case TopType(): case TopType():
return type return type
@@ -250,10 +273,21 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
body=substitute_typevars(body, substitutions), body=substitute_typevars(body, substitutions),
) )
case ColumnType(type=items_type):
return ColumnType(
type=substitute_typevars(items_type, substitutions),
)
case DataFrameType(columns=columns):
return DataFrameType(
columns=list(map(sub_column, columns)),
)
case UnknownType() | UnitType(): case UnknownType() | UnitType():
return type return type
case TopType() | GenericType(): case TopType() | GenericType():
raise NotImplementedError(f"Unsupported type {type}") raise NotImplementedError(f"Unsupported type {type}")
# Ensure exhaustiveness # Ensure exhaustiveness
@@ -317,6 +351,12 @@ def to_annotation(type: Type) -> str:
case ConstraintType(): case ConstraintType():
return str(type) return str(type)
case ColumnType():
return "pd.Series"
case DataFrameType():
return "pd.DataFrame"
case _: case _:
assert_never(type) assert_never(type)
@@ -342,4 +382,6 @@ Type = (
| GenericType | GenericType
| AppliedType | AppliedType
| ConstraintType | ConstraintType
| ColumnType
| DataFrameType
) )