feat(gen): add support for tuples and dataframes

This commit is contained in:
2026-06-23 14:45:19 +02:00
parent 3bdbc80079
commit cc5e7af143
2 changed files with 63 additions and 2 deletions

View File

@@ -1,4 +1,5 @@
import ast
import logging
import shutil
from dataclasses import dataclass, field
from pathlib import Path
@@ -13,13 +14,16 @@ from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
TopType,
TupleType,
Type,
TypeVar,
UnitType,
@@ -40,6 +44,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self.workdir: Path = workdir.resolve()
self.build_dir: Path = self.workdir / "build" / "midas"
self.rel_src_path: Path = Path()
self.logger: logging.Logger = logging.getLogger("Generator")
self._typed_ast: TypedAST = TypedAST(
stmts=[],
@@ -332,6 +337,19 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
if bound is not None:
self._make_cast_asserts(src_location, expr, bound)
case TupleType(items=items):
self._add_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id="tuple")],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
)
assert isinstance(expr, ast.Tuple)
for item, item_type in zip(expr.elts, items):
self._make_cast_asserts(src_location, item, item_type)
case (
TopType()
| Function()
@@ -339,8 +357,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
| ComplexType()
| ExtensionType()
| GenericType()
| ColumnType()
| DataFrameType()
):
raise NotImplementedError(f"Can't make assertion for type {type}")
self.logger.warning(f"Can't make assertion for type {type}")
# Ensure exhaustiveness
case _:

View File

@@ -7,13 +7,16 @@ from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
TopType,
TupleType,
Type,
TypeVar,
UnitType,
@@ -30,6 +33,7 @@ class StubsGenerator:
self.types: TypesRegistry = types
self.stubs: list[ast.stmt] = []
self.typing_imports: set[str] = set()
self.import_pandas: bool = False
self.protocol_idx: int = 0
self.stub_idx: int = 0
self.type_var_idx: int = 0
@@ -38,6 +42,7 @@ class StubsGenerator:
def generate_stubs(self) -> ast.Module:
self.stubs = []
self.typing_imports = set()
self.import_pandas = False
for name, type in self.types._types.items():
# Skip builtin types, not just based on name so the user can override
# TODO: check if added members on builtin type
@@ -53,7 +58,7 @@ class StubsGenerator:
continue
self.generate_stub(name, type)
imports = [
imports: list[ast.stmt] = [
ast.ImportFrom(
module="__future__",
names=[ast.alias(name="annotations")],
@@ -70,6 +75,17 @@ class StubsGenerator:
level=0,
)
)
if self.import_pandas:
imports.append(
ast.Import(
names=[
ast.alias(
name="pandas",
asname="pd",
)
],
)
)
return ast.Module(body=imports + self.stubs, type_ignores=[])
def generate_stub(self, name: str, type: Type):
@@ -231,6 +247,31 @@ class StubsGenerator:
case ConstraintType():
return self.dump_type(type.type)
case TupleType(items=items):
return ast.Subscript(
value=ast.Name(id="tuple"),
slice=ast.Tuple(
elts=[self.dump_type(item) for item in items],
),
)
case ColumnType(type=inner):
self.import_pandas = True
return ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="pd"),
attr="Series",
),
slice=self.dump_type(inner),
)
case DataFrameType():
self.import_pandas = True
return ast.Attribute(
value=ast.Name(id="pd"),
attr="DataFrame",
)
case _:
assert_never(type)