Compare commits
5 Commits
main
...
feat/stubs
| Author | SHA1 | Date | |
|---|---|---|---|
|
11422d4364
|
|||
|
e8f8a5ca2f
|
|||
|
df8d71c0a9
|
|||
|
e4fb142f99
|
|||
|
2f8f9d633b
|
@@ -102,7 +102,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
base_name,
|
base_name,
|
||||||
member.name.lexeme,
|
member.name.lexeme,
|
||||||
member_type,
|
member_type,
|
||||||
member.kind == m.MemberKind.METHOD,
|
member.kind,
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from midas.checker.registry import TypesRegistry
|
|||||||
from midas.checker.reporter import FileReporter, Reporter
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
from midas.checker.resolver import Resolver
|
from midas.checker.resolver import Resolver
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
|
AliasType,
|
||||||
AppliedType,
|
AppliedType,
|
||||||
Function,
|
Function,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
@@ -694,9 +695,17 @@ class PythonTyper(
|
|||||||
case UnknownType():
|
case UnknownType():
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
|
case AliasType(type=base):
|
||||||
|
return self._get_call_result(
|
||||||
|
location, base, positional, keywords, report_errors
|
||||||
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
if report_errors:
|
if report_errors:
|
||||||
self.reporter.error(location, f"{callee} is not callable")
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"{callee} ({callee.__class__.__name__}) is not callable",
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _are_arguments_valid(
|
def _are_arguments_valid(
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from midas.ast.midas import MemberKind
|
||||||
from midas.checker.builtins import BUILTIN_SUBTYPES
|
from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
AliasType,
|
||||||
@@ -19,11 +21,17 @@ from midas.checker.types import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Member:
|
||||||
|
kind: MemberKind
|
||||||
|
type: Type
|
||||||
|
|
||||||
|
|
||||||
class TypesRegistry:
|
class TypesRegistry:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
|
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
|
||||||
self._types: dict[str, Type] = {}
|
self._types: dict[str, Type] = {}
|
||||||
self._members: dict[str, dict[str, Type]] = {}
|
self._members: dict[str, dict[str, Member]] = {}
|
||||||
|
|
||||||
def get_type(self, name: str) -> Type:
|
def get_type(self, name: str) -> Type:
|
||||||
"""Get a type from its name
|
"""Get a type from its name
|
||||||
@@ -60,26 +68,38 @@ class TypesRegistry:
|
|||||||
return type
|
return type
|
||||||
|
|
||||||
def define_member(
|
def define_member(
|
||||||
self, type_name: str, member_name: str, member_type: Type, is_method: bool
|
self,
|
||||||
|
type_name: str,
|
||||||
|
member_name: str,
|
||||||
|
member_type: Type,
|
||||||
|
kind: MemberKind,
|
||||||
):
|
):
|
||||||
members: dict[str, Type] = self._members.setdefault(type_name, {})
|
members: dict[str, Member] = self._members.setdefault(type_name, {})
|
||||||
if member_name in members:
|
if member_name in members:
|
||||||
if not is_method:
|
current: Member = members[member_name]
|
||||||
|
if current.kind != kind:
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
f"Member '{member_name}' already defined for type {type_name}"
|
f"Member '{member_name}' is already defined as a {current.kind},"
|
||||||
|
+ f" cannot define a {kind} with the same name"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
current: Type = members[member_name]
|
if kind != MemberKind.METHOD:
|
||||||
|
self.logger.error(
|
||||||
|
f"Member '{member_name}' already defined for type {type_name},"
|
||||||
|
+ " only methods can be overloaded"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
combined: Type
|
combined: Type
|
||||||
match current:
|
match current.type:
|
||||||
case OverloadedFunction(overloads=overloads):
|
case OverloadedFunction(overloads=overloads):
|
||||||
combined = OverloadedFunction(overloads=overloads + [member_type])
|
combined = OverloadedFunction(overloads=overloads + [member_type])
|
||||||
case _:
|
case _:
|
||||||
combined = OverloadedFunction(overloads=[current, member_type])
|
combined = OverloadedFunction(overloads=[current.type, member_type])
|
||||||
members[member_name] = combined
|
members[member_name] = Member(kind=current.kind, type=combined)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
members[member_name] = member_type
|
members[member_name] = Member(kind=kind, type=member_type)
|
||||||
|
|
||||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||||
"""Check whether `type1` is a subtype of `type2`
|
"""Check whether `type1` is a subtype of `type2`
|
||||||
@@ -297,13 +317,13 @@ class TypesRegistry:
|
|||||||
case BaseType(name=name):
|
case BaseType(name=name):
|
||||||
if name in self._members:
|
if name in self._members:
|
||||||
if member_name in self._members[name]:
|
if member_name in self._members[name]:
|
||||||
return self._members[name][member_name]
|
return self._members[name][member_name].type
|
||||||
return None
|
return None
|
||||||
|
|
||||||
case AliasType(name=name, type=base):
|
case AliasType(name=name, type=base):
|
||||||
if name in self._members:
|
if name in self._members:
|
||||||
if member_name in self._members[name]:
|
if member_name in self._members[name]:
|
||||||
return self._members[name][member_name]
|
return self._members[name][member_name].type
|
||||||
return self.lookup_member(base, member_name)
|
return self.lookup_member(base, member_name)
|
||||||
|
|
||||||
case AppliedType(name=name, body=body, args=args):
|
case AppliedType(name=name, body=body, args=args):
|
||||||
@@ -317,7 +337,7 @@ class TypesRegistry:
|
|||||||
}
|
}
|
||||||
if name in self._members:
|
if name in self._members:
|
||||||
if member_name in self._members[name]:
|
if member_name in self._members[name]:
|
||||||
member_type: Type = self._members[name][member_name]
|
member_type: Type = self._members[name][member_name].type
|
||||||
return substitute_typevars(member_type, substitutions)
|
return substitute_typevars(member_type, substitutions)
|
||||||
|
|
||||||
member_type2: Optional[Type] = self.lookup_member(body, member_name)
|
member_type2: Optional[Type] = self.lookup_member(body, member_name)
|
||||||
|
|||||||
@@ -140,6 +140,9 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
)
|
)
|
||||||
|
|
||||||
match type:
|
match type:
|
||||||
|
case TopType():
|
||||||
|
return type
|
||||||
|
|
||||||
case BaseType(name=name) if name in substitutions:
|
case BaseType(name=name) if name in substitutions:
|
||||||
return substitutions[name]
|
return substitutions[name]
|
||||||
|
|
||||||
@@ -200,6 +203,21 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
return substitutions[name]
|
return substitutions[name]
|
||||||
raise ValueError(f"Missing TypeVar substitution for {name}")
|
raise ValueError(f"Missing TypeVar substitution for {name}")
|
||||||
|
|
||||||
|
case GenericType(name=name, params=params, body=body):
|
||||||
|
params2: list[TypeVar] = []
|
||||||
|
for param in params:
|
||||||
|
param2: Type = substitute_typevars(param, substitutions)
|
||||||
|
if not isinstance(param2, TypeVar):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid type parameter substitution, expected TypeVar, got {param2}"
|
||||||
|
)
|
||||||
|
params2.append(param2)
|
||||||
|
return GenericType(
|
||||||
|
name=name,
|
||||||
|
params=params2,
|
||||||
|
body=substitute_typevars(body, substitutions),
|
||||||
|
)
|
||||||
|
|
||||||
case UnknownType() | UnitType():
|
case UnknownType() | UnitType():
|
||||||
return type
|
return type
|
||||||
|
|
||||||
|
|||||||
@@ -4,5 +4,6 @@ from .format import format as format
|
|||||||
from .highlight import highlight as highlight
|
from .highlight import highlight as highlight
|
||||||
from .parse import parse as parse
|
from .parse import parse as parse
|
||||||
from .registry import dump_registry as dump_registry
|
from .registry import dump_registry as dump_registry
|
||||||
|
from .stubs import stubs as stubs
|
||||||
from .types import types as types
|
from .types import types as types
|
||||||
from .validate import validate as validate
|
from .validate import validate as validate
|
||||||
|
|||||||
27
midas/cli/commands/stubs.py
Normal file
27
midas/cli/commands/stubs.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import ast
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TextIO
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
from midas.checker.checker import TypeChecker
|
||||||
|
from midas.generator.stubs import StubsGenerator
|
||||||
|
|
||||||
|
|
||||||
|
@click.command(help="Generate stubs from Midas definitions")
|
||||||
|
@click.argument("file", type=click.File("r"))
|
||||||
|
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||||
|
def stubs(
|
||||||
|
file: TextIO,
|
||||||
|
output: TextIO,
|
||||||
|
):
|
||||||
|
source_path: Path = Path(file.name).resolve()
|
||||||
|
|
||||||
|
checker = TypeChecker()
|
||||||
|
checker.import_midas(source_path)
|
||||||
|
|
||||||
|
generator = StubsGenerator(checker.types)
|
||||||
|
module: ast.Module = generator.generate_stubs()
|
||||||
|
module = ast.fix_missing_locations(module)
|
||||||
|
|
||||||
|
output.write(ast.unparse(module))
|
||||||
@@ -18,6 +18,7 @@ midas.add_command(commands.highlight)
|
|||||||
midas.add_command(commands.parse)
|
midas.add_command(commands.parse)
|
||||||
midas.add_command(commands.dump_registry)
|
midas.add_command(commands.dump_registry)
|
||||||
midas.add_command(commands.types)
|
midas.add_command(commands.types)
|
||||||
|
midas.add_command(commands.stubs)
|
||||||
midas.add_command(commands.validate)
|
midas.add_command(commands.validate)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
337
midas/generator/stubs.py
Normal file
337
midas/generator/stubs.py
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
import ast
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
|
from midas.checker.registry import Member, TypesRegistry
|
||||||
|
from midas.checker.types import (
|
||||||
|
AliasType,
|
||||||
|
AppliedType,
|
||||||
|
BaseType,
|
||||||
|
ComplexType,
|
||||||
|
ExtensionType,
|
||||||
|
Function,
|
||||||
|
GenericType,
|
||||||
|
OverloadedFunction,
|
||||||
|
TopType,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
UnitType,
|
||||||
|
UnknownType,
|
||||||
|
substitute_typevars,
|
||||||
|
)
|
||||||
|
|
||||||
|
Empty = ast.Constant(value=...)
|
||||||
|
|
||||||
|
|
||||||
|
class StubsGenerator:
|
||||||
|
def __init__(self, types: TypesRegistry) -> None:
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
self.stubs: list[ast.stmt] = []
|
||||||
|
self.typing_imports: set[str] = set()
|
||||||
|
self.protocol_idx: int = 0
|
||||||
|
self.stub_idx: int = 0
|
||||||
|
self.type_var_idx: int = 0
|
||||||
|
self.substitutions: dict[str, dict[str, Type]] = {}
|
||||||
|
|
||||||
|
def generate_stubs(self) -> ast.Module:
|
||||||
|
self.stubs = []
|
||||||
|
self.typing_imports = set()
|
||||||
|
for name, type in self.types._types.items():
|
||||||
|
self.generate_stub(name, type)
|
||||||
|
|
||||||
|
imports = [
|
||||||
|
ast.ImportFrom(
|
||||||
|
module="__future__",
|
||||||
|
names=[ast.alias(name="annotations")],
|
||||||
|
level=0,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if len(self.typing_imports) != 0:
|
||||||
|
imports.append(
|
||||||
|
ast.ImportFrom(
|
||||||
|
module="typing",
|
||||||
|
names=[
|
||||||
|
ast.alias(name=name) for name in sorted(self.typing_imports)
|
||||||
|
],
|
||||||
|
level=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return ast.Module(body=imports + self.stubs, type_ignores=[])
|
||||||
|
|
||||||
|
def generate_stub(self, name: str, type: Type):
|
||||||
|
base_type: Type = type
|
||||||
|
|
||||||
|
members: dict[str, Member] = self.types._members.get(name, {})
|
||||||
|
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
bases: list[ast.expr] = []
|
||||||
|
substitutions: dict[str, Type] = {}
|
||||||
|
bases, substitutions = self.get_bases(type)
|
||||||
|
self.substitutions[name] = substitutions
|
||||||
|
|
||||||
|
body = self.generate_body(members, substitutions)
|
||||||
|
stub = ast.ClassDef(
|
||||||
|
name=name,
|
||||||
|
bases=bases,
|
||||||
|
body=body,
|
||||||
|
keywords=[],
|
||||||
|
decorator_list=[],
|
||||||
|
)
|
||||||
|
self.add_stub(stub)
|
||||||
|
|
||||||
|
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
|
||||||
|
match type:
|
||||||
|
case AliasType(type=base):
|
||||||
|
return [self.dump_type(base)], {}
|
||||||
|
case GenericType(params=params, body=body):
|
||||||
|
self.add_typing_import("Generic")
|
||||||
|
type_vars: ast.expr
|
||||||
|
|
||||||
|
params2: list[TypeVar] = self.define_type_vars(params)
|
||||||
|
if len(params) == 1:
|
||||||
|
type_vars = ast.Name(id=params2[0].name)
|
||||||
|
else:
|
||||||
|
type_vars = ast.Tuple(
|
||||||
|
elts=[ast.Name(id=param.name) for param in params2]
|
||||||
|
)
|
||||||
|
|
||||||
|
substitutions: dict[str, TypeVar] = {
|
||||||
|
param.name: param2 for param, param2 in zip(params, params2)
|
||||||
|
}
|
||||||
|
|
||||||
|
body_bases, body_subsitutions = self.get_bases(body)
|
||||||
|
return (
|
||||||
|
body_bases
|
||||||
|
+ [
|
||||||
|
ast.Subscript(
|
||||||
|
value=ast.Name(id="Generic"),
|
||||||
|
slice=type_vars,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
body_subsitutions | substitutions,
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
return [], {}
|
||||||
|
|
||||||
|
def generate_body(
|
||||||
|
self, members: dict[str, Member], substitutions: dict[str, Type]
|
||||||
|
) -> list[ast.stmt]:
|
||||||
|
if len(members) == 0:
|
||||||
|
return [ast.Expr(value=Empty)]
|
||||||
|
|
||||||
|
body: list[ast.stmt] = []
|
||||||
|
for name, member in members.items():
|
||||||
|
type: Type = member.type
|
||||||
|
type = substitute_typevars(type, substitutions)
|
||||||
|
match member.kind:
|
||||||
|
case m.MemberKind.PROPERTY:
|
||||||
|
body.append(
|
||||||
|
ast.AnnAssign(
|
||||||
|
target=ast.Name(id=name),
|
||||||
|
annotation=self.dump_type(type),
|
||||||
|
simple=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
case m.MemberKind.METHOD:
|
||||||
|
body.extend(self.dump_method(name, type))
|
||||||
|
return body
|
||||||
|
|
||||||
|
def dump_type(self, type: Type) -> ast.expr:
|
||||||
|
match type:
|
||||||
|
case AliasType(name=name) | GenericType(name=name) if (
|
||||||
|
name in self.substitutions
|
||||||
|
):
|
||||||
|
type = substitute_typevars(type, self.substitutions[name])
|
||||||
|
|
||||||
|
match type:
|
||||||
|
case TopType() | UnknownType():
|
||||||
|
self.add_typing_import("Any")
|
||||||
|
return ast.Name(id="Any")
|
||||||
|
case BaseType(name=name):
|
||||||
|
return ast.Name(id=name)
|
||||||
|
case AliasType(name=name):
|
||||||
|
return ast.Name(id=name)
|
||||||
|
case UnitType():
|
||||||
|
return ast.Constant(value=None)
|
||||||
|
case Function():
|
||||||
|
name: str = self.define_protocol(type)
|
||||||
|
return ast.Name(id=name)
|
||||||
|
case OverloadedFunction(overloads=overloads):
|
||||||
|
if len(overloads) == 1:
|
||||||
|
return self.dump_type(overloads[0])
|
||||||
|
return ast.BinOp(
|
||||||
|
left=self.dump_type(OverloadedFunction(overloads=overloads[:-1])),
|
||||||
|
op=ast.BitOr(),
|
||||||
|
right=self.dump_type(overloads[-1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ComplexType():
|
||||||
|
name: str = self.new_stub_name()
|
||||||
|
self.generate_stub(name, type)
|
||||||
|
return ast.Name(id=name)
|
||||||
|
|
||||||
|
case ExtensionType():
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
case TypeVar():
|
||||||
|
return ast.Name(id=type.name)
|
||||||
|
case GenericType(name=name):
|
||||||
|
params: ast.expr
|
||||||
|
if len(type.params) == 1:
|
||||||
|
params = self.dump_type(type.params[0])
|
||||||
|
else:
|
||||||
|
params = ast.Tuple(
|
||||||
|
elts=[self.dump_type(param) for param in type.params]
|
||||||
|
)
|
||||||
|
return ast.Subscript(
|
||||||
|
value=ast.Name(id=type.name),
|
||||||
|
slice=params,
|
||||||
|
)
|
||||||
|
case AppliedType():
|
||||||
|
args: ast.expr
|
||||||
|
if len(type.args) == 1:
|
||||||
|
args = self.dump_type(type.args[0])
|
||||||
|
else:
|
||||||
|
args = ast.Tuple(elts=[self.dump_type(arg) for arg in type.args])
|
||||||
|
return ast.Subscript(
|
||||||
|
value=ast.Name(id=type.name),
|
||||||
|
slice=args,
|
||||||
|
)
|
||||||
|
|
||||||
|
def dump_method(
|
||||||
|
self, name: str, method: Type, overloaded: bool = False
|
||||||
|
) -> list[ast.stmt]:
|
||||||
|
match method:
|
||||||
|
case Function():
|
||||||
|
if overloaded:
|
||||||
|
self.add_typing_import("overload")
|
||||||
|
return [
|
||||||
|
ast.FunctionDef(
|
||||||
|
name=name,
|
||||||
|
args=self.dump_args(method, with_self=True),
|
||||||
|
returns=self.dump_type(method.returns),
|
||||||
|
body=[ast.Expr(value=Empty)],
|
||||||
|
decorator_list=[ast.Name(id="overload")] if overloaded else [],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
case OverloadedFunction(overloads=overloads):
|
||||||
|
stmts: list[ast.stmt] = []
|
||||||
|
for overload in overloads:
|
||||||
|
stmts.extend(self.dump_method(name, overload, True))
|
||||||
|
return stmts
|
||||||
|
case _:
|
||||||
|
return [
|
||||||
|
ast.AnnAssign(
|
||||||
|
target=ast.Name(id=name),
|
||||||
|
annotation=self.dump_type(method),
|
||||||
|
simple=1,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def dump_args(self, func: Function, with_self: bool = False) -> ast.arguments:
|
||||||
|
pos: list[ast.arg] = [
|
||||||
|
ast.arg(arg=f"_{arg.pos}", annotation=self.dump_type(arg.type))
|
||||||
|
for arg in func.pos_args
|
||||||
|
]
|
||||||
|
mixed: list[ast.arg] = [
|
||||||
|
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
|
||||||
|
for arg in func.args
|
||||||
|
]
|
||||||
|
kw: list[ast.arg] = [
|
||||||
|
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
|
||||||
|
for arg in func.kw_args
|
||||||
|
]
|
||||||
|
defaults: list[ast.expr] = [
|
||||||
|
Empty for arg in func.pos_args + func.args if not arg.required
|
||||||
|
]
|
||||||
|
kw_defaults: list[Optional[ast.expr]] = [
|
||||||
|
None if arg.required else Empty for arg in func.kw_args
|
||||||
|
]
|
||||||
|
if with_self:
|
||||||
|
arg = ast.arg(arg="self", annotation=None)
|
||||||
|
if len(pos) != 0:
|
||||||
|
pos.insert(0, arg)
|
||||||
|
else:
|
||||||
|
mixed.insert(0, arg)
|
||||||
|
return ast.arguments(
|
||||||
|
posonlyargs=pos,
|
||||||
|
args=mixed,
|
||||||
|
kwonlyargs=kw,
|
||||||
|
defaults=defaults,
|
||||||
|
kw_defaults=kw_defaults,
|
||||||
|
)
|
||||||
|
|
||||||
|
def define_protocol(self, func: Function) -> str:
|
||||||
|
self.add_typing_import("Protocol")
|
||||||
|
name: str = self.new_protocol_name()
|
||||||
|
protocol = ast.ClassDef(
|
||||||
|
name=name,
|
||||||
|
bases=[ast.Name(id="Protocol")],
|
||||||
|
keywords=[],
|
||||||
|
body=[
|
||||||
|
ast.FunctionDef(
|
||||||
|
name="__call__",
|
||||||
|
args=self.dump_args(func, with_self=True),
|
||||||
|
returns=self.dump_type(func.returns),
|
||||||
|
body=[ast.Expr(value=Empty)],
|
||||||
|
decorator_list=[],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
decorator_list=[],
|
||||||
|
)
|
||||||
|
self.add_stub(protocol)
|
||||||
|
return name
|
||||||
|
|
||||||
|
def new_protocol_name(self) -> str:
|
||||||
|
name: str = f"_Protocol{self.protocol_idx}"
|
||||||
|
self.protocol_idx += 1
|
||||||
|
return name
|
||||||
|
|
||||||
|
def new_stub_name(self) -> str:
|
||||||
|
name: str = f"_Stub_{self.stub_idx}"
|
||||||
|
self.stub_idx += 1
|
||||||
|
return name
|
||||||
|
|
||||||
|
def new_type_var_name(self) -> str:
|
||||||
|
name: str = f"_T{self.type_var_idx}"
|
||||||
|
self.type_var_idx += 1
|
||||||
|
return name
|
||||||
|
|
||||||
|
def add_stub(self, stub: ast.stmt):
|
||||||
|
self.stubs.append(stub)
|
||||||
|
|
||||||
|
def add_typing_import(self, name: str):
|
||||||
|
self.typing_imports.add(name)
|
||||||
|
|
||||||
|
def define_type_vars(self, vars: list[TypeVar]) -> list[TypeVar]:
|
||||||
|
vars2: list[TypeVar] = []
|
||||||
|
for var in vars:
|
||||||
|
vars2.append(self.define_type_var(var))
|
||||||
|
return vars2
|
||||||
|
|
||||||
|
def define_type_var(self, var: TypeVar) -> TypeVar:
|
||||||
|
name: str = self.new_type_var_name()
|
||||||
|
self.add_typing_import("TypeVar")
|
||||||
|
self.add_stub(
|
||||||
|
ast.Assign(
|
||||||
|
targets=[ast.Name(id=name)],
|
||||||
|
value=ast.Call(
|
||||||
|
func=ast.Name(id="TypeVar"),
|
||||||
|
args=[
|
||||||
|
ast.Constant(value=name),
|
||||||
|
],
|
||||||
|
keywords=(
|
||||||
|
[]
|
||||||
|
if var.bound is None
|
||||||
|
else [
|
||||||
|
ast.keyword(
|
||||||
|
arg="bound",
|
||||||
|
value=self.dump_type(var.bound),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return TypeVar(name=name, bound=None)
|
||||||
Reference in New Issue
Block a user