feat(gen): add support for tuples and dataframes
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import ast
|
import ast
|
||||||
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -13,13 +14,16 @@ from midas.checker.types import (
|
|||||||
AliasType,
|
AliasType,
|
||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
|
ColumnType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
|
DataFrameType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
TopType,
|
TopType,
|
||||||
|
TupleType,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnitType,
|
UnitType,
|
||||||
@@ -40,6 +44,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
self.workdir: Path = workdir.resolve()
|
self.workdir: Path = workdir.resolve()
|
||||||
self.build_dir: Path = self.workdir / "build" / "midas"
|
self.build_dir: Path = self.workdir / "build" / "midas"
|
||||||
self.rel_src_path: Path = Path()
|
self.rel_src_path: Path = Path()
|
||||||
|
self.logger: logging.Logger = logging.getLogger("Generator")
|
||||||
|
|
||||||
self._typed_ast: TypedAST = TypedAST(
|
self._typed_ast: TypedAST = TypedAST(
|
||||||
stmts=[],
|
stmts=[],
|
||||||
@@ -327,6 +332,19 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
if bound is not None:
|
if bound is not None:
|
||||||
self._make_cast_asserts(src_location, expr, bound)
|
self._make_cast_asserts(src_location, expr, bound)
|
||||||
|
|
||||||
|
case TupleType(items=items):
|
||||||
|
self._add_assert(
|
||||||
|
ast.Call(
|
||||||
|
func=ast.Name(id="isinstance"),
|
||||||
|
args=[expr, ast.Name(id="tuple")],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(src_location, expr, type),
|
||||||
|
)
|
||||||
|
assert isinstance(expr, ast.Tuple)
|
||||||
|
for item, item_type in zip(expr.elts, items):
|
||||||
|
self._make_cast_asserts(src_location, item, item_type)
|
||||||
|
|
||||||
case (
|
case (
|
||||||
TopType()
|
TopType()
|
||||||
| Function()
|
| Function()
|
||||||
@@ -334,8 +352,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
| ComplexType()
|
| ComplexType()
|
||||||
| ExtensionType()
|
| ExtensionType()
|
||||||
| GenericType()
|
| GenericType()
|
||||||
|
| ColumnType()
|
||||||
|
| DataFrameType()
|
||||||
):
|
):
|
||||||
raise NotImplementedError(f"Can't make assertion for type {type}")
|
self.logger.warning(f"Can't make assertion for type {type}")
|
||||||
|
|
||||||
# Ensure exhaustiveness
|
# Ensure exhaustiveness
|
||||||
case _:
|
case _:
|
||||||
|
|||||||
@@ -7,13 +7,16 @@ from midas.checker.types import (
|
|||||||
AliasType,
|
AliasType,
|
||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
|
ColumnType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
|
DataFrameType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
TopType,
|
TopType,
|
||||||
|
TupleType,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnitType,
|
UnitType,
|
||||||
@@ -30,6 +33,7 @@ class StubsGenerator:
|
|||||||
self.types: TypesRegistry = types
|
self.types: TypesRegistry = types
|
||||||
self.stubs: list[ast.stmt] = []
|
self.stubs: list[ast.stmt] = []
|
||||||
self.typing_imports: set[str] = set()
|
self.typing_imports: set[str] = set()
|
||||||
|
self.import_pandas: bool = False
|
||||||
self.protocol_idx: int = 0
|
self.protocol_idx: int = 0
|
||||||
self.stub_idx: int = 0
|
self.stub_idx: int = 0
|
||||||
self.type_var_idx: int = 0
|
self.type_var_idx: int = 0
|
||||||
@@ -38,6 +42,7 @@ class StubsGenerator:
|
|||||||
def generate_stubs(self) -> ast.Module:
|
def generate_stubs(self) -> ast.Module:
|
||||||
self.stubs = []
|
self.stubs = []
|
||||||
self.typing_imports = set()
|
self.typing_imports = set()
|
||||||
|
self.import_pandas = False
|
||||||
for name, type in self.types._types.items():
|
for name, type in self.types._types.items():
|
||||||
# Skip builtin types, not just based on name so the user can override
|
# Skip builtin types, not just based on name so the user can override
|
||||||
# TODO: check if added members on builtin type
|
# TODO: check if added members on builtin type
|
||||||
@@ -53,7 +58,7 @@ class StubsGenerator:
|
|||||||
continue
|
continue
|
||||||
self.generate_stub(name, type)
|
self.generate_stub(name, type)
|
||||||
|
|
||||||
imports = [
|
imports: list[ast.stmt] = [
|
||||||
ast.ImportFrom(
|
ast.ImportFrom(
|
||||||
module="__future__",
|
module="__future__",
|
||||||
names=[ast.alias(name="annotations")],
|
names=[ast.alias(name="annotations")],
|
||||||
@@ -70,6 +75,17 @@ class StubsGenerator:
|
|||||||
level=0,
|
level=0,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if self.import_pandas:
|
||||||
|
imports.append(
|
||||||
|
ast.Import(
|
||||||
|
names=[
|
||||||
|
ast.alias(
|
||||||
|
name="pandas",
|
||||||
|
asname="pd",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
return ast.Module(body=imports + self.stubs, type_ignores=[])
|
return ast.Module(body=imports + self.stubs, type_ignores=[])
|
||||||
|
|
||||||
def generate_stub(self, name: str, type: Type):
|
def generate_stub(self, name: str, type: Type):
|
||||||
@@ -231,6 +247,31 @@ class StubsGenerator:
|
|||||||
case ConstraintType():
|
case ConstraintType():
|
||||||
return self.dump_type(type.type)
|
return self.dump_type(type.type)
|
||||||
|
|
||||||
|
case TupleType(items=items):
|
||||||
|
return ast.Subscript(
|
||||||
|
value=ast.Name(id="tuple"),
|
||||||
|
slice=ast.Tuple(
|
||||||
|
elts=[self.dump_type(item) for item in items],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ColumnType(type=inner):
|
||||||
|
self.import_pandas = True
|
||||||
|
return ast.Subscript(
|
||||||
|
value=ast.Attribute(
|
||||||
|
value=ast.Name(id="pd"),
|
||||||
|
attr="Series",
|
||||||
|
),
|
||||||
|
slice=self.dump_type(inner),
|
||||||
|
)
|
||||||
|
|
||||||
|
case DataFrameType():
|
||||||
|
self.import_pandas = True
|
||||||
|
return ast.Attribute(
|
||||||
|
value=ast.Name(id="pd"),
|
||||||
|
attr="DataFrame",
|
||||||
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
assert_never(type)
|
assert_never(type)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user