Compare commits
6 Commits
9c7a93412c
...
9145496587
| Author | SHA1 | Date | |
|---|---|---|---|
|
9145496587
|
|||
|
9144995b79
|
|||
|
45e27ee04e
|
|||
|
83bd3793df
|
|||
|
c2a5517d09
|
|||
|
27a18580a5
|
133
midas/checker/frames.py
Normal file
133
midas/checker/frames.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
from typing import Optional, TypeGuard, cast
|
||||||
|
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.reporter import FileReporter
|
||||||
|
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
|
||||||
|
|
||||||
|
import midas.ast.python as p
|
||||||
|
|
||||||
|
|
||||||
|
def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
|
||||||
|
return all(isinstance(expr, p.LiteralExpr) for expr in exprs)
|
||||||
|
|
||||||
|
|
||||||
|
class FrameManager:
|
||||||
|
def __init__(self, types: TypesRegistry) -> None:
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
|
||||||
|
def assign(
|
||||||
|
self,
|
||||||
|
reporter: FileReporter,
|
||||||
|
location: Location,
|
||||||
|
frame: DataFrameType,
|
||||||
|
index: p.Expr,
|
||||||
|
value_type: Type,
|
||||||
|
) -> Type:
|
||||||
|
match index:
|
||||||
|
case p.LiteralExpr(value=str() as name):
|
||||||
|
return self.assign_column(reporter, location, frame, name, value_type)
|
||||||
|
|
||||||
|
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
|
||||||
|
isinstance(idx, str) for idx in indices
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
case _:
|
||||||
|
reporter.error(location, f"Invalid index type {index} on {frame}")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def assign_column(
|
||||||
|
self,
|
||||||
|
reporter: FileReporter,
|
||||||
|
location: Location,
|
||||||
|
frame: DataFrameType,
|
||||||
|
name: str,
|
||||||
|
type: Type,
|
||||||
|
) -> Type:
|
||||||
|
if not isinstance(type, ColumnType):
|
||||||
|
reporter.error(
|
||||||
|
location,
|
||||||
|
f"Cannot assign {type} to dataframe column. Must be a ColumnType",
|
||||||
|
)
|
||||||
|
return frame
|
||||||
|
return self._set_column(frame, name, type)
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
reporter: FileReporter,
|
||||||
|
location: Location,
|
||||||
|
frame: DataFrameType,
|
||||||
|
index: p.Expr,
|
||||||
|
) -> Type:
|
||||||
|
match index:
|
||||||
|
case p.LiteralExpr(value=str() as name):
|
||||||
|
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
|
||||||
|
if column is None:
|
||||||
|
reporter.error(location, f"Unknown column '{name}' on {frame}")
|
||||||
|
return UnknownType()
|
||||||
|
return column
|
||||||
|
|
||||||
|
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
|
||||||
|
isinstance(index.value, str) for index in indices
|
||||||
|
):
|
||||||
|
names: list[str] = [cast(str, index.value) for index in indices]
|
||||||
|
columns: list[ColumnType] = []
|
||||||
|
for name in names:
|
||||||
|
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
|
||||||
|
if column is None:
|
||||||
|
reporter.error(location, f"Unknown column '{name}' on {frame}")
|
||||||
|
return UnknownType()
|
||||||
|
columns.append(column)
|
||||||
|
return TupleType(items=tuple(columns))
|
||||||
|
|
||||||
|
case _:
|
||||||
|
reporter.error(location, f"Invalid index type {index} on {frame}")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _set_column(
|
||||||
|
cls, frame: DataFrameType, name: str, column: ColumnType
|
||||||
|
) -> DataFrameType:
|
||||||
|
new_columns: list[DataFrameType.Column] = []
|
||||||
|
index: int = len(frame.columns)
|
||||||
|
replace: bool = False
|
||||||
|
for i, col in enumerate(frame.columns):
|
||||||
|
if col.name == name:
|
||||||
|
index = i
|
||||||
|
replace = True
|
||||||
|
# TODO: check column type here to prevent changing it
|
||||||
|
new_columns.append(col)
|
||||||
|
|
||||||
|
new_col: DataFrameType.Column = DataFrameType.Column(
|
||||||
|
index=index,
|
||||||
|
name=name,
|
||||||
|
type=column,
|
||||||
|
)
|
||||||
|
if replace:
|
||||||
|
new_columns[index] = new_col
|
||||||
|
else:
|
||||||
|
new_columns.append(new_col)
|
||||||
|
|
||||||
|
return DataFrameType(columns=new_columns)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _set_columns(
|
||||||
|
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
|
||||||
|
) -> DataFrameType:
|
||||||
|
for name, col in zip(names, columns):
|
||||||
|
frame = cls._set_column(frame, name, col)
|
||||||
|
return frame
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
|
||||||
|
for col in frame.columns:
|
||||||
|
if col.name == name:
|
||||||
|
return col.type
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_columns(
|
||||||
|
cls, frame: DataFrameType, names: list[str]
|
||||||
|
) -> list[Optional[ColumnType]]:
|
||||||
|
return [cls._get_column(frame, name) for name in names]
|
||||||
@@ -6,6 +6,7 @@ from typing import Optional
|
|||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
|
from midas.checker.frames import FrameManager
|
||||||
from midas.checker.operators import (
|
from midas.checker.operators import (
|
||||||
PY_COMPARATOR_METHODS,
|
PY_COMPARATOR_METHODS,
|
||||||
PY_OPERATOR_METHODS,
|
PY_OPERATOR_METHODS,
|
||||||
@@ -18,9 +19,12 @@ from midas.checker.resolver import Resolver
|
|||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
AliasType,
|
||||||
AppliedType,
|
AppliedType,
|
||||||
|
ColumnType,
|
||||||
|
DataFrameType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
|
TupleType,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnitType,
|
UnitType,
|
||||||
@@ -67,6 +71,7 @@ class PythonTyper(
|
|||||||
self.logger: logging.Logger = logging.getLogger("PythonTyper")
|
self.logger: logging.Logger = logging.getLogger("PythonTyper")
|
||||||
self.reporter: FileReporter = reporter.for_file(None)
|
self.reporter: FileReporter = reporter.for_file(None)
|
||||||
self.types: TypesRegistry = types
|
self.types: TypesRegistry = types
|
||||||
|
self.frame_mgr: FrameManager = FrameManager(self.types)
|
||||||
self.global_env: Environment = Preamble(self.types)
|
self.global_env: Environment = Preamble(self.types)
|
||||||
self.env: Environment = self.global_env
|
self.env: Environment = self.global_env
|
||||||
self.locals: dict[p.Expr, int] = {}
|
self.locals: dict[p.Expr, int] = {}
|
||||||
@@ -309,9 +314,15 @@ class PythonTyper(
|
|||||||
case p.VariableExpr():
|
case p.VariableExpr():
|
||||||
self._assign_var(location, target, value_type)
|
self._assign_var(location, target, value_type)
|
||||||
|
|
||||||
|
# Allow any kind of object because we disallow creating new attributes
|
||||||
case p.GetExpr(object=object, name=name):
|
case p.GetExpr(object=object, name=name):
|
||||||
self._assign_attr(location, object, name, value_type)
|
self._assign_attr(location, object, name, value_type)
|
||||||
|
|
||||||
|
# Only support variable expressions because modifying
|
||||||
|
# the underlying value would require reference types
|
||||||
|
case p.SubscriptExpr(object=p.VariableExpr() as var, index=index):
|
||||||
|
self._assign_sub(location, var, index, value_type)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
if not isinstance(target, p.VariableExpr):
|
if not isinstance(target, p.VariableExpr):
|
||||||
self.logger.warning(f"Unsupported assignment to {target}")
|
self.logger.warning(f"Unsupported assignment to {target}")
|
||||||
@@ -350,6 +361,27 @@ class PythonTyper(
|
|||||||
f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}",
|
f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _assign_sub(
|
||||||
|
self,
|
||||||
|
location: Location,
|
||||||
|
var: p.VariableExpr,
|
||||||
|
index: p.Expr,
|
||||||
|
value_type: Type,
|
||||||
|
):
|
||||||
|
var_type: Type = self.type_of(var)
|
||||||
|
# TODO: what happens if type is an alias of a dataframe type
|
||||||
|
match var_type:
|
||||||
|
case DataFrameType() as frame:
|
||||||
|
new_type: Type = self.frame_mgr.assign(
|
||||||
|
self.reporter, location, frame, index, value_type
|
||||||
|
)
|
||||||
|
self.env.assign(var.name, new_type)
|
||||||
|
case _:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Cannot assign {value_type} to index {index} of {var_type}",
|
||||||
|
)
|
||||||
|
|
||||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||||
type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType()
|
type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType()
|
||||||
self.env.return_types.append(type)
|
self.env.return_types.append(type)
|
||||||
@@ -622,6 +654,13 @@ class PythonTyper(
|
|||||||
|
|
||||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type:
|
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type:
|
||||||
object: Type = self.type_of(expr.object)
|
object: Type = self.type_of(expr.object)
|
||||||
|
unfolded: Type = unfold_type(object)
|
||||||
|
match unfolded:
|
||||||
|
case TupleType():
|
||||||
|
return self._visit_tuple_subscript(unfolded, expr)
|
||||||
|
case DataFrameType():
|
||||||
|
return self._visit_frame_subscript(unfolded, expr)
|
||||||
|
|
||||||
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
|
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
|
||||||
if operation is None:
|
if operation is None:
|
||||||
self.reporter.error(
|
self.reporter.error(
|
||||||
@@ -659,13 +698,26 @@ class PythonTyper(
|
|||||||
self.reporter.warning(node.location, "ConstraintType not yet supported")
|
self.reporter.warning(node.location, "ConstraintType not yet supported")
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
def visit_frame_column(self, node: p.FrameColumn) -> Type:
|
def visit_frame_column(self, node: p.FrameColumn) -> ColumnType:
|
||||||
self.reporter.warning(node.location, "FrameColumn not yet supported")
|
return ColumnType(
|
||||||
return UnknownType()
|
type=(
|
||||||
|
self.resolve_type_expr(node.type)
|
||||||
|
if node.type is not None
|
||||||
|
else UnknownType()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def visit_frame_type(self, node: p.FrameType) -> Type:
|
def visit_frame_type(self, node: p.FrameType) -> Type:
|
||||||
self.reporter.warning(node.location, "FrameType not yet supported")
|
return DataFrameType(
|
||||||
return UnknownType()
|
columns=[
|
||||||
|
DataFrameType.Column(
|
||||||
|
index=i,
|
||||||
|
name=column.name,
|
||||||
|
type=self.visit_frame_column(column),
|
||||||
|
)
|
||||||
|
for i, column in enumerate(node.columns)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def _get_call_result(
|
def _get_call_result(
|
||||||
self,
|
self,
|
||||||
@@ -1108,3 +1160,23 @@ class PythonTyper(
|
|||||||
return p.BaseType(location=location, base=name, param=None)
|
return p.BaseType(location=location, base=name, param=None)
|
||||||
case _:
|
case _:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _visit_tuple_subscript(self, tup: TupleType, expr: p.SubscriptExpr) -> Type:
|
||||||
|
match expr.index:
|
||||||
|
case p.LiteralExpr(value=int() as index):
|
||||||
|
if index < 0 or index >= len(tup.items):
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Index {index} out of range for tuple {tup}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return tup.items[index]
|
||||||
|
case _:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Invalid index type {expr.index} on {tup}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def _visit_frame_subscript(
|
||||||
|
self, frame: DataFrameType, expr: p.SubscriptExpr
|
||||||
|
) -> Type:
|
||||||
|
return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index)
|
||||||
|
|||||||
@@ -128,6 +128,10 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
|
|
||||||
case p.GetExpr():
|
case p.GetExpr():
|
||||||
target.accept(self)
|
target.accept(self)
|
||||||
|
|
||||||
|
case p.SubscriptExpr():
|
||||||
|
target.accept(self)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise Exception(f"Unsupported assignment to {target}")
|
raise Exception(f"Unsupported assignment to {target}")
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Optional, assert_never
|
from typing import Optional, assert_never, cast
|
||||||
|
|
||||||
import midas.ast.midas as m
|
import midas.ast.midas as m
|
||||||
from midas.ast.printer import MidasPrinter
|
from midas.ast.printer import MidasPrinter
|
||||||
@@ -156,6 +156,37 @@ class ConstraintType:
|
|||||||
return f"{self.type} where {printer.print(self.constraint)}"
|
return f"{self.type} where {printer.print(self.constraint)}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class TupleType:
|
||||||
|
items: tuple[Type, ...]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"({', '.join(map(str, self.items))})"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class ColumnType:
|
||||||
|
type: Type
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"Column[{self.type}]"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class DataFrameType:
|
||||||
|
columns: list[Column]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
schema: list[str] = [f"{col.name}: {col.type}" for col in self.columns]
|
||||||
|
return f"Frame[{', '.join(schema)}]"
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Column:
|
||||||
|
index: int
|
||||||
|
name: Optional[str]
|
||||||
|
type: ColumnType
|
||||||
|
|
||||||
|
|
||||||
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||||
def sub_argument(arg: Function.Argument):
|
def sub_argument(arg: Function.Argument):
|
||||||
return Function.Argument(
|
return Function.Argument(
|
||||||
@@ -165,6 +196,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
required=arg.required,
|
required=arg.required,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def sub_column(col: DataFrameType.Column):
|
||||||
|
return DataFrameType.Column(
|
||||||
|
index=col.index,
|
||||||
|
name=col.name,
|
||||||
|
type=cast(ColumnType, substitute_typevars(col.type, substitutions)),
|
||||||
|
)
|
||||||
|
|
||||||
match type:
|
match type:
|
||||||
case TopType():
|
case TopType():
|
||||||
return type
|
return type
|
||||||
@@ -250,10 +288,26 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
body=substitute_typevars(body, substitutions),
|
body=substitute_typevars(body, substitutions),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case TupleType(items=items):
|
||||||
|
return TupleType(
|
||||||
|
items=tuple(substitute_typevars(item, substitutions) for item in items),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ColumnType(type=items_type):
|
||||||
|
return ColumnType(
|
||||||
|
type=substitute_typevars(items_type, substitutions),
|
||||||
|
)
|
||||||
|
|
||||||
|
case DataFrameType(columns=columns):
|
||||||
|
return DataFrameType(
|
||||||
|
columns=list(map(sub_column, columns)),
|
||||||
|
)
|
||||||
|
|
||||||
case UnknownType() | UnitType():
|
case UnknownType() | UnitType():
|
||||||
return type
|
return type
|
||||||
|
|
||||||
case TopType() | GenericType():
|
case TopType() | GenericType():
|
||||||
|
|
||||||
raise NotImplementedError(f"Unsupported type {type}")
|
raise NotImplementedError(f"Unsupported type {type}")
|
||||||
|
|
||||||
# Ensure exhaustiveness
|
# Ensure exhaustiveness
|
||||||
@@ -317,6 +371,15 @@ def to_annotation(type: Type) -> str:
|
|||||||
case ConstraintType():
|
case ConstraintType():
|
||||||
return str(type)
|
return str(type)
|
||||||
|
|
||||||
|
case TupleType(items=items):
|
||||||
|
return f"Tuple[{', '.join(map(to_annotation, items))}]"
|
||||||
|
|
||||||
|
case ColumnType():
|
||||||
|
return "pd.Series"
|
||||||
|
|
||||||
|
case DataFrameType():
|
||||||
|
return "pd.DataFrame"
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
assert_never(type)
|
assert_never(type)
|
||||||
|
|
||||||
@@ -342,4 +405,7 @@ Type = (
|
|||||||
| GenericType
|
| GenericType
|
||||||
| AppliedType
|
| AppliedType
|
||||||
| ConstraintType
|
| ConstraintType
|
||||||
|
| TupleType
|
||||||
|
| ColumnType
|
||||||
|
| DataFrameType
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import ast
|
import ast
|
||||||
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -13,13 +14,16 @@ from midas.checker.types import (
|
|||||||
AliasType,
|
AliasType,
|
||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
|
ColumnType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
|
DataFrameType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
TopType,
|
TopType,
|
||||||
|
TupleType,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnitType,
|
UnitType,
|
||||||
@@ -40,6 +44,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
self.workdir: Path = workdir.resolve()
|
self.workdir: Path = workdir.resolve()
|
||||||
self.build_dir: Path = self.workdir / "build" / "midas"
|
self.build_dir: Path = self.workdir / "build" / "midas"
|
||||||
self.rel_src_path: Path = Path()
|
self.rel_src_path: Path = Path()
|
||||||
|
self.logger: logging.Logger = logging.getLogger("Generator")
|
||||||
|
|
||||||
self._typed_ast: TypedAST = TypedAST(
|
self._typed_ast: TypedAST = TypedAST(
|
||||||
stmts=[],
|
stmts=[],
|
||||||
@@ -327,6 +332,19 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
if bound is not None:
|
if bound is not None:
|
||||||
self._make_cast_asserts(src_location, expr, bound)
|
self._make_cast_asserts(src_location, expr, bound)
|
||||||
|
|
||||||
|
case TupleType(items=items):
|
||||||
|
self._add_assert(
|
||||||
|
ast.Call(
|
||||||
|
func=ast.Name(id="isinstance"),
|
||||||
|
args=[expr, ast.Name(id="tuple")],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(src_location, expr, type),
|
||||||
|
)
|
||||||
|
assert isinstance(expr, ast.Tuple)
|
||||||
|
for item, item_type in zip(expr.elts, items):
|
||||||
|
self._make_cast_asserts(src_location, item, item_type)
|
||||||
|
|
||||||
case (
|
case (
|
||||||
TopType()
|
TopType()
|
||||||
| Function()
|
| Function()
|
||||||
@@ -334,8 +352,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
| ComplexType()
|
| ComplexType()
|
||||||
| ExtensionType()
|
| ExtensionType()
|
||||||
| GenericType()
|
| GenericType()
|
||||||
|
| ColumnType()
|
||||||
|
| DataFrameType()
|
||||||
):
|
):
|
||||||
raise NotImplementedError(f"Can't make assertion for type {type}")
|
self.logger.warning(f"Can't make assertion for type {type}")
|
||||||
|
|
||||||
# Ensure exhaustiveness
|
# Ensure exhaustiveness
|
||||||
case _:
|
case _:
|
||||||
|
|||||||
@@ -7,13 +7,16 @@ from midas.checker.types import (
|
|||||||
AliasType,
|
AliasType,
|
||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
|
ColumnType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
|
DataFrameType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
TopType,
|
TopType,
|
||||||
|
TupleType,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnitType,
|
UnitType,
|
||||||
@@ -30,6 +33,7 @@ class StubsGenerator:
|
|||||||
self.types: TypesRegistry = types
|
self.types: TypesRegistry = types
|
||||||
self.stubs: list[ast.stmt] = []
|
self.stubs: list[ast.stmt] = []
|
||||||
self.typing_imports: set[str] = set()
|
self.typing_imports: set[str] = set()
|
||||||
|
self.import_pandas: bool = False
|
||||||
self.protocol_idx: int = 0
|
self.protocol_idx: int = 0
|
||||||
self.stub_idx: int = 0
|
self.stub_idx: int = 0
|
||||||
self.type_var_idx: int = 0
|
self.type_var_idx: int = 0
|
||||||
@@ -38,6 +42,7 @@ class StubsGenerator:
|
|||||||
def generate_stubs(self) -> ast.Module:
|
def generate_stubs(self) -> ast.Module:
|
||||||
self.stubs = []
|
self.stubs = []
|
||||||
self.typing_imports = set()
|
self.typing_imports = set()
|
||||||
|
self.import_pandas = False
|
||||||
for name, type in self.types._types.items():
|
for name, type in self.types._types.items():
|
||||||
# Skip builtin types, not just based on name so the user can override
|
# Skip builtin types, not just based on name so the user can override
|
||||||
# TODO: check if added members on builtin type
|
# TODO: check if added members on builtin type
|
||||||
@@ -53,7 +58,7 @@ class StubsGenerator:
|
|||||||
continue
|
continue
|
||||||
self.generate_stub(name, type)
|
self.generate_stub(name, type)
|
||||||
|
|
||||||
imports = [
|
imports: list[ast.stmt] = [
|
||||||
ast.ImportFrom(
|
ast.ImportFrom(
|
||||||
module="__future__",
|
module="__future__",
|
||||||
names=[ast.alias(name="annotations")],
|
names=[ast.alias(name="annotations")],
|
||||||
@@ -70,6 +75,17 @@ class StubsGenerator:
|
|||||||
level=0,
|
level=0,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if self.import_pandas:
|
||||||
|
imports.append(
|
||||||
|
ast.Import(
|
||||||
|
names=[
|
||||||
|
ast.alias(
|
||||||
|
name="pandas",
|
||||||
|
asname="pd",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
return ast.Module(body=imports + self.stubs, type_ignores=[])
|
return ast.Module(body=imports + self.stubs, type_ignores=[])
|
||||||
|
|
||||||
def generate_stub(self, name: str, type: Type):
|
def generate_stub(self, name: str, type: Type):
|
||||||
@@ -231,6 +247,31 @@ class StubsGenerator:
|
|||||||
case ConstraintType():
|
case ConstraintType():
|
||||||
return self.dump_type(type.type)
|
return self.dump_type(type.type)
|
||||||
|
|
||||||
|
case TupleType(items=items):
|
||||||
|
return ast.Subscript(
|
||||||
|
value=ast.Name(id="tuple"),
|
||||||
|
slice=ast.Tuple(
|
||||||
|
elts=[self.dump_type(item) for item in items],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ColumnType(type=inner):
|
||||||
|
self.import_pandas = True
|
||||||
|
return ast.Subscript(
|
||||||
|
value=ast.Attribute(
|
||||||
|
value=ast.Name(id="pd"),
|
||||||
|
attr="Series",
|
||||||
|
),
|
||||||
|
slice=self.dump_type(inner),
|
||||||
|
)
|
||||||
|
|
||||||
|
case DataFrameType():
|
||||||
|
self.import_pandas = True
|
||||||
|
return ast.Attribute(
|
||||||
|
value=ast.Name(id="pd"),
|
||||||
|
attr="DataFrame",
|
||||||
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
assert_never(type)
|
assert_never(type)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user