diff --git a/README.md b/README.md index f20aad0..9799d0f 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ This framework is being developed as part of a Bachelor's Thesis by Louis Herede - [Highlighting](#highlighting) - [Dumping the AST](#dumping-the-ast) - [Dumping the Registry](#dumping-the-registry) + - [Generating Stubs](#generating-stubs) - [Showing Type Judgements](#showing-type-judgements) - [Validating Definitions](#validating-definitions) - [Tests](#tests) @@ -116,6 +117,14 @@ midas dump-registry -t types.midas This command processes the given Midas definitions and dumps the contents of the types registry. +### Generating Stubs + +```shell +midas stubs types.midas -o stubs.pyi +``` + +This command generate Python stubs from a Midas definition file + ### Showing Type Judgements ```shell diff --git a/midas/checker/midas.py b/midas/checker/midas.py index e79719c..13e83b6 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -173,7 +173,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type base_name, member.name.lexeme, member_type, - member.kind == m.MemberKind.METHOD, + member.kind, ) def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: diff --git a/midas/checker/python.py b/midas/checker/python.py index 6b2892b..22ea98c 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -16,6 +16,7 @@ from midas.checker.registry import TypesRegistry from midas.checker.reporter import FileReporter, Reporter from midas.checker.resolver import Resolver from midas.checker.types import ( + AliasType, AppliedType, Function, OverloadedFunction, @@ -698,9 +699,17 @@ class PythonTyper( case UnknownType(): return UnknownType() + case AliasType(type=base): + return self._get_call_result( + location, base, positional, keywords, report_errors + ) + case _: if report_errors: - self.reporter.error(location, f"{callee} is not callable") + self.reporter.error( + location, + f"{callee} ({callee.__class__.__name__}) is not callable", + ) return None def _are_arguments_valid( diff --git a/midas/checker/registry.py b/midas/checker/registry.py index b8f7dfa..b787f20 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -1,6 +1,8 @@ import logging +from dataclasses import dataclass from typing import Optional +from midas.ast.midas import MemberKind from midas.checker.builtins import BUILTIN_SUBTYPES from midas.checker.types import ( AliasType, @@ -22,11 +24,17 @@ from midas.checker.types import ( ) +@dataclass +class Member: + kind: MemberKind + type: Type + + class TypesRegistry: def __init__(self) -> None: self.logger: logging.Logger = logging.getLogger("TypesRegistry") self._types: dict[str, Type] = {} - self._members: dict[str, dict[str, Type]] = {} + self._members: dict[str, dict[str, Member]] = {} self._predicates: dict[str, Predicate] = {} def get_type(self, name: str) -> Type: @@ -64,26 +72,38 @@ class TypesRegistry: return type def define_member( - self, type_name: str, member_name: str, member_type: Type, is_method: bool + self, + type_name: str, + member_name: str, + member_type: Type, + kind: MemberKind, ): - members: dict[str, Type] = self._members.setdefault(type_name, {}) + members: dict[str, Member] = self._members.setdefault(type_name, {}) if member_name in members: - if not is_method: + current: Member = members[member_name] + if current.kind != kind: self.logger.error( - f"Member '{member_name}' already defined for type {type_name}" + f"Member '{member_name}' is already defined as a {current.kind}," + + f" cannot define a {kind} with the same name" ) return - current: Type = members[member_name] + if kind != MemberKind.METHOD: + self.logger.error( + f"Member '{member_name}' already defined for type {type_name}," + + " only methods can be overloaded" + ) + return + combined: Type - match current: + match current.type: case OverloadedFunction(overloads=overloads): combined = OverloadedFunction(overloads=overloads + [member_type]) case _: - combined = OverloadedFunction(overloads=[current, member_type]) - members[member_name] = combined + combined = OverloadedFunction(overloads=[current.type, member_type]) + members[member_name] = Member(kind=current.kind, type=combined) else: - members[member_name] = member_type + members[member_name] = Member(kind=kind, type=member_type) def define_predicate(self, name: str, predicate: Predicate): if name in self._predicates: @@ -327,13 +347,13 @@ class TypesRegistry: case BaseType(name=name): if name in self._members: if member_name in self._members[name]: - return self._members[name][member_name] + return self._members[name][member_name].type return None case AliasType(name=name, type=base): if name in self._members: if member_name in self._members[name]: - return self._members[name][member_name] + return self._members[name][member_name].type return self.lookup_member(base, member_name) case AppliedType(name=name, body=body, args=args): @@ -347,7 +367,7 @@ class TypesRegistry: } if name in self._members: if member_name in self._members[name]: - member_type: Type = self._members[name][member_name] + member_type: Type = self._members[name][member_name].type return substitute_typevars(member_type, substitutions) member_type2: Optional[Type] = self.lookup_member(body, member_name) diff --git a/midas/checker/types.py b/midas/checker/types.py index 60fb9c7..4ebda35 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -166,6 +166,9 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: ) match type: + case TopType(): + return type + case BaseType(name=name) if name in substitutions: return substitutions[name] @@ -232,6 +235,21 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: return substitutions[name] raise ValueError(f"Missing TypeVar substitution for {name}") + case GenericType(name=name, params=params, body=body): + params2: list[TypeVar] = [] + for param in params: + param2: Type = substitute_typevars(param, substitutions) + if not isinstance(param2, TypeVar): + raise ValueError( + f"Invalid type parameter substitution, expected TypeVar, got {param2}" + ) + params2.append(param2) + return GenericType( + name=name, + params=params2, + body=substitute_typevars(body, substitutions), + ) + case UnknownType() | UnitType(): return type diff --git a/midas/checker/variance.py b/midas/checker/variance.py index b620958..e7f5ac0 100644 --- a/midas/checker/variance.py +++ b/midas/checker/variance.py @@ -1,6 +1,6 @@ from typing import Literal, Optional, cast -from midas.checker.registry import TypesRegistry +from midas.checker.registry import Member, TypesRegistry from midas.checker.types import ( AppliedType, ConstraintType, @@ -54,9 +54,9 @@ class VarianceInferrer: self.tracker = Tracker(type.params) self.walk(type.body, 1, type.name) - members: dict[str, Type] = self.types._members.get(type.name, {}) + members: dict[str, Member] = self.types._members.get(type.name, {}) for name, member in members.items(): - self.walk(member, 1, type.name, [f"member:'{name}'"]) + self.walk(member.type, 1, type.name, [f"member:'{name}'"]) return GenericType( name=type.name, diff --git a/midas/cli/commands/__init__.py b/midas/cli/commands/__init__.py index 4a3c2a8..0a38d36 100644 --- a/midas/cli/commands/__init__.py +++ b/midas/cli/commands/__init__.py @@ -4,5 +4,6 @@ from .format import format as format from .highlight import highlight as highlight from .parse import parse as parse from .registry import dump_registry as dump_registry +from .stubs import stubs as stubs from .types import types as types from .validate import validate as validate diff --git a/midas/cli/commands/registry.py b/midas/cli/commands/registry.py index 7c0c521..12c222b 100644 --- a/midas/cli/commands/registry.py +++ b/midas/cli/commands/registry.py @@ -10,6 +10,7 @@ import click from midas.ast.printer import MidasPrinter from midas.checker.checker import TypeChecker +from midas.checker.registry import Member from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type @@ -38,7 +39,7 @@ def dump_registry( print("##### Types #####") for name, type in checker.types._types.items(): - members: dict[str, Type] = checker.types._members.get(name, {}) + members: dict[str, Member] = checker.types._members.get(name, {}) params: str = "" if isinstance(type, GenericType): params = ", ".join(map(str, type.params)) @@ -46,8 +47,9 @@ def dump_registry( print(f"{name}{params} = {base_type(type)}") if len(members) != 0: print(" " * 4 + "Members:") - for member_name, member_type in members.items(): - print(" " * 8 + f"{member_name}: {member_type}") + for member_name, member in members.items(): + kind: str = member.kind.name + print(" " * 8 + f"({kind:8}) {member_name}: {member.type}") print("##### Predicates #####") printer = MidasPrinter() diff --git a/midas/cli/commands/stubs.py b/midas/cli/commands/stubs.py new file mode 100644 index 0000000..98b3cd4 --- /dev/null +++ b/midas/cli/commands/stubs.py @@ -0,0 +1,27 @@ +import ast +from pathlib import Path +from typing import TextIO + +import click + +from midas.checker.checker import TypeChecker +from midas.generator.stubs import StubsGenerator + + +@click.command(help="Generate stubs from Midas definitions") +@click.argument("file", type=click.File("r")) +@click.option("-o", "--output", type=click.File("w"), default="-") +def stubs( + file: TextIO, + output: TextIO, +): + source_path: Path = Path(file.name).resolve() + + checker = TypeChecker() + checker.import_midas(source_path) + + generator = StubsGenerator(checker.types) + module: ast.Module = generator.generate_stubs() + module = ast.fix_missing_locations(module) + + output.write(ast.unparse(module)) diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py index 1303f94..ce6c2b2 100644 --- a/midas/cli/highlighter.py +++ b/midas/cli/highlighter.py @@ -228,6 +228,13 @@ class PythonHighlighter( for item in expr.items: item.accept(self) + def visit_dict_expr(self, expr: p.DictExpr) -> None: + for key in expr.keys: + if key is not None: + key.accept(self) + for value in expr.values: + value.accept(self) + def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: expr.object.accept(self) expr.index.accept(self) @@ -240,6 +247,10 @@ class PythonHighlighter( if expr.step is not None: expr.step.accept(self) + def visit_raw_expr(self, expr: p.RawExpr) -> None: ... + + def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ... + class MidasHighlighter( Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None] @@ -266,8 +277,9 @@ class MidasHighlighter( def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: self.wrap(stmt, "predicate") self.wrap(LocatableToken(stmt.name), "predicate-name") - stmt.type.accept(self) - stmt.condition.accept(self) + for spec in stmt.params: + self._visit_param_spec(spec) + stmt.body.accept(self) def visit_logical_expr(self, expr: m.LogicalExpr) -> None: self.wrap(expr, "logical-expr") @@ -283,6 +295,14 @@ class MidasHighlighter( self.wrap(expr, "unary-expr") expr.right.accept(self) + def visit_call_expr(self, expr: m.CallExpr) -> None: + self.wrap(expr, "call-expr") + expr.callee.accept(self) + for arg in expr.arguments: + arg.accept(self) + for arg in expr.keywords.values(): + arg.accept(self) + def visit_get_expr(self, expr: m.GetExpr) -> None: self.wrap(expr, "get-expr") expr.expr.accept(self) @@ -318,8 +338,7 @@ class MidasHighlighter( def visit_function_type(self, type: m.FunctionType) -> None: self.wrap(type, "function") - for arg in type.pos_args + type.args + type.kw_args: - arg.type.accept(self) + self._visit_param_spec(type.params) type.returns.accept(self) def visit_extension_type(self, type: m.ExtensionType) -> None: @@ -327,6 +346,10 @@ class MidasHighlighter( type.base.accept(self) type.extension.accept(self) + def _visit_param_spec(self, spec: m.ParamSpec) -> None: + for param in spec.pos + spec.mixed + spec.kw: + param.type.accept(self) + class DiagnosticsHighlighter(Highlighter): EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css" diff --git a/midas/cli/main.py b/midas/cli/main.py index 084adf4..05db15e 100644 --- a/midas/cli/main.py +++ b/midas/cli/main.py @@ -18,6 +18,7 @@ midas.add_command(commands.highlight) midas.add_command(commands.parse) midas.add_command(commands.dump_registry) midas.add_command(commands.types) +midas.add_command(commands.stubs) midas.add_command(commands.validate) diff --git a/midas/generator/stubs.py b/midas/generator/stubs.py new file mode 100644 index 0000000..d54c948 --- /dev/null +++ b/midas/generator/stubs.py @@ -0,0 +1,368 @@ +import ast +from typing import Optional, assert_never + +import midas.ast.midas as m +from midas.checker.registry import Member, TypesRegistry +from midas.checker.types import ( + AliasType, + AppliedType, + BaseType, + ComplexType, + ConstraintType, + ExtensionType, + Function, + GenericType, + OverloadedFunction, + TopType, + Type, + TypeVar, + UnitType, + UnknownType, + Variance, + substitute_typevars, +) + +Empty = ast.Constant(value=...) + + +class StubsGenerator: + def __init__(self, types: TypesRegistry) -> None: + self.types: TypesRegistry = types + self.stubs: list[ast.stmt] = [] + self.typing_imports: set[str] = set() + self.protocol_idx: int = 0 + self.stub_idx: int = 0 + self.type_var_idx: int = 0 + self.substitutions: dict[str, dict[str, Type]] = {} + + def generate_stubs(self) -> ast.Module: + self.stubs = [] + self.typing_imports = set() + for name, type in self.types._types.items(): + self.generate_stub(name, type) + + imports = [ + ast.ImportFrom( + module="__future__", + names=[ast.alias(name="annotations")], + level=0, + ) + ] + if len(self.typing_imports) != 0: + imports.append( + ast.ImportFrom( + module="typing", + names=[ + ast.alias(name=name) for name in sorted(self.typing_imports) + ], + level=0, + ) + ) + return ast.Module(body=imports + self.stubs, type_ignores=[]) + + def generate_stub(self, name: str, type: Type): + base_type: Type = type + + members: dict[str, Member] = self.types._members.get(name, {}) + if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0: + return + + bases: list[ast.expr] = [] + substitutions: dict[str, Type] = {} + bases, substitutions = self.get_bases(type) + self.substitutions[name] = substitutions + + body = self.generate_body(members, substitutions) + stub = ast.ClassDef( + name=name, + bases=bases, + body=body, + keywords=[], + decorator_list=[], + ) + self.add_stub(stub) + + def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]: + match type: + case AliasType(type=base): + return [self.dump_type(base)], {} + + case GenericType(params=params, body=body): + self.add_typing_import("Generic") + type_vars: ast.expr + + params2: list[TypeVar] = self.define_type_vars(params) + if len(params) == 1: + type_vars = ast.Name(id=params2[0].name) + else: + type_vars = ast.Tuple( + elts=[ast.Name(id=param.name) for param in params2] + ) + + substitutions: dict[str, TypeVar] = { + param.name: param2 for param, param2 in zip(params, params2) + } + + body_bases, body_subsitutions = self.get_bases(body) + return ( + body_bases + + [ + ast.Subscript( + value=ast.Name(id="Generic"), + slice=type_vars, + ) + ], + body_subsitutions | substitutions, + ) + + case _: + return [], {} + + def generate_body( + self, members: dict[str, Member], substitutions: dict[str, Type] + ) -> list[ast.stmt]: + if len(members) == 0: + return [ast.Expr(value=Empty)] + + body: list[ast.stmt] = [] + for name, member in members.items(): + type: Type = member.type + type = substitute_typevars(type, substitutions) + match member.kind: + case m.MemberKind.PROPERTY: + body.append( + ast.AnnAssign( + target=ast.Name(id=name), + annotation=self.dump_type(type), + simple=1, + ) + ) + case m.MemberKind.METHOD: + body.extend(self.dump_method(name, type)) + return body + + def dump_type(self, type: Type) -> ast.expr: + match type: + case AliasType(name=name) | GenericType(name=name) if ( + name in self.substitutions + ): + type = substitute_typevars(type, self.substitutions[name]) + + match type: + case TopType() | UnknownType(): + self.add_typing_import("Any") + return ast.Name(id="Any") + + case BaseType(name=name): + return ast.Name(id=name) + + case AliasType(name=name): + return ast.Name(id=name) + + case UnitType(): + return ast.Constant(value=None) + + case Function(): + name: str = self.define_protocol(type) + return ast.Name(id=name) + + case OverloadedFunction(overloads=overloads): + if len(overloads) == 1: + return self.dump_type(overloads[0]) + return ast.BinOp( + left=self.dump_type(OverloadedFunction(overloads=overloads[:-1])), + op=ast.BitOr(), + right=self.dump_type(overloads[-1]), + ) + + case ComplexType(): + name: str = self.new_stub_name() + self.generate_stub(name, type) + return ast.Name(id=name) + + case ExtensionType(): + raise NotImplementedError + + case TypeVar(): + return ast.Name(id=type.name) + + case GenericType(name=name): + params: ast.expr + if len(type.params) == 1: + params = self.dump_type(type.params[0]) + else: + params = ast.Tuple( + elts=[self.dump_type(param) for param in type.params] + ) + return ast.Subscript( + value=ast.Name(id=type.name), + slice=params, + ) + + case AppliedType(): + args: ast.expr + if len(type.args) == 1: + args = self.dump_type(type.args[0]) + else: + args = ast.Tuple(elts=[self.dump_type(arg) for arg in type.args]) + return ast.Subscript( + value=ast.Name(id=type.name), + slice=args, + ) + + case ConstraintType(): + return self.dump_type(type.type) + + case _: + assert_never(type) + + def dump_method( + self, name: str, method: Type, overloaded: bool = False + ) -> list[ast.stmt]: + match method: + case Function(): + if overloaded: + self.add_typing_import("overload") + return [ + ast.FunctionDef( + name=name, + args=self.dump_args(method, with_self=True), + returns=self.dump_type(method.returns), + body=[ast.Expr(value=Empty)], + decorator_list=[ast.Name(id="overload")] if overloaded else [], + ) + ] + case OverloadedFunction(overloads=overloads): + stmts: list[ast.stmt] = [] + for overload in overloads: + stmts.extend(self.dump_method(name, overload, True)) + return stmts + case _: + return [ + ast.AnnAssign( + target=ast.Name(id=name), + annotation=self.dump_type(method), + simple=1, + ) + ] + + def dump_args(self, func: Function, with_self: bool = False) -> ast.arguments: + pos: list[ast.arg] = [ + ast.arg(arg=f"_{arg.pos}", annotation=self.dump_type(arg.type)) + for arg in func.pos_args + ] + mixed: list[ast.arg] = [ + ast.arg(arg=arg.name, annotation=self.dump_type(arg.type)) + for arg in func.args + ] + kw: list[ast.arg] = [ + ast.arg(arg=arg.name, annotation=self.dump_type(arg.type)) + for arg in func.kw_args + ] + defaults: list[ast.expr] = [ + Empty for arg in func.pos_args + func.args if not arg.required + ] + kw_defaults: list[Optional[ast.expr]] = [ + None if arg.required else Empty for arg in func.kw_args + ] + if with_self: + arg = ast.arg(arg="self", annotation=None) + if len(pos) != 0: + pos.insert(0, arg) + else: + mixed.insert(0, arg) + return ast.arguments( + posonlyargs=pos, + args=mixed, + kwonlyargs=kw, + defaults=defaults, + kw_defaults=kw_defaults, + ) + + def define_protocol(self, func: Function) -> str: + self.add_typing_import("Protocol") + name: str = self.new_protocol_name() + protocol = ast.ClassDef( + name=name, + bases=[ast.Name(id="Protocol")], + keywords=[], + body=[ + ast.FunctionDef( + name="__call__", + args=self.dump_args(func, with_self=True), + returns=self.dump_type(func.returns), + body=[ast.Expr(value=Empty)], + decorator_list=[], + ), + ], + decorator_list=[], + ) + self.add_stub(protocol) + return name + + def new_protocol_name(self) -> str: + name: str = f"_Protocol{self.protocol_idx}" + self.protocol_idx += 1 + return name + + def new_stub_name(self) -> str: + name: str = f"_Stub_{self.stub_idx}" + self.stub_idx += 1 + return name + + def new_type_var_name(self) -> str: + name: str = f"_T{self.type_var_idx}" + self.type_var_idx += 1 + return name + + def add_stub(self, stub: ast.stmt): + self.stubs.append(stub) + + def add_typing_import(self, name: str): + self.typing_imports.add(name) + + def define_type_vars(self, vars: list[TypeVar]) -> list[TypeVar]: + vars2: list[TypeVar] = [] + for var in vars: + vars2.append(self.define_type_var(var)) + return vars2 + + def define_type_var(self, var: TypeVar) -> TypeVar: + name: str = self.new_type_var_name() + self.add_typing_import("TypeVar") + + kwargs: list[ast.keyword] = [] + if var.bound is not None: + kwargs.append( + ast.keyword( + arg="bound", + value=self.dump_type(var.bound), + ) + ) + if var.variance == Variance.COVARIANT: + kwargs.append( + ast.keyword( + arg="covariant", + value=ast.Constant(value=True), + ) + ) + elif var.variance == Variance.CONTRAVARIANT: + kwargs.append( + ast.keyword( + arg="contravariant", + value=ast.Constant(value=True), + ) + ) + self.add_stub( + ast.Assign( + targets=[ast.Name(id=name)], + value=ast.Call( + func=ast.Name(id="TypeVar"), + args=[ + ast.Constant(value=name), + ], + keywords=kwargs, + ), + ) + ) + return TypeVar(name=name, bound=None)