feat(gen): add support for tuples and dataframes

This commit is contained in:
2026-06-23 14:45:19 +02:00
parent 9144995b79
commit 9145496587
2 changed files with 63 additions and 2 deletions

View File

@@ -1,4 +1,5 @@
import ast import ast
import logging
import shutil import shutil
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
@@ -13,13 +14,16 @@ from midas.checker.types import (
AliasType, AliasType,
AppliedType, AppliedType,
BaseType, BaseType,
ColumnType,
ComplexType, ComplexType,
ConstraintType, ConstraintType,
DataFrameType,
ExtensionType, ExtensionType,
Function, Function,
GenericType, GenericType,
OverloadedFunction, OverloadedFunction,
TopType, TopType,
TupleType,
Type, Type,
TypeVar, TypeVar,
UnitType, UnitType,
@@ -40,6 +44,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self.workdir: Path = workdir.resolve() self.workdir: Path = workdir.resolve()
self.build_dir: Path = self.workdir / "build" / "midas" self.build_dir: Path = self.workdir / "build" / "midas"
self.rel_src_path: Path = Path() self.rel_src_path: Path = Path()
self.logger: logging.Logger = logging.getLogger("Generator")
self._typed_ast: TypedAST = TypedAST( self._typed_ast: TypedAST = TypedAST(
stmts=[], stmts=[],
@@ -327,6 +332,19 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
if bound is not None: if bound is not None:
self._make_cast_asserts(src_location, expr, bound) self._make_cast_asserts(src_location, expr, bound)
case TupleType(items=items):
self._add_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id="tuple")],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
)
assert isinstance(expr, ast.Tuple)
for item, item_type in zip(expr.elts, items):
self._make_cast_asserts(src_location, item, item_type)
case ( case (
TopType() TopType()
| Function() | Function()
@@ -334,8 +352,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
| ComplexType() | ComplexType()
| ExtensionType() | ExtensionType()
| GenericType() | GenericType()
| ColumnType()
| DataFrameType()
): ):
raise NotImplementedError(f"Can't make assertion for type {type}") self.logger.warning(f"Can't make assertion for type {type}")
# Ensure exhaustiveness # Ensure exhaustiveness
case _: case _:

View File

@@ -7,13 +7,16 @@ from midas.checker.types import (
AliasType, AliasType,
AppliedType, AppliedType,
BaseType, BaseType,
ColumnType,
ComplexType, ComplexType,
ConstraintType, ConstraintType,
DataFrameType,
ExtensionType, ExtensionType,
Function, Function,
GenericType, GenericType,
OverloadedFunction, OverloadedFunction,
TopType, TopType,
TupleType,
Type, Type,
TypeVar, TypeVar,
UnitType, UnitType,
@@ -30,6 +33,7 @@ class StubsGenerator:
self.types: TypesRegistry = types self.types: TypesRegistry = types
self.stubs: list[ast.stmt] = [] self.stubs: list[ast.stmt] = []
self.typing_imports: set[str] = set() self.typing_imports: set[str] = set()
self.import_pandas: bool = False
self.protocol_idx: int = 0 self.protocol_idx: int = 0
self.stub_idx: int = 0 self.stub_idx: int = 0
self.type_var_idx: int = 0 self.type_var_idx: int = 0
@@ -38,6 +42,7 @@ class StubsGenerator:
def generate_stubs(self) -> ast.Module: def generate_stubs(self) -> ast.Module:
self.stubs = [] self.stubs = []
self.typing_imports = set() self.typing_imports = set()
self.import_pandas = False
for name, type in self.types._types.items(): for name, type in self.types._types.items():
# Skip builtin types, not just based on name so the user can override # Skip builtin types, not just based on name so the user can override
# TODO: check if added members on builtin type # TODO: check if added members on builtin type
@@ -53,7 +58,7 @@ class StubsGenerator:
continue continue
self.generate_stub(name, type) self.generate_stub(name, type)
imports = [ imports: list[ast.stmt] = [
ast.ImportFrom( ast.ImportFrom(
module="__future__", module="__future__",
names=[ast.alias(name="annotations")], names=[ast.alias(name="annotations")],
@@ -70,6 +75,17 @@ class StubsGenerator:
level=0, level=0,
) )
) )
if self.import_pandas:
imports.append(
ast.Import(
names=[
ast.alias(
name="pandas",
asname="pd",
)
],
)
)
return ast.Module(body=imports + self.stubs, type_ignores=[]) return ast.Module(body=imports + self.stubs, type_ignores=[])
def generate_stub(self, name: str, type: Type): def generate_stub(self, name: str, type: Type):
@@ -231,6 +247,31 @@ class StubsGenerator:
case ConstraintType(): case ConstraintType():
return self.dump_type(type.type) return self.dump_type(type.type)
case TupleType(items=items):
return ast.Subscript(
value=ast.Name(id="tuple"),
slice=ast.Tuple(
elts=[self.dump_type(item) for item in items],
),
)
case ColumnType(type=inner):
self.import_pandas = True
return ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="pd"),
attr="Series",
),
slice=self.dump_type(inner),
)
case DataFrameType():
self.import_pandas = True
return ast.Attribute(
value=ast.Name(id="pd"),
attr="DataFrame",
)
case _: case _:
assert_never(type) assert_never(type)