16 Commits

Author SHA1 Message Date
11422d4364 feat(cli): add stubs command 2026-06-17 12:19:54 +02:00
e8f8a5ca2f feat(gen): add base for stubs generator 2026-06-17 12:19:32 +02:00
df8d71c0a9 fix(checker): handle calls to AliasType 2026-06-17 12:14:20 +02:00
e4fb142f99 fix(checker): allow substitutyping type vars in GenericType and TopType 2026-06-17 12:12:07 +02:00
2f8f9d633b fix(checker): store member kind in registry 2026-06-17 12:11:16 +02:00
a4a2ed5d64 Merge pull request 'Dictionaries' (#13) from feat/dictionaries into main
Reviewed-on: #13
2026-06-16 18:42:12 +00:00
e5cb90aff6 fix(checker): make builtin type constructor parameter optional 2026-06-16 20:40:48 +02:00
75f8e4af53 feat(checker): type check dictionaries 2026-06-16 20:40:10 +02:00
42c2d7a098 feat(parser): add dictionary expression 2026-06-16 20:35:39 +02:00
5ce3b4abed Merge pull request 'Cast assertions and generator tests' (#12) from feat/cast-assertions into main
Reviewed-on: #12
2026-06-16 12:57:49 +00:00
2a8b7d559c tests: add simple gen test 2026-06-16 14:56:59 +02:00
da38cad23d feat(tests): add generator tester 2026-06-16 14:56:59 +02:00
591012d059 fix(checker): allow calling AppliedType and UnknownType 2026-06-16 14:56:58 +02:00
4b1087d6b9 fix(cli): improve dump-registry command output 2026-06-16 14:56:57 +02:00
732f7b0796 feat(checker): add environment preamble
this adds some builtin functions such as the builtin type constructors
2026-06-16 14:56:56 +02:00
c4062c9595 fix(checker): allow inferred return to be subtype of hint 2026-06-16 14:56:47 +02:00
18 changed files with 580 additions and 17 deletions

View File

@@ -157,6 +157,11 @@ class ListExpr:
items: list[Expr] items: list[Expr]
class DictExpr:
keys: list[Optional[Expr]]
values: list[Expr]
class SubscriptExpr: class SubscriptExpr:
object: Expr object: Expr
index: Expr index: Expr

View File

@@ -745,6 +745,27 @@ class PythonAstPrinter(
self._mark_last() self._mark_last()
item.accept(self) item.accept(self)
def visit_dict_expr(self, expr: p.DictExpr) -> None:
self._write_line("DictExpr")
with self._child_level():
self._write_line("keys")
with self._child_level():
for i, key in enumerate(expr.keys):
self._idx = i
if i == len(expr.keys) - 1:
self._mark_last()
if key is None:
self._write_line("None")
else:
key.accept(self)
self._write_line("values", last=True)
with self._child_level():
for i, value in enumerate(expr.values):
self._idx = i
if i == len(expr.values) - 1:
self._mark_last()
value.accept(self)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
self._write_line("SubscriptExpr") self._write_line("SubscriptExpr")
with self._child_level(): with self._child_level():

View File

@@ -259,6 +259,9 @@ class Expr(ABC):
@abstractmethod @abstractmethod
def visit_list_expr(self, expr: ListExpr) -> T: ... def visit_list_expr(self, expr: ListExpr) -> T: ...
@abstractmethod
def visit_dict_expr(self, expr: DictExpr) -> T: ...
@abstractmethod @abstractmethod
def visit_subscript_expr(self, expr: SubscriptExpr) -> T: ... def visit_subscript_expr(self, expr: SubscriptExpr) -> T: ...
@@ -370,6 +373,15 @@ class ListExpr(Expr):
return visitor.visit_list_expr(self) return visitor.visit_list_expr(self)
@dataclass(frozen=True)
class DictExpr(Expr):
keys: list[Optional[Expr]]
values: list[Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_dict_expr(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class SubscriptExpr(Expr): class SubscriptExpr(Expr):
object: Expr object: Expr

View File

@@ -150,3 +150,32 @@ extend list[T] {
prop __doc__: str prop __doc__: str
} }
extend dict[K, V] {
def copy: fn() -> dict[K, V]
def keys: fn() -> list[K] // TODO: use builtin types
def values: fn() -> list[V] // TODO: use builtin types
// def items: fn() -> list[tuple[K, V]] // TODO: use builtin types
// def get: fn(key: K, default: None = None, /) -> V | None
def get: fn(key: K, default: V, /) -> V
// def get: fn[T](key: K, default: T, /) -> V | T
def pop: fn(key: K, /) -> V
def pop: fn(key: K, default: V, /) -> V
// def pop: fn[T](key: K, default: T, /) -> V | T
def __len__: fn() -> int
def __getitem__: fn(key: K, /) -> V
def __setitem__: fn(key: K, value: V, /) -> None
def __delitem__: fn(key: K, /) -> None
// def __iter__: fn() -> Iterator[K]
def __eq__: fn(value: object, /) -> bool
// def __reversed__: fn() -> Iterator[K]
def __or__: fn(value: dict[K, V], /) -> dict[K, V]
// def __or__: fn[K2, V2](value: dict[K2, V2], /) -> dict[K | K2, V | V2]
def __ror__: fn(value: dict[K, V], /) -> dict[K, V]
// def __ror__: fn[K2, V2](value: dict[K2, V2], /) -> dict[K | K2, V | V2]
// def __ior__: fn(value: SupportsKeysAndGetItem[K, V], /) -> dict[K, V]
// def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V]
}

View File

@@ -39,3 +39,14 @@ def define_builtins(reg: TypesRegistry):
body=BaseType(name="list"), body=BaseType(name="list"),
), ),
) )
dict = reg.define_type(
"dict",
GenericType(
name="dict",
params=[
TypeVar(name="K", bound=None),
TypeVar(name="V", bound=None),
],
body=BaseType(name="dict"),
),
)

View File

@@ -102,7 +102,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
base_name, base_name,
member.name.lexeme, member.name.lexeme,
member_type, member_type,
member.kind == m.MemberKind.METHOD, member.kind,
) )
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:

View File

@@ -61,7 +61,7 @@ class Preamble(Environment):
# TODO: more specific arg types # TODO: more specific arg types
self._def_function( self._def_function(
name=name, name=name,
pos=[Param("object", TopType())], pos=[Param("object", TopType(), required=False)],
returns=self._types.get_type(name), returns=self._types.get_type(name),
) )

View File

@@ -12,6 +12,7 @@ from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver from midas.checker.resolver import Resolver
from midas.checker.types import ( from midas.checker.types import (
AliasType,
AppliedType, AppliedType,
Function, Function,
OverloadedFunction, OverloadedFunction,
@@ -552,6 +553,46 @@ class PythonTyper(
) )
return self.types.apply_generic(list_type, [UnknownType()]) return self.types.apply_generic(list_type, [UnknownType()])
def visit_dict_expr(self, expr: p.DictExpr) -> Type:
dict_type: Type = self.types.get_type("dict")
key_types: list[Type] = []
value_types: list[Type] = []
for key, value in zip(expr.keys, expr.values):
if key is None:
self.reporter.warning(
value.location, "Dictionary unpacking not supported"
)
continue
key_types.append(self.type_of(key))
value_types.append(self.type_of(value))
key_types = self.types.reduce_types(key_types)
value_types = self.types.reduce_types(value_types)
if len(key_types) == 0 or len(value_types) == 0:
return dict_type
key_type: Type = UnknownType()
value_type: Type = UnknownType()
if len(key_types) == 1:
key_type = key_types[0]
else:
self.reporter.error(
expr.location,
f"Heterogeneous dict keys: {key_types}",
)
if len(value_types) == 1:
value_type = value_types[0]
else:
self.reporter.error(
expr.location,
f"Heterogeneous dict values: {value_types}",
)
return self.types.apply_generic(dict_type, [key_type, value_type])
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type: def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type:
object: Type = self.type_of(expr.object) object: Type = self.type_of(expr.object)
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__") operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
@@ -654,9 +695,17 @@ class PythonTyper(
case UnknownType(): case UnknownType():
return UnknownType() return UnknownType()
case AliasType(type=base):
return self._get_call_result(
location, base, positional, keywords, report_errors
)
case _: case _:
if report_errors: if report_errors:
self.reporter.error(location, f"{callee} is not callable") self.reporter.error(
location,
f"{callee} ({callee.__class__.__name__}) is not callable",
)
return None return None
def _are_arguments_valid( def _are_arguments_valid(

View File

@@ -1,6 +1,8 @@
import logging import logging
from dataclasses import dataclass
from typing import Optional from typing import Optional
from midas.ast.midas import MemberKind
from midas.checker.builtins import BUILTIN_SUBTYPES from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.types import ( from midas.checker.types import (
AliasType, AliasType,
@@ -19,11 +21,17 @@ from midas.checker.types import (
) )
@dataclass
class Member:
kind: MemberKind
type: Type
class TypesRegistry: class TypesRegistry:
def __init__(self) -> None: def __init__(self) -> None:
self.logger: logging.Logger = logging.getLogger("TypesRegistry") self.logger: logging.Logger = logging.getLogger("TypesRegistry")
self._types: dict[str, Type] = {} self._types: dict[str, Type] = {}
self._members: dict[str, dict[str, Type]] = {} self._members: dict[str, dict[str, Member]] = {}
def get_type(self, name: str) -> Type: def get_type(self, name: str) -> Type:
"""Get a type from its name """Get a type from its name
@@ -60,26 +68,38 @@ class TypesRegistry:
return type return type
def define_member( def define_member(
self, type_name: str, member_name: str, member_type: Type, is_method: bool self,
type_name: str,
member_name: str,
member_type: Type,
kind: MemberKind,
): ):
members: dict[str, Type] = self._members.setdefault(type_name, {}) members: dict[str, Member] = self._members.setdefault(type_name, {})
if member_name in members: if member_name in members:
if not is_method: current: Member = members[member_name]
if current.kind != kind:
self.logger.error( self.logger.error(
f"Member '{member_name}' already defined for type {type_name}" f"Member '{member_name}' is already defined as a {current.kind},"
+ f" cannot define a {kind} with the same name"
) )
return return
current: Type = members[member_name] if kind != MemberKind.METHOD:
self.logger.error(
f"Member '{member_name}' already defined for type {type_name},"
+ " only methods can be overloaded"
)
return
combined: Type combined: Type
match current: match current.type:
case OverloadedFunction(overloads=overloads): case OverloadedFunction(overloads=overloads):
combined = OverloadedFunction(overloads=overloads + [member_type]) combined = OverloadedFunction(overloads=overloads + [member_type])
case _: case _:
combined = OverloadedFunction(overloads=[current, member_type]) combined = OverloadedFunction(overloads=[current.type, member_type])
members[member_name] = combined members[member_name] = Member(kind=current.kind, type=combined)
else: else:
members[member_name] = member_type members[member_name] = Member(kind=kind, type=member_type)
def is_subtype(self, type1: Type, type2: Type) -> bool: def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2` """Check whether `type1` is a subtype of `type2`
@@ -297,13 +317,13 @@ class TypesRegistry:
case BaseType(name=name): case BaseType(name=name):
if name in self._members: if name in self._members:
if member_name in self._members[name]: if member_name in self._members[name]:
return self._members[name][member_name] return self._members[name][member_name].type
return None return None
case AliasType(name=name, type=base): case AliasType(name=name, type=base):
if name in self._members: if name in self._members:
if member_name in self._members[name]: if member_name in self._members[name]:
return self._members[name][member_name] return self._members[name][member_name].type
return self.lookup_member(base, member_name) return self.lookup_member(base, member_name)
case AppliedType(name=name, body=body, args=args): case AppliedType(name=name, body=body, args=args):
@@ -317,7 +337,7 @@ class TypesRegistry:
} }
if name in self._members: if name in self._members:
if member_name in self._members[name]: if member_name in self._members[name]:
member_type: Type = self._members[name][member_name] member_type: Type = self._members[name][member_name].type
return substitute_typevars(member_type, substitutions) return substitute_typevars(member_type, substitutions)
member_type2: Optional[Type] = self.lookup_member(body, member_name) member_type2: Optional[Type] = self.lookup_member(body, member_name)

View File

@@ -213,6 +213,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
for item in expr.items: for item in expr.items:
self.resolve(item) self.resolve(item)
def visit_dict_expr(self, expr: p.DictExpr) -> None:
for key in expr.keys:
if key is not None:
self.resolve(key)
for value in expr.values:
self.resolve(value)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
self.resolve(expr.object) self.resolve(expr.object)
self.resolve(expr.index) self.resolve(expr.index)

View File

@@ -140,6 +140,9 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
) )
match type: match type:
case TopType():
return type
case BaseType(name=name) if name in substitutions: case BaseType(name=name) if name in substitutions:
return substitutions[name] return substitutions[name]
@@ -200,6 +203,21 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
return substitutions[name] return substitutions[name]
raise ValueError(f"Missing TypeVar substitution for {name}") raise ValueError(f"Missing TypeVar substitution for {name}")
case GenericType(name=name, params=params, body=body):
params2: list[TypeVar] = []
for param in params:
param2: Type = substitute_typevars(param, substitutions)
if not isinstance(param2, TypeVar):
raise ValueError(
f"Invalid type parameter substitution, expected TypeVar, got {param2}"
)
params2.append(param2)
return GenericType(
name=name,
params=params2,
body=substitute_typevars(body, substitutions),
)
case UnknownType() | UnitType(): case UnknownType() | UnitType():
return type return type

View File

@@ -4,5 +4,6 @@ from .format import format as format
from .highlight import highlight as highlight from .highlight import highlight as highlight
from .parse import parse as parse from .parse import parse as parse
from .registry import dump_registry as dump_registry from .registry import dump_registry as dump_registry
from .stubs import stubs as stubs
from .types import types as types from .types import types as types
from .validate import validate as validate from .validate import validate as validate

View File

@@ -0,0 +1,27 @@
import ast
from pathlib import Path
from typing import TextIO
import click
from midas.checker.checker import TypeChecker
from midas.generator.stubs import StubsGenerator
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
def stubs(
file: TextIO,
output: TextIO,
):
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
checker.import_midas(source_path)
generator = StubsGenerator(checker.types)
module: ast.Module = generator.generate_stubs()
module = ast.fix_missing_locations(module)
output.write(ast.unparse(module))

View File

@@ -18,6 +18,7 @@ midas.add_command(commands.highlight)
midas.add_command(commands.parse) midas.add_command(commands.parse)
midas.add_command(commands.dump_registry) midas.add_command(commands.dump_registry)
midas.add_command(commands.types) midas.add_command(commands.types)
midas.add_command(commands.stubs)
midas.add_command(commands.validate) midas.add_command(commands.validate)

View File

@@ -4,8 +4,8 @@ from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from midas.ast.location import Location
import midas.ast.python as p import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.types import ( from midas.checker.types import (
AliasType, AliasType,
AppliedType, AppliedType,
@@ -139,6 +139,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
elts=[item.accept(self) for item in expr.items], elts=[item.accept(self) for item in expr.items],
) )
def visit_dict_expr(self, expr: p.DictExpr) -> ast.expr:
return ast.Dict(
keys=[key.accept(self) if key is not None else None for key in expr.keys],
values=[value.accept(self) for value in expr.values],
)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr: def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr:
return ast.Subscript( return ast.Subscript(
value=expr.object.accept(self), value=expr.object.accept(self),

337
midas/generator/stubs.py Normal file
View File

@@ -0,0 +1,337 @@
import ast
from typing import Optional
import midas.ast.midas as m
from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
TopType,
Type,
TypeVar,
UnitType,
UnknownType,
substitute_typevars,
)
Empty = ast.Constant(value=...)
class StubsGenerator:
def __init__(self, types: TypesRegistry) -> None:
self.types: TypesRegistry = types
self.stubs: list[ast.stmt] = []
self.typing_imports: set[str] = set()
self.protocol_idx: int = 0
self.stub_idx: int = 0
self.type_var_idx: int = 0
self.substitutions: dict[str, dict[str, Type]] = {}
def generate_stubs(self) -> ast.Module:
self.stubs = []
self.typing_imports = set()
for name, type in self.types._types.items():
self.generate_stub(name, type)
imports = [
ast.ImportFrom(
module="__future__",
names=[ast.alias(name="annotations")],
level=0,
)
]
if len(self.typing_imports) != 0:
imports.append(
ast.ImportFrom(
module="typing",
names=[
ast.alias(name=name) for name in sorted(self.typing_imports)
],
level=0,
)
)
return ast.Module(body=imports + self.stubs, type_ignores=[])
def generate_stub(self, name: str, type: Type):
base_type: Type = type
members: dict[str, Member] = self.types._members.get(name, {})
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
return
bases: list[ast.expr] = []
substitutions: dict[str, Type] = {}
bases, substitutions = self.get_bases(type)
self.substitutions[name] = substitutions
body = self.generate_body(members, substitutions)
stub = ast.ClassDef(
name=name,
bases=bases,
body=body,
keywords=[],
decorator_list=[],
)
self.add_stub(stub)
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
match type:
case AliasType(type=base):
return [self.dump_type(base)], {}
case GenericType(params=params, body=body):
self.add_typing_import("Generic")
type_vars: ast.expr
params2: list[TypeVar] = self.define_type_vars(params)
if len(params) == 1:
type_vars = ast.Name(id=params2[0].name)
else:
type_vars = ast.Tuple(
elts=[ast.Name(id=param.name) for param in params2]
)
substitutions: dict[str, TypeVar] = {
param.name: param2 for param, param2 in zip(params, params2)
}
body_bases, body_subsitutions = self.get_bases(body)
return (
body_bases
+ [
ast.Subscript(
value=ast.Name(id="Generic"),
slice=type_vars,
)
],
body_subsitutions | substitutions,
)
case _:
return [], {}
def generate_body(
self, members: dict[str, Member], substitutions: dict[str, Type]
) -> list[ast.stmt]:
if len(members) == 0:
return [ast.Expr(value=Empty)]
body: list[ast.stmt] = []
for name, member in members.items():
type: Type = member.type
type = substitute_typevars(type, substitutions)
match member.kind:
case m.MemberKind.PROPERTY:
body.append(
ast.AnnAssign(
target=ast.Name(id=name),
annotation=self.dump_type(type),
simple=1,
)
)
case m.MemberKind.METHOD:
body.extend(self.dump_method(name, type))
return body
def dump_type(self, type: Type) -> ast.expr:
match type:
case AliasType(name=name) | GenericType(name=name) if (
name in self.substitutions
):
type = substitute_typevars(type, self.substitutions[name])
match type:
case TopType() | UnknownType():
self.add_typing_import("Any")
return ast.Name(id="Any")
case BaseType(name=name):
return ast.Name(id=name)
case AliasType(name=name):
return ast.Name(id=name)
case UnitType():
return ast.Constant(value=None)
case Function():
name: str = self.define_protocol(type)
return ast.Name(id=name)
case OverloadedFunction(overloads=overloads):
if len(overloads) == 1:
return self.dump_type(overloads[0])
return ast.BinOp(
left=self.dump_type(OverloadedFunction(overloads=overloads[:-1])),
op=ast.BitOr(),
right=self.dump_type(overloads[-1]),
)
case ComplexType():
name: str = self.new_stub_name()
self.generate_stub(name, type)
return ast.Name(id=name)
case ExtensionType():
raise NotImplementedError
case TypeVar():
return ast.Name(id=type.name)
case GenericType(name=name):
params: ast.expr
if len(type.params) == 1:
params = self.dump_type(type.params[0])
else:
params = ast.Tuple(
elts=[self.dump_type(param) for param in type.params]
)
return ast.Subscript(
value=ast.Name(id=type.name),
slice=params,
)
case AppliedType():
args: ast.expr
if len(type.args) == 1:
args = self.dump_type(type.args[0])
else:
args = ast.Tuple(elts=[self.dump_type(arg) for arg in type.args])
return ast.Subscript(
value=ast.Name(id=type.name),
slice=args,
)
def dump_method(
self, name: str, method: Type, overloaded: bool = False
) -> list[ast.stmt]:
match method:
case Function():
if overloaded:
self.add_typing_import("overload")
return [
ast.FunctionDef(
name=name,
args=self.dump_args(method, with_self=True),
returns=self.dump_type(method.returns),
body=[ast.Expr(value=Empty)],
decorator_list=[ast.Name(id="overload")] if overloaded else [],
)
]
case OverloadedFunction(overloads=overloads):
stmts: list[ast.stmt] = []
for overload in overloads:
stmts.extend(self.dump_method(name, overload, True))
return stmts
case _:
return [
ast.AnnAssign(
target=ast.Name(id=name),
annotation=self.dump_type(method),
simple=1,
)
]
def dump_args(self, func: Function, with_self: bool = False) -> ast.arguments:
pos: list[ast.arg] = [
ast.arg(arg=f"_{arg.pos}", annotation=self.dump_type(arg.type))
for arg in func.pos_args
]
mixed: list[ast.arg] = [
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
for arg in func.args
]
kw: list[ast.arg] = [
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
for arg in func.kw_args
]
defaults: list[ast.expr] = [
Empty for arg in func.pos_args + func.args if not arg.required
]
kw_defaults: list[Optional[ast.expr]] = [
None if arg.required else Empty for arg in func.kw_args
]
if with_self:
arg = ast.arg(arg="self", annotation=None)
if len(pos) != 0:
pos.insert(0, arg)
else:
mixed.insert(0, arg)
return ast.arguments(
posonlyargs=pos,
args=mixed,
kwonlyargs=kw,
defaults=defaults,
kw_defaults=kw_defaults,
)
def define_protocol(self, func: Function) -> str:
self.add_typing_import("Protocol")
name: str = self.new_protocol_name()
protocol = ast.ClassDef(
name=name,
bases=[ast.Name(id="Protocol")],
keywords=[],
body=[
ast.FunctionDef(
name="__call__",
args=self.dump_args(func, with_self=True),
returns=self.dump_type(func.returns),
body=[ast.Expr(value=Empty)],
decorator_list=[],
),
],
decorator_list=[],
)
self.add_stub(protocol)
return name
def new_protocol_name(self) -> str:
name: str = f"_Protocol{self.protocol_idx}"
self.protocol_idx += 1
return name
def new_stub_name(self) -> str:
name: str = f"_Stub_{self.stub_idx}"
self.stub_idx += 1
return name
def new_type_var_name(self) -> str:
name: str = f"_T{self.type_var_idx}"
self.type_var_idx += 1
return name
def add_stub(self, stub: ast.stmt):
self.stubs.append(stub)
def add_typing_import(self, name: str):
self.typing_imports.add(name)
def define_type_vars(self, vars: list[TypeVar]) -> list[TypeVar]:
vars2: list[TypeVar] = []
for var in vars:
vars2.append(self.define_type_var(var))
return vars2
def define_type_var(self, var: TypeVar) -> TypeVar:
name: str = self.new_type_var_name()
self.add_typing_import("TypeVar")
self.add_stub(
ast.Assign(
targets=[ast.Name(id=name)],
value=ast.Call(
func=ast.Name(id="TypeVar"),
args=[
ast.Constant(value=name),
],
keywords=(
[]
if var.bound is None
else [
ast.keyword(
arg="bound",
value=self.dump_type(var.bound),
)
]
),
),
)
)
return TypeVar(name=name, bound=None)

View File

@@ -10,6 +10,7 @@ from midas.ast.python import (
CastExpr, CastExpr,
CompareExpr, CompareExpr,
ConstraintType, ConstraintType,
DictExpr,
Expr, Expr,
ExpressionStmt, ExpressionStmt,
ForStmt, ForStmt,
@@ -447,6 +448,16 @@ class PythonParser:
items=[self.parse_expr(item) for item in items], items=[self.parse_expr(item) for item in items],
) )
case ast.Dict(keys=keys, values=values):
return DictExpr(
location=location,
keys=[
self.parse_expr(key) if key is not None else None
for key in keys
],
values=[self.parse_expr(value) for value in values],
)
case ast.Subscript(value=value, slice=index): case ast.Subscript(value=value, slice=index):
return SubscriptExpr( return SubscriptExpr(
location=location, location=location,

View File

@@ -9,6 +9,7 @@ from midas.ast.python import (
CastExpr, CastExpr,
CompareExpr, CompareExpr,
ConstraintType, ConstraintType,
DictExpr,
Expr, Expr,
ExpressionStmt, ExpressionStmt,
ForStmt, ForStmt,
@@ -278,6 +279,13 @@ class PythonAstJsonSerializer(
"items": [item.accept(self) for item in expr.items], "items": [item.accept(self) for item in expr.items],
} }
def visit_dict_expr(self, expr: DictExpr) -> dict:
return {
"_type": "DictExpr",
"keys": [self._serialize_optional(key) for key in expr.keys],
"values": self._serialize_list(expr.values),
}
def visit_subscript_expr(self, expr: SubscriptExpr) -> dict: def visit_subscript_expr(self, expr: SubscriptExpr) -> dict:
return { return {
"_type": "SubscriptExpr", "_type": "SubscriptExpr",