feat(types): add DataFrameType and ColumnType
This commit is contained in:
@@ -21,7 +21,9 @@ from midas.checker.types import (
|
|||||||
AliasType,
|
AliasType,
|
||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
|
ColumnType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
|
DataFrameType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
@@ -677,13 +679,26 @@ class PythonTyper(
|
|||||||
self.reporter.warning(node.location, "ConstraintType not yet supported")
|
self.reporter.warning(node.location, "ConstraintType not yet supported")
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
def visit_frame_column(self, node: p.FrameColumn) -> Type:
|
def visit_frame_column(self, node: p.FrameColumn) -> ColumnType:
|
||||||
self.reporter.warning(node.location, "FrameColumn not yet supported")
|
return ColumnType(
|
||||||
return UnknownType()
|
type=(
|
||||||
|
self.resolve_type_expr(node.type)
|
||||||
|
if node.type is not None
|
||||||
|
else UnknownType()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def visit_frame_type(self, node: p.FrameType) -> Type:
|
def visit_frame_type(self, node: p.FrameType) -> Type:
|
||||||
self.reporter.warning(node.location, "FrameType not yet supported")
|
return DataFrameType(
|
||||||
return UnknownType()
|
columns=[
|
||||||
|
DataFrameType.Column(
|
||||||
|
index=i,
|
||||||
|
name=column.name,
|
||||||
|
type=self.visit_frame_column(column),
|
||||||
|
)
|
||||||
|
for i, column in enumerate(node.columns)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def _get_call_result(
|
def _get_call_result(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Optional, assert_never
|
from typing import Optional, assert_never, cast
|
||||||
|
|
||||||
import midas.ast.midas as m
|
import midas.ast.midas as m
|
||||||
from midas.ast.printer import MidasPrinter
|
from midas.ast.printer import MidasPrinter
|
||||||
@@ -156,6 +156,22 @@ class ConstraintType:
|
|||||||
return f"{self.type} where {printer.print(self.constraint)}"
|
return f"{self.type} where {printer.print(self.constraint)}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class ColumnType:
|
||||||
|
type: Type
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class DataFrameType:
|
||||||
|
columns: list[Column]
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Column:
|
||||||
|
index: int
|
||||||
|
name: Optional[str]
|
||||||
|
type: ColumnType
|
||||||
|
|
||||||
|
|
||||||
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||||
def sub_argument(arg: Function.Argument):
|
def sub_argument(arg: Function.Argument):
|
||||||
return Function.Argument(
|
return Function.Argument(
|
||||||
@@ -165,6 +181,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
required=arg.required,
|
required=arg.required,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def sub_column(col: DataFrameType.Column):
|
||||||
|
return DataFrameType.Column(
|
||||||
|
index=col.index,
|
||||||
|
name=col.name,
|
||||||
|
type=cast(ColumnType, substitute_typevars(col.type, substitutions)),
|
||||||
|
)
|
||||||
|
|
||||||
match type:
|
match type:
|
||||||
case TopType():
|
case TopType():
|
||||||
return type
|
return type
|
||||||
@@ -250,10 +273,21 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
body=substitute_typevars(body, substitutions),
|
body=substitute_typevars(body, substitutions),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case ColumnType(type=items_type):
|
||||||
|
return ColumnType(
|
||||||
|
type=substitute_typevars(items_type, substitutions),
|
||||||
|
)
|
||||||
|
|
||||||
|
case DataFrameType(columns=columns):
|
||||||
|
return DataFrameType(
|
||||||
|
columns=list(map(sub_column, columns)),
|
||||||
|
)
|
||||||
|
|
||||||
case UnknownType() | UnitType():
|
case UnknownType() | UnitType():
|
||||||
return type
|
return type
|
||||||
|
|
||||||
case TopType() | GenericType():
|
case TopType() | GenericType():
|
||||||
|
|
||||||
raise NotImplementedError(f"Unsupported type {type}")
|
raise NotImplementedError(f"Unsupported type {type}")
|
||||||
|
|
||||||
# Ensure exhaustiveness
|
# Ensure exhaustiveness
|
||||||
@@ -317,6 +351,12 @@ def to_annotation(type: Type) -> str:
|
|||||||
case ConstraintType():
|
case ConstraintType():
|
||||||
return str(type)
|
return str(type)
|
||||||
|
|
||||||
|
case ColumnType():
|
||||||
|
return "pd.Series"
|
||||||
|
|
||||||
|
case DataFrameType():
|
||||||
|
return "pd.DataFrame"
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
assert_never(type)
|
assert_never(type)
|
||||||
|
|
||||||
@@ -342,4 +382,6 @@ Type = (
|
|||||||
| GenericType
|
| GenericType
|
||||||
| AppliedType
|
| AppliedType
|
||||||
| ConstraintType
|
| ConstraintType
|
||||||
|
| ColumnType
|
||||||
|
| DataFrameType
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user