5 Commits

8 changed files with 428 additions and 15 deletions

View File

@@ -102,7 +102,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
base_name, base_name,
member.name.lexeme, member.name.lexeme,
member_type, member_type,
member.kind == m.MemberKind.METHOD, member.kind,
) )
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:

View File

@@ -12,6 +12,7 @@ from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver from midas.checker.resolver import Resolver
from midas.checker.types import ( from midas.checker.types import (
AliasType,
AppliedType, AppliedType,
Function, Function,
OverloadedFunction, OverloadedFunction,
@@ -694,9 +695,17 @@ class PythonTyper(
case UnknownType(): case UnknownType():
return UnknownType() return UnknownType()
case AliasType(type=base):
return self._get_call_result(
location, base, positional, keywords, report_errors
)
case _: case _:
if report_errors: if report_errors:
self.reporter.error(location, f"{callee} is not callable") self.reporter.error(
location,
f"{callee} ({callee.__class__.__name__}) is not callable",
)
return None return None
def _are_arguments_valid( def _are_arguments_valid(

View File

@@ -1,6 +1,8 @@
import logging import logging
from dataclasses import dataclass
from typing import Optional from typing import Optional
from midas.ast.midas import MemberKind
from midas.checker.builtins import BUILTIN_SUBTYPES from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.types import ( from midas.checker.types import (
AliasType, AliasType,
@@ -19,11 +21,17 @@ from midas.checker.types import (
) )
@dataclass
class Member:
kind: MemberKind
type: Type
class TypesRegistry: class TypesRegistry:
def __init__(self) -> None: def __init__(self) -> None:
self.logger: logging.Logger = logging.getLogger("TypesRegistry") self.logger: logging.Logger = logging.getLogger("TypesRegistry")
self._types: dict[str, Type] = {} self._types: dict[str, Type] = {}
self._members: dict[str, dict[str, Type]] = {} self._members: dict[str, dict[str, Member]] = {}
def get_type(self, name: str) -> Type: def get_type(self, name: str) -> Type:
"""Get a type from its name """Get a type from its name
@@ -60,26 +68,38 @@ class TypesRegistry:
return type return type
def define_member( def define_member(
self, type_name: str, member_name: str, member_type: Type, is_method: bool self,
type_name: str,
member_name: str,
member_type: Type,
kind: MemberKind,
): ):
members: dict[str, Type] = self._members.setdefault(type_name, {}) members: dict[str, Member] = self._members.setdefault(type_name, {})
if member_name in members: if member_name in members:
if not is_method: current: Member = members[member_name]
if current.kind != kind:
self.logger.error( self.logger.error(
f"Member '{member_name}' already defined for type {type_name}" f"Member '{member_name}' is already defined as a {current.kind},"
+ f" cannot define a {kind} with the same name"
) )
return return
current: Type = members[member_name] if kind != MemberKind.METHOD:
self.logger.error(
f"Member '{member_name}' already defined for type {type_name},"
+ " only methods can be overloaded"
)
return
combined: Type combined: Type
match current: match current.type:
case OverloadedFunction(overloads=overloads): case OverloadedFunction(overloads=overloads):
combined = OverloadedFunction(overloads=overloads + [member_type]) combined = OverloadedFunction(overloads=overloads + [member_type])
case _: case _:
combined = OverloadedFunction(overloads=[current, member_type]) combined = OverloadedFunction(overloads=[current.type, member_type])
members[member_name] = combined members[member_name] = Member(kind=current.kind, type=combined)
else: else:
members[member_name] = member_type members[member_name] = Member(kind=kind, type=member_type)
def is_subtype(self, type1: Type, type2: Type) -> bool: def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2` """Check whether `type1` is a subtype of `type2`
@@ -297,13 +317,13 @@ class TypesRegistry:
case BaseType(name=name): case BaseType(name=name):
if name in self._members: if name in self._members:
if member_name in self._members[name]: if member_name in self._members[name]:
return self._members[name][member_name] return self._members[name][member_name].type
return None return None
case AliasType(name=name, type=base): case AliasType(name=name, type=base):
if name in self._members: if name in self._members:
if member_name in self._members[name]: if member_name in self._members[name]:
return self._members[name][member_name] return self._members[name][member_name].type
return self.lookup_member(base, member_name) return self.lookup_member(base, member_name)
case AppliedType(name=name, body=body, args=args): case AppliedType(name=name, body=body, args=args):
@@ -317,7 +337,7 @@ class TypesRegistry:
} }
if name in self._members: if name in self._members:
if member_name in self._members[name]: if member_name in self._members[name]:
member_type: Type = self._members[name][member_name] member_type: Type = self._members[name][member_name].type
return substitute_typevars(member_type, substitutions) return substitute_typevars(member_type, substitutions)
member_type2: Optional[Type] = self.lookup_member(body, member_name) member_type2: Optional[Type] = self.lookup_member(body, member_name)

View File

@@ -140,6 +140,9 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
) )
match type: match type:
case TopType():
return type
case BaseType(name=name) if name in substitutions: case BaseType(name=name) if name in substitutions:
return substitutions[name] return substitutions[name]
@@ -200,6 +203,21 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
return substitutions[name] return substitutions[name]
raise ValueError(f"Missing TypeVar substitution for {name}") raise ValueError(f"Missing TypeVar substitution for {name}")
case GenericType(name=name, params=params, body=body):
params2: list[TypeVar] = []
for param in params:
param2: Type = substitute_typevars(param, substitutions)
if not isinstance(param2, TypeVar):
raise ValueError(
f"Invalid type parameter substitution, expected TypeVar, got {param2}"
)
params2.append(param2)
return GenericType(
name=name,
params=params2,
body=substitute_typevars(body, substitutions),
)
case UnknownType() | UnitType(): case UnknownType() | UnitType():
return type return type

View File

@@ -4,5 +4,6 @@ from .format import format as format
from .highlight import highlight as highlight from .highlight import highlight as highlight
from .parse import parse as parse from .parse import parse as parse
from .registry import dump_registry as dump_registry from .registry import dump_registry as dump_registry
from .stubs import stubs as stubs
from .types import types as types from .types import types as types
from .validate import validate as validate from .validate import validate as validate

View File

@@ -0,0 +1,27 @@
import ast
from pathlib import Path
from typing import TextIO
import click
from midas.checker.checker import TypeChecker
from midas.generator.stubs import StubsGenerator
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
def stubs(
file: TextIO,
output: TextIO,
):
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
checker.import_midas(source_path)
generator = StubsGenerator(checker.types)
module: ast.Module = generator.generate_stubs()
module = ast.fix_missing_locations(module)
output.write(ast.unparse(module))

View File

@@ -18,6 +18,7 @@ midas.add_command(commands.highlight)
midas.add_command(commands.parse) midas.add_command(commands.parse)
midas.add_command(commands.dump_registry) midas.add_command(commands.dump_registry)
midas.add_command(commands.types) midas.add_command(commands.types)
midas.add_command(commands.stubs)
midas.add_command(commands.validate) midas.add_command(commands.validate)

337
midas/generator/stubs.py Normal file
View File

@@ -0,0 +1,337 @@
import ast
from typing import Optional
import midas.ast.midas as m
from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
TopType,
Type,
TypeVar,
UnitType,
UnknownType,
substitute_typevars,
)
Empty = ast.Constant(value=...)
class StubsGenerator:
def __init__(self, types: TypesRegistry) -> None:
self.types: TypesRegistry = types
self.stubs: list[ast.stmt] = []
self.typing_imports: set[str] = set()
self.protocol_idx: int = 0
self.stub_idx: int = 0
self.type_var_idx: int = 0
self.substitutions: dict[str, dict[str, Type]] = {}
def generate_stubs(self) -> ast.Module:
self.stubs = []
self.typing_imports = set()
for name, type in self.types._types.items():
self.generate_stub(name, type)
imports = [
ast.ImportFrom(
module="__future__",
names=[ast.alias(name="annotations")],
level=0,
)
]
if len(self.typing_imports) != 0:
imports.append(
ast.ImportFrom(
module="typing",
names=[
ast.alias(name=name) for name in sorted(self.typing_imports)
],
level=0,
)
)
return ast.Module(body=imports + self.stubs, type_ignores=[])
def generate_stub(self, name: str, type: Type):
base_type: Type = type
members: dict[str, Member] = self.types._members.get(name, {})
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
return
bases: list[ast.expr] = []
substitutions: dict[str, Type] = {}
bases, substitutions = self.get_bases(type)
self.substitutions[name] = substitutions
body = self.generate_body(members, substitutions)
stub = ast.ClassDef(
name=name,
bases=bases,
body=body,
keywords=[],
decorator_list=[],
)
self.add_stub(stub)
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
match type:
case AliasType(type=base):
return [self.dump_type(base)], {}
case GenericType(params=params, body=body):
self.add_typing_import("Generic")
type_vars: ast.expr
params2: list[TypeVar] = self.define_type_vars(params)
if len(params) == 1:
type_vars = ast.Name(id=params2[0].name)
else:
type_vars = ast.Tuple(
elts=[ast.Name(id=param.name) for param in params2]
)
substitutions: dict[str, TypeVar] = {
param.name: param2 for param, param2 in zip(params, params2)
}
body_bases, body_subsitutions = self.get_bases(body)
return (
body_bases
+ [
ast.Subscript(
value=ast.Name(id="Generic"),
slice=type_vars,
)
],
body_subsitutions | substitutions,
)
case _:
return [], {}
def generate_body(
self, members: dict[str, Member], substitutions: dict[str, Type]
) -> list[ast.stmt]:
if len(members) == 0:
return [ast.Expr(value=Empty)]
body: list[ast.stmt] = []
for name, member in members.items():
type: Type = member.type
type = substitute_typevars(type, substitutions)
match member.kind:
case m.MemberKind.PROPERTY:
body.append(
ast.AnnAssign(
target=ast.Name(id=name),
annotation=self.dump_type(type),
simple=1,
)
)
case m.MemberKind.METHOD:
body.extend(self.dump_method(name, type))
return body
def dump_type(self, type: Type) -> ast.expr:
match type:
case AliasType(name=name) | GenericType(name=name) if (
name in self.substitutions
):
type = substitute_typevars(type, self.substitutions[name])
match type:
case TopType() | UnknownType():
self.add_typing_import("Any")
return ast.Name(id="Any")
case BaseType(name=name):
return ast.Name(id=name)
case AliasType(name=name):
return ast.Name(id=name)
case UnitType():
return ast.Constant(value=None)
case Function():
name: str = self.define_protocol(type)
return ast.Name(id=name)
case OverloadedFunction(overloads=overloads):
if len(overloads) == 1:
return self.dump_type(overloads[0])
return ast.BinOp(
left=self.dump_type(OverloadedFunction(overloads=overloads[:-1])),
op=ast.BitOr(),
right=self.dump_type(overloads[-1]),
)
case ComplexType():
name: str = self.new_stub_name()
self.generate_stub(name, type)
return ast.Name(id=name)
case ExtensionType():
raise NotImplementedError
case TypeVar():
return ast.Name(id=type.name)
case GenericType(name=name):
params: ast.expr
if len(type.params) == 1:
params = self.dump_type(type.params[0])
else:
params = ast.Tuple(
elts=[self.dump_type(param) for param in type.params]
)
return ast.Subscript(
value=ast.Name(id=type.name),
slice=params,
)
case AppliedType():
args: ast.expr
if len(type.args) == 1:
args = self.dump_type(type.args[0])
else:
args = ast.Tuple(elts=[self.dump_type(arg) for arg in type.args])
return ast.Subscript(
value=ast.Name(id=type.name),
slice=args,
)
def dump_method(
self, name: str, method: Type, overloaded: bool = False
) -> list[ast.stmt]:
match method:
case Function():
if overloaded:
self.add_typing_import("overload")
return [
ast.FunctionDef(
name=name,
args=self.dump_args(method, with_self=True),
returns=self.dump_type(method.returns),
body=[ast.Expr(value=Empty)],
decorator_list=[ast.Name(id="overload")] if overloaded else [],
)
]
case OverloadedFunction(overloads=overloads):
stmts: list[ast.stmt] = []
for overload in overloads:
stmts.extend(self.dump_method(name, overload, True))
return stmts
case _:
return [
ast.AnnAssign(
target=ast.Name(id=name),
annotation=self.dump_type(method),
simple=1,
)
]
def dump_args(self, func: Function, with_self: bool = False) -> ast.arguments:
pos: list[ast.arg] = [
ast.arg(arg=f"_{arg.pos}", annotation=self.dump_type(arg.type))
for arg in func.pos_args
]
mixed: list[ast.arg] = [
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
for arg in func.args
]
kw: list[ast.arg] = [
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
for arg in func.kw_args
]
defaults: list[ast.expr] = [
Empty for arg in func.pos_args + func.args if not arg.required
]
kw_defaults: list[Optional[ast.expr]] = [
None if arg.required else Empty for arg in func.kw_args
]
if with_self:
arg = ast.arg(arg="self", annotation=None)
if len(pos) != 0:
pos.insert(0, arg)
else:
mixed.insert(0, arg)
return ast.arguments(
posonlyargs=pos,
args=mixed,
kwonlyargs=kw,
defaults=defaults,
kw_defaults=kw_defaults,
)
def define_protocol(self, func: Function) -> str:
self.add_typing_import("Protocol")
name: str = self.new_protocol_name()
protocol = ast.ClassDef(
name=name,
bases=[ast.Name(id="Protocol")],
keywords=[],
body=[
ast.FunctionDef(
name="__call__",
args=self.dump_args(func, with_self=True),
returns=self.dump_type(func.returns),
body=[ast.Expr(value=Empty)],
decorator_list=[],
),
],
decorator_list=[],
)
self.add_stub(protocol)
return name
def new_protocol_name(self) -> str:
name: str = f"_Protocol{self.protocol_idx}"
self.protocol_idx += 1
return name
def new_stub_name(self) -> str:
name: str = f"_Stub_{self.stub_idx}"
self.stub_idx += 1
return name
def new_type_var_name(self) -> str:
name: str = f"_T{self.type_var_idx}"
self.type_var_idx += 1
return name
def add_stub(self, stub: ast.stmt):
self.stubs.append(stub)
def add_typing_import(self, name: str):
self.typing_imports.add(name)
def define_type_vars(self, vars: list[TypeVar]) -> list[TypeVar]:
vars2: list[TypeVar] = []
for var in vars:
vars2.append(self.define_type_var(var))
return vars2
def define_type_var(self, var: TypeVar) -> TypeVar:
name: str = self.new_type_var_name()
self.add_typing_import("TypeVar")
self.add_stub(
ast.Assign(
targets=[ast.Name(id=name)],
value=ast.Call(
func=ast.Name(id="TypeVar"),
args=[
ast.Constant(value=name),
],
keywords=(
[]
if var.bound is None
else [
ast.keyword(
arg="bound",
value=self.dump_type(var.bound),
)
]
),
),
)
)
return TypeVar(name=name, bound=None)