Stubs generator #17

Merged
HEL merged 9 commits from feat/stubs-gen into main 2026-06-20 15:44:35 +00:00
12 changed files with 503 additions and 25 deletions

View File

@@ -18,6 +18,7 @@ This framework is being developed as part of a Bachelor's Thesis by Louis Herede
- [Highlighting](#highlighting) - [Highlighting](#highlighting)
- [Dumping the AST](#dumping-the-ast) - [Dumping the AST](#dumping-the-ast)
- [Dumping the Registry](#dumping-the-registry) - [Dumping the Registry](#dumping-the-registry)
- [Generating Stubs](#generating-stubs)
- [Showing Type Judgements](#showing-type-judgements) - [Showing Type Judgements](#showing-type-judgements)
- [Validating Definitions](#validating-definitions) - [Validating Definitions](#validating-definitions)
- [Tests](#tests) - [Tests](#tests)
@@ -116,6 +117,14 @@ midas dump-registry -t types.midas
This command processes the given Midas definitions and dumps the contents of the types registry. This command processes the given Midas definitions and dumps the contents of the types registry.
### Generating Stubs
```shell
midas stubs types.midas -o stubs.pyi
```
This command generate Python stubs from a Midas definition file
### Showing Type Judgements ### Showing Type Judgements
```shell ```shell

View File

@@ -173,7 +173,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
base_name, base_name,
member.name.lexeme, member.name.lexeme,
member_type, member_type,
member.kind == m.MemberKind.METHOD, member.kind,
) )
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:

View File

@@ -16,6 +16,7 @@ from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver from midas.checker.resolver import Resolver
from midas.checker.types import ( from midas.checker.types import (
AliasType,
AppliedType, AppliedType,
Function, Function,
OverloadedFunction, OverloadedFunction,
@@ -698,9 +699,17 @@ class PythonTyper(
case UnknownType(): case UnknownType():
return UnknownType() return UnknownType()
case AliasType(type=base):
return self._get_call_result(
location, base, positional, keywords, report_errors
)
case _: case _:
if report_errors: if report_errors:
self.reporter.error(location, f"{callee} is not callable") self.reporter.error(
location,
f"{callee} ({callee.__class__.__name__}) is not callable",
)
return None return None
def _are_arguments_valid( def _are_arguments_valid(

View File

@@ -1,6 +1,8 @@
import logging import logging
from dataclasses import dataclass
from typing import Optional from typing import Optional
from midas.ast.midas import MemberKind
from midas.checker.builtins import BUILTIN_SUBTYPES from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.types import ( from midas.checker.types import (
AliasType, AliasType,
@@ -22,11 +24,17 @@ from midas.checker.types import (
) )
@dataclass
class Member:
kind: MemberKind
type: Type
class TypesRegistry: class TypesRegistry:
def __init__(self) -> None: def __init__(self) -> None:
self.logger: logging.Logger = logging.getLogger("TypesRegistry") self.logger: logging.Logger = logging.getLogger("TypesRegistry")
self._types: dict[str, Type] = {} self._types: dict[str, Type] = {}
self._members: dict[str, dict[str, Type]] = {} self._members: dict[str, dict[str, Member]] = {}
self._predicates: dict[str, Predicate] = {} self._predicates: dict[str, Predicate] = {}
def get_type(self, name: str) -> Type: def get_type(self, name: str) -> Type:
@@ -64,26 +72,38 @@ class TypesRegistry:
return type return type
def define_member( def define_member(
self, type_name: str, member_name: str, member_type: Type, is_method: bool self,
type_name: str,
member_name: str,
member_type: Type,
kind: MemberKind,
): ):
members: dict[str, Type] = self._members.setdefault(type_name, {}) members: dict[str, Member] = self._members.setdefault(type_name, {})
if member_name in members: if member_name in members:
if not is_method: current: Member = members[member_name]
if current.kind != kind:
self.logger.error( self.logger.error(
f"Member '{member_name}' already defined for type {type_name}" f"Member '{member_name}' is already defined as a {current.kind},"
+ f" cannot define a {kind} with the same name"
) )
return return
current: Type = members[member_name] if kind != MemberKind.METHOD:
self.logger.error(
f"Member '{member_name}' already defined for type {type_name},"
+ " only methods can be overloaded"
)
return
combined: Type combined: Type
match current: match current.type:
case OverloadedFunction(overloads=overloads): case OverloadedFunction(overloads=overloads):
combined = OverloadedFunction(overloads=overloads + [member_type]) combined = OverloadedFunction(overloads=overloads + [member_type])
case _: case _:
combined = OverloadedFunction(overloads=[current, member_type]) combined = OverloadedFunction(overloads=[current.type, member_type])
members[member_name] = combined members[member_name] = Member(kind=current.kind, type=combined)
else: else:
members[member_name] = member_type members[member_name] = Member(kind=kind, type=member_type)
def define_predicate(self, name: str, predicate: Predicate): def define_predicate(self, name: str, predicate: Predicate):
if name in self._predicates: if name in self._predicates:
@@ -327,13 +347,13 @@ class TypesRegistry:
case BaseType(name=name): case BaseType(name=name):
if name in self._members: if name in self._members:
if member_name in self._members[name]: if member_name in self._members[name]:
return self._members[name][member_name] return self._members[name][member_name].type
return None return None
case AliasType(name=name, type=base): case AliasType(name=name, type=base):
if name in self._members: if name in self._members:
if member_name in self._members[name]: if member_name in self._members[name]:
return self._members[name][member_name] return self._members[name][member_name].type
return self.lookup_member(base, member_name) return self.lookup_member(base, member_name)
case AppliedType(name=name, body=body, args=args): case AppliedType(name=name, body=body, args=args):
@@ -347,7 +367,7 @@ class TypesRegistry:
} }
if name in self._members: if name in self._members:
if member_name in self._members[name]: if member_name in self._members[name]:
member_type: Type = self._members[name][member_name] member_type: Type = self._members[name][member_name].type
return substitute_typevars(member_type, substitutions) return substitute_typevars(member_type, substitutions)
member_type2: Optional[Type] = self.lookup_member(body, member_name) member_type2: Optional[Type] = self.lookup_member(body, member_name)

View File

@@ -166,6 +166,9 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
) )
match type: match type:
case TopType():
return type
case BaseType(name=name) if name in substitutions: case BaseType(name=name) if name in substitutions:
return substitutions[name] return substitutions[name]
@@ -232,6 +235,21 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
return substitutions[name] return substitutions[name]
raise ValueError(f"Missing TypeVar substitution for {name}") raise ValueError(f"Missing TypeVar substitution for {name}")
case GenericType(name=name, params=params, body=body):
params2: list[TypeVar] = []
for param in params:
param2: Type = substitute_typevars(param, substitutions)
if not isinstance(param2, TypeVar):
raise ValueError(
f"Invalid type parameter substitution, expected TypeVar, got {param2}"
)
params2.append(param2)
return GenericType(
name=name,
params=params2,
body=substitute_typevars(body, substitutions),
)
case UnknownType() | UnitType(): case UnknownType() | UnitType():
return type return type

View File

@@ -1,6 +1,6 @@
from typing import Literal, Optional, cast from typing import Literal, Optional, cast
from midas.checker.registry import TypesRegistry from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import ( from midas.checker.types import (
AppliedType, AppliedType,
ConstraintType, ConstraintType,
@@ -54,9 +54,9 @@ class VarianceInferrer:
self.tracker = Tracker(type.params) self.tracker = Tracker(type.params)
self.walk(type.body, 1, type.name) self.walk(type.body, 1, type.name)
members: dict[str, Type] = self.types._members.get(type.name, {}) members: dict[str, Member] = self.types._members.get(type.name, {})
for name, member in members.items(): for name, member in members.items():
self.walk(member, 1, type.name, [f"member:'{name}'"]) self.walk(member.type, 1, type.name, [f"member:'{name}'"])
return GenericType( return GenericType(
name=type.name, name=type.name,

View File

@@ -4,5 +4,6 @@ from .format import format as format
from .highlight import highlight as highlight from .highlight import highlight as highlight
from .parse import parse as parse from .parse import parse as parse
from .registry import dump_registry as dump_registry from .registry import dump_registry as dump_registry
from .stubs import stubs as stubs
from .types import types as types from .types import types as types
from .validate import validate as validate from .validate import validate as validate

View File

@@ -10,6 +10,7 @@ import click
from midas.ast.printer import MidasPrinter from midas.ast.printer import MidasPrinter
from midas.checker.checker import TypeChecker from midas.checker.checker import TypeChecker
from midas.checker.registry import Member
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
@@ -38,7 +39,7 @@ def dump_registry(
print("##### Types #####") print("##### Types #####")
for name, type in checker.types._types.items(): for name, type in checker.types._types.items():
members: dict[str, Type] = checker.types._members.get(name, {}) members: dict[str, Member] = checker.types._members.get(name, {})
params: str = "" params: str = ""
if isinstance(type, GenericType): if isinstance(type, GenericType):
params = ", ".join(map(str, type.params)) params = ", ".join(map(str, type.params))
@@ -46,8 +47,9 @@ def dump_registry(
print(f"{name}{params} = {base_type(type)}") print(f"{name}{params} = {base_type(type)}")
if len(members) != 0: if len(members) != 0:
print(" " * 4 + "Members:") print(" " * 4 + "Members:")
for member_name, member_type in members.items(): for member_name, member in members.items():
print(" " * 8 + f"{member_name}: {member_type}") kind: str = member.kind.name
print(" " * 8 + f"({kind:8}) {member_name}: {member.type}")
print("##### Predicates #####") print("##### Predicates #####")
printer = MidasPrinter() printer = MidasPrinter()

View File

@@ -0,0 +1,27 @@
import ast
from pathlib import Path
from typing import TextIO
import click
from midas.checker.checker import TypeChecker
from midas.generator.stubs import StubsGenerator
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
def stubs(
file: TextIO,
output: TextIO,
):
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
checker.import_midas(source_path)
generator = StubsGenerator(checker.types)
module: ast.Module = generator.generate_stubs()
module = ast.fix_missing_locations(module)
output.write(ast.unparse(module))

View File

@@ -228,6 +228,13 @@ class PythonHighlighter(
for item in expr.items: for item in expr.items:
item.accept(self) item.accept(self)
def visit_dict_expr(self, expr: p.DictExpr) -> None:
for key in expr.keys:
if key is not None:
key.accept(self)
for value in expr.values:
value.accept(self)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
expr.object.accept(self) expr.object.accept(self)
expr.index.accept(self) expr.index.accept(self)
@@ -240,6 +247,10 @@ class PythonHighlighter(
if expr.step is not None: if expr.step is not None:
expr.step.accept(self) expr.step.accept(self)
def visit_raw_expr(self, expr: p.RawExpr) -> None: ...
def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ...
class MidasHighlighter( class MidasHighlighter(
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None] Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
@@ -266,8 +277,9 @@ class MidasHighlighter(
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.wrap(stmt, "predicate") self.wrap(stmt, "predicate")
self.wrap(LocatableToken(stmt.name), "predicate-name") self.wrap(LocatableToken(stmt.name), "predicate-name")
stmt.type.accept(self) for spec in stmt.params:
stmt.condition.accept(self) self._visit_param_spec(spec)
stmt.body.accept(self)
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.wrap(expr, "logical-expr") self.wrap(expr, "logical-expr")
@@ -283,6 +295,14 @@ class MidasHighlighter(
self.wrap(expr, "unary-expr") self.wrap(expr, "unary-expr")
expr.right.accept(self) expr.right.accept(self)
def visit_call_expr(self, expr: m.CallExpr) -> None:
self.wrap(expr, "call-expr")
expr.callee.accept(self)
for arg in expr.arguments:
arg.accept(self)
for arg in expr.keywords.values():
arg.accept(self)
def visit_get_expr(self, expr: m.GetExpr) -> None: def visit_get_expr(self, expr: m.GetExpr) -> None:
self.wrap(expr, "get-expr") self.wrap(expr, "get-expr")
expr.expr.accept(self) expr.expr.accept(self)
@@ -318,8 +338,7 @@ class MidasHighlighter(
def visit_function_type(self, type: m.FunctionType) -> None: def visit_function_type(self, type: m.FunctionType) -> None:
self.wrap(type, "function") self.wrap(type, "function")
for arg in type.pos_args + type.args + type.kw_args: self._visit_param_spec(type.params)
arg.type.accept(self)
type.returns.accept(self) type.returns.accept(self)
def visit_extension_type(self, type: m.ExtensionType) -> None: def visit_extension_type(self, type: m.ExtensionType) -> None:
@@ -327,6 +346,10 @@ class MidasHighlighter(
type.base.accept(self) type.base.accept(self)
type.extension.accept(self) type.extension.accept(self)
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
for param in spec.pos + spec.mixed + spec.kw:
param.type.accept(self)
class DiagnosticsHighlighter(Highlighter): class DiagnosticsHighlighter(Highlighter):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css" EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"

View File

@@ -18,6 +18,7 @@ midas.add_command(commands.highlight)
midas.add_command(commands.parse) midas.add_command(commands.parse)
midas.add_command(commands.dump_registry) midas.add_command(commands.dump_registry)
midas.add_command(commands.types) midas.add_command(commands.types)
midas.add_command(commands.stubs)
midas.add_command(commands.validate) midas.add_command(commands.validate)

368
midas/generator/stubs.py Normal file
View File

@@ -0,0 +1,368 @@
import ast
from typing import Optional, assert_never
import midas.ast.midas as m
from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
ConstraintType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
TopType,
Type,
TypeVar,
UnitType,
UnknownType,
Variance,
substitute_typevars,
)
Empty = ast.Constant(value=...)
class StubsGenerator:
def __init__(self, types: TypesRegistry) -> None:
self.types: TypesRegistry = types
self.stubs: list[ast.stmt] = []
self.typing_imports: set[str] = set()
self.protocol_idx: int = 0
self.stub_idx: int = 0
self.type_var_idx: int = 0
self.substitutions: dict[str, dict[str, Type]] = {}
def generate_stubs(self) -> ast.Module:
self.stubs = []
self.typing_imports = set()
for name, type in self.types._types.items():
self.generate_stub(name, type)
imports = [
ast.ImportFrom(
module="__future__",
names=[ast.alias(name="annotations")],
level=0,
)
]
if len(self.typing_imports) != 0:
imports.append(
ast.ImportFrom(
module="typing",
names=[
ast.alias(name=name) for name in sorted(self.typing_imports)
],
level=0,
)
)
return ast.Module(body=imports + self.stubs, type_ignores=[])
def generate_stub(self, name: str, type: Type):
base_type: Type = type
members: dict[str, Member] = self.types._members.get(name, {})
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
return
bases: list[ast.expr] = []
substitutions: dict[str, Type] = {}
bases, substitutions = self.get_bases(type)
self.substitutions[name] = substitutions
body = self.generate_body(members, substitutions)
stub = ast.ClassDef(
name=name,
bases=bases,
body=body,
keywords=[],
decorator_list=[],
)
self.add_stub(stub)
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
match type:
case AliasType(type=base):
return [self.dump_type(base)], {}
case GenericType(params=params, body=body):
self.add_typing_import("Generic")
type_vars: ast.expr
params2: list[TypeVar] = self.define_type_vars(params)
if len(params) == 1:
type_vars = ast.Name(id=params2[0].name)
else:
type_vars = ast.Tuple(
elts=[ast.Name(id=param.name) for param in params2]
)
substitutions: dict[str, TypeVar] = {
param.name: param2 for param, param2 in zip(params, params2)
}
body_bases, body_subsitutions = self.get_bases(body)
return (
body_bases
+ [
ast.Subscript(
value=ast.Name(id="Generic"),
slice=type_vars,
)
],
body_subsitutions | substitutions,
)
case _:
return [], {}
def generate_body(
self, members: dict[str, Member], substitutions: dict[str, Type]
) -> list[ast.stmt]:
if len(members) == 0:
return [ast.Expr(value=Empty)]
body: list[ast.stmt] = []
for name, member in members.items():
type: Type = member.type
type = substitute_typevars(type, substitutions)
match member.kind:
case m.MemberKind.PROPERTY:
body.append(
ast.AnnAssign(
target=ast.Name(id=name),
annotation=self.dump_type(type),
simple=1,
)
)
case m.MemberKind.METHOD:
body.extend(self.dump_method(name, type))
return body
def dump_type(self, type: Type) -> ast.expr:
match type:
case AliasType(name=name) | GenericType(name=name) if (
name in self.substitutions
):
type = substitute_typevars(type, self.substitutions[name])
match type:
case TopType() | UnknownType():
self.add_typing_import("Any")
return ast.Name(id="Any")
case BaseType(name=name):
return ast.Name(id=name)
case AliasType(name=name):
return ast.Name(id=name)
case UnitType():
return ast.Constant(value=None)
case Function():
name: str = self.define_protocol(type)
return ast.Name(id=name)
case OverloadedFunction(overloads=overloads):
if len(overloads) == 1:
return self.dump_type(overloads[0])
return ast.BinOp(
left=self.dump_type(OverloadedFunction(overloads=overloads[:-1])),
op=ast.BitOr(),
right=self.dump_type(overloads[-1]),
)
case ComplexType():
name: str = self.new_stub_name()
self.generate_stub(name, type)
return ast.Name(id=name)
case ExtensionType():
raise NotImplementedError
case TypeVar():
return ast.Name(id=type.name)
case GenericType(name=name):
params: ast.expr
if len(type.params) == 1:
params = self.dump_type(type.params[0])
else:
params = ast.Tuple(
elts=[self.dump_type(param) for param in type.params]
)
return ast.Subscript(
value=ast.Name(id=type.name),
slice=params,
)
case AppliedType():
args: ast.expr
if len(type.args) == 1:
args = self.dump_type(type.args[0])
else:
args = ast.Tuple(elts=[self.dump_type(arg) for arg in type.args])
return ast.Subscript(
value=ast.Name(id=type.name),
slice=args,
)
case ConstraintType():
return self.dump_type(type.type)
case _:
assert_never(type)
def dump_method(
self, name: str, method: Type, overloaded: bool = False
) -> list[ast.stmt]:
match method:
case Function():
if overloaded:
self.add_typing_import("overload")
return [
ast.FunctionDef(
name=name,
args=self.dump_args(method, with_self=True),
returns=self.dump_type(method.returns),
body=[ast.Expr(value=Empty)],
decorator_list=[ast.Name(id="overload")] if overloaded else [],
)
]
case OverloadedFunction(overloads=overloads):
stmts: list[ast.stmt] = []
for overload in overloads:
stmts.extend(self.dump_method(name, overload, True))
return stmts
case _:
return [
ast.AnnAssign(
target=ast.Name(id=name),
annotation=self.dump_type(method),
simple=1,
)
]
def dump_args(self, func: Function, with_self: bool = False) -> ast.arguments:
pos: list[ast.arg] = [
ast.arg(arg=f"_{arg.pos}", annotation=self.dump_type(arg.type))
for arg in func.pos_args
]
mixed: list[ast.arg] = [
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
for arg in func.args
]
kw: list[ast.arg] = [
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
for arg in func.kw_args
]
defaults: list[ast.expr] = [
Empty for arg in func.pos_args + func.args if not arg.required
]
kw_defaults: list[Optional[ast.expr]] = [
None if arg.required else Empty for arg in func.kw_args
]
if with_self:
arg = ast.arg(arg="self", annotation=None)
if len(pos) != 0:
pos.insert(0, arg)
else:
mixed.insert(0, arg)
return ast.arguments(
posonlyargs=pos,
args=mixed,
kwonlyargs=kw,
defaults=defaults,
kw_defaults=kw_defaults,
)
def define_protocol(self, func: Function) -> str:
self.add_typing_import("Protocol")
name: str = self.new_protocol_name()
protocol = ast.ClassDef(
name=name,
bases=[ast.Name(id="Protocol")],
keywords=[],
body=[
ast.FunctionDef(
name="__call__",
args=self.dump_args(func, with_self=True),
returns=self.dump_type(func.returns),
body=[ast.Expr(value=Empty)],
decorator_list=[],
),
],
decorator_list=[],
)
self.add_stub(protocol)
return name
def new_protocol_name(self) -> str:
name: str = f"_Protocol{self.protocol_idx}"
self.protocol_idx += 1
return name
def new_stub_name(self) -> str:
name: str = f"_Stub_{self.stub_idx}"
self.stub_idx += 1
return name
def new_type_var_name(self) -> str:
name: str = f"_T{self.type_var_idx}"
self.type_var_idx += 1
return name
def add_stub(self, stub: ast.stmt):
self.stubs.append(stub)
def add_typing_import(self, name: str):
self.typing_imports.add(name)
def define_type_vars(self, vars: list[TypeVar]) -> list[TypeVar]:
vars2: list[TypeVar] = []
for var in vars:
vars2.append(self.define_type_var(var))
return vars2
def define_type_var(self, var: TypeVar) -> TypeVar:
name: str = self.new_type_var_name()
self.add_typing_import("TypeVar")
kwargs: list[ast.keyword] = []
if var.bound is not None:
kwargs.append(
ast.keyword(
arg="bound",
value=self.dump_type(var.bound),
)
)
if var.variance == Variance.COVARIANT:
kwargs.append(
ast.keyword(
arg="covariant",
value=ast.Constant(value=True),
)
)
elif var.variance == Variance.CONTRAVARIANT:
kwargs.append(
ast.keyword(
arg="contravariant",
value=ast.Constant(value=True),
)
)
self.add_stub(
ast.Assign(
targets=[ast.Name(id=name)],
value=ast.Call(
func=ast.Name(id="TypeVar"),
args=[
ast.Constant(value=name),
],
keywords=kwargs,
),
)
)
return TypeVar(name=name, bound=None)