5 Commits

16 changed files with 588 additions and 577 deletions

View File

@@ -26,14 +26,6 @@ class MemberKind(Enum):
METHOD = auto()
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
l_paren: Token
pos: list[FunctionType.Argument]
mixed: list[FunctionType.Argument]
kw: list[FunctionType.Argument]
###<
@@ -58,8 +50,9 @@ class ExtendStmt:
class PredicateStmt:
name: Token
params: list[ParamSpec]
body: Expr
subject: Token
type: Type
condition: Expr
###<
@@ -85,12 +78,6 @@ class UnaryExpr:
right: Expr
class CallExpr:
callee: Expr
arguments: list[Expr]
keywords: dict[str, Expr]
class GetExpr:
expr: Expr
name: Token
@@ -141,7 +128,9 @@ class ExtensionType:
class FunctionType:
params: ParamSpec
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
returns: Type
@dataclass(frozen=True, kw_only=True)

View File

@@ -27,14 +27,6 @@ class MemberKind(Enum):
METHOD = auto()
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
l_paren: Token
pos: list[FunctionType.Argument]
mixed: list[FunctionType.Argument]
kw: list[FunctionType.Argument]
##############
# Statements #
##############
@@ -94,8 +86,9 @@ class ExtendStmt(Stmt):
@dataclass(frozen=True)
class PredicateStmt(Stmt):
name: Token
params: list[ParamSpec]
body: Expr
subject: Token
type: Type
condition: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_predicate_stmt(self)
@@ -123,9 +116,6 @@ class Expr(ABC):
@abstractmethod
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
@abstractmethod
def visit_call_expr(self, expr: CallExpr) -> T: ...
@abstractmethod
def visit_get_expr(self, expr: GetExpr) -> T: ...
@@ -171,16 +161,6 @@ class UnaryExpr(Expr):
return visitor.visit_unary_expr(self)
@dataclass(frozen=True)
class CallExpr(Expr):
callee: Expr
arguments: list[Expr]
keywords: dict[str, Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_call_expr(self)
@dataclass(frozen=True)
class GetExpr(Expr):
expr: Expr
@@ -299,7 +279,9 @@ class ExtensionType(Type):
@dataclass(frozen=True)
class FunctionType(Type):
params: ParamSpec
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
returns: Type
@dataclass(frozen=True, kw_only=True)

View File

@@ -150,17 +150,13 @@ class MidasAstPrinter(
self._write_line("PredicateStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("params")
with self._child_level():
for i, spec in enumerate(stmt.params):
self._idx = i
if i == len(stmt.params) - 1:
self._mark_last()
self._visit_param_spec(spec)
self._write_line("body", last=True)
self._write_line(f'subject: "{stmt.subject.lexeme}"')
self._write_line("type")
with self._child_level(single=True):
stmt.body.accept(self)
stmt.type.accept(self)
self._write_line("condition", last=True)
with self._child_level(single=True):
stmt.condition.accept(self)
# Expressions
@@ -199,29 +195,6 @@ class MidasAstPrinter(
with self._child_level(single=True):
expr.right.accept(self)
def visit_call_expr(self, expr: m.CallExpr) -> None:
self._write_line("CallExpr")
with self._child_level():
self._write_line("callee")
with self._child_level(single=True):
expr.callee.accept(self)
self._write_line("arguments")
with self._child_level():
for i, arg in enumerate(expr.arguments):
self._idx = i
if i == len(expr.arguments) - 1:
self._mark_last()
arg.accept(self)
self._write_line("keywords", last=True)
with self._child_level():
for i, (name, arg) in enumerate(expr.keywords.items()):
self._idx = i
if i == len(expr.keywords) - 1:
self._mark_last()
self._write_line(name)
with self._child_level(single=True):
arg.accept(self)
def visit_get_expr(self, expr: m.GetExpr):
self._write_line("GetExpr")
with self._child_level():
@@ -303,41 +276,34 @@ class MidasAstPrinter(
def visit_function_type(self, type: m.FunctionType) -> None:
self._write_line("FunctionType")
with self._child_level():
self._write_line("params")
with self._child_level(single=True):
self._visit_param_spec(type.params)
self._write_line("pos_args")
with self._child_level():
for i, arg in enumerate(type.pos_args):
self._idx = i
if i == len(type.pos_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("args")
with self._child_level():
for i, arg in enumerate(type.args):
self._idx = i
if i == len(type.args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw_args")
with self._child_level():
for i, arg in enumerate(type.kw_args):
self._idx = i
if i == len(type.kw_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("returns", last=True)
with self._child_level(single=True):
type.returns.accept(self)
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
self._write_line("ParamSpec")
with self._child_level():
self._write_line("pos")
with self._child_level():
for i, arg in enumerate(spec.pos):
self._idx = i
if i == len(spec.pos) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("mixed")
with self._child_level():
for i, arg in enumerate(spec.mixed):
self._idx = i
if i == len(spec.mixed) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw", last=True)
with self._child_level():
for i, arg in enumerate(spec.kw):
self._idx = i
if i == len(spec.kw) - 1:
self._mark_last()
self._print_function_arg(arg)
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
self._write_line("Argument")
with self._child_level():
@@ -401,9 +367,10 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
name: str = stmt.name.lexeme
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
body: str = stmt.body.accept(self)
return self.indented(f"predicate {name}{sig} = {body}")
subject: str = stmt.subject.lexeme
type: str = stmt.type.accept(self)
condition: str = stmt.condition.accept(self)
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
def visit_logical_expr(self, expr: m.LogicalExpr):
left: str = expr.left.accept(self)
@@ -422,12 +389,6 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
right: str = expr.right.accept(self)
return f"{operator}{right}"
def visit_call_expr(self, expr: m.CallExpr) -> str:
args: list[str] = [arg.accept(self) for arg in expr.arguments] + [
f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items()
]
return f"{expr.callee.accept(self)}({', '.join(args)})"
def visit_get_expr(self, expr: m.GetExpr):
expr_: str = expr.expr.accept(self)
name: str = expr.name.lexeme
@@ -475,13 +436,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
def visit_function_type(self, type: m.FunctionType) -> str:
spec: str = self._visit_param_spec(type.params)
return f"fn {spec} -> {type.returns.accept(self)}"
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in spec.pos]
mixed_args: list[str] = [self._print_arg(arg) for arg in spec.mixed]
kw_args: list[str] = [self._print_arg(arg) for arg in spec.kw]
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
args: list[str] = pos_args
if len(pos_args) != 0:
@@ -490,7 +447,8 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
if len(kw_args) != 0:
args.append("*")
args += kw_args
return f"({', '.join(args)})"
return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}"
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
res: str = ""

View File

@@ -1,5 +1,4 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
@@ -10,11 +9,9 @@ from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import (
AliasType,
ComplexType,
ConstraintType,
ExtensionType,
Function,
GenericType,
Predicate,
Type,
TypeVar,
UnknownType,
@@ -24,13 +21,6 @@ from midas.lexer.token import Token
from midas.parser.midas import MidasParser
@dataclass(frozen=True, kw_only=True)
class TypedParamSpec:
pos: list[Function.Argument]
mixed: list[Function.Argument]
kw: list[Function.Argument]
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry"""
@@ -47,8 +37,6 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
self.process(builtins_path.read_text(), str(builtins_path))
self._bool: Type = self.get_type("bool")
def process(self, source: str, path: Optional[str]):
self.reporter = self.reporter.for_file(path)
lexer: MidasLexer = MidasLexer(source)
@@ -114,28 +102,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
base_name,
member.name.lexeme,
member_type,
member.kind == m.MemberKind.METHOD,
member.kind,
)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
params: list[TypedParamSpec] = [
self._visit_param_spec(spec) for spec in stmt.params
]
type: Type = self._bool
for spec in reversed(params):
type = Function(
pos_args=spec.pos,
args=spec.mixed,
kw_args=spec.kw,
returns=type,
)
self.types.define_predicate(
stmt.name.lexeme,
Predicate(
type=type,
body=stmt.body,
),
)
self.reporter.warning(stmt.location, "PredicateStmt not yet supported")
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
@@ -146,9 +117,6 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
def visit_call_expr(self, expr: m.CallExpr) -> None:
self.reporter.warning(expr.location, "CallExpr not yet supported")
def visit_get_expr(self, expr: m.GetExpr) -> None:
self.reporter.warning(expr.location, "GetExpr not yet supported")
@@ -185,10 +153,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
return UnknownType()
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
return ConstraintType(
type=type.type.accept(self),
constraint=type.constraint,
)
type_: Type = type.type.accept(self)
type.constraint.accept(self)
# TODO
return UnknownType()
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
return ComplexType(
@@ -204,17 +172,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
)
def visit_function_type(self, type: m.FunctionType) -> Type:
params: TypedParamSpec = self._visit_param_spec(type.params)
return Function(
pos_args=params.pos,
args=params.mixed,
kw_args=params.kw,
returns=type.returns.accept(self),
)
def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec:
n_pos: int = len(spec.pos)
n_mixed: int = len(spec.mixed)
n_pos_args: int = len(type.pos_args)
n_args: int = len(type.args)
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
return Function.Argument(
@@ -224,10 +183,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
required=arg.required,
)
return TypedParamSpec(
pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)],
mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)],
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
return Function(
pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
kw_args=[
process_arg(arg, i + n_pos_args + n_args)
for i, arg in enumerate(type.kw_args)
],
returns=type.returns.accept(self),
)
def _resolve_type_params(self, params: list[m.TypeParam]):

View File

@@ -12,6 +12,7 @@ from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver
from midas.checker.types import (
AliasType,
AppliedType,
Function,
OverloadedFunction,
@@ -694,9 +695,17 @@ class PythonTyper(
case UnknownType():
return UnknownType()
case AliasType(type=base):
return self._get_call_result(
location, base, positional, keywords, report_errors
)
case _:
if report_errors:
self.reporter.error(location, f"{callee} is not callable")
self.reporter.error(
location,
f"{callee} ({callee.__class__.__name__}) is not callable",
)
return None
def _are_arguments_valid(

View File

@@ -1,6 +1,8 @@
import logging
from dataclasses import dataclass
from typing import Optional
from midas.ast.midas import MemberKind
from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.types import (
AliasType,
@@ -11,7 +13,6 @@ from midas.checker.types import (
Function,
GenericType,
OverloadedFunction,
Predicate,
TopType,
Type,
TypeVar,
@@ -20,12 +21,17 @@ from midas.checker.types import (
)
@dataclass
class Member:
kind: MemberKind
type: Type
class TypesRegistry:
def __init__(self) -> None:
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
self._types: dict[str, Type] = {}
self._members: dict[str, dict[str, Type]] = {}
self._predicates: dict[str, Predicate] = {}
self._members: dict[str, dict[str, Member]] = {}
def get_type(self, name: str) -> Type:
"""Get a type from its name
@@ -62,31 +68,38 @@ class TypesRegistry:
return type
def define_member(
self, type_name: str, member_name: str, member_type: Type, is_method: bool
self,
type_name: str,
member_name: str,
member_type: Type,
kind: MemberKind,
):
members: dict[str, Type] = self._members.setdefault(type_name, {})
members: dict[str, Member] = self._members.setdefault(type_name, {})
if member_name in members:
if not is_method:
current: Member = members[member_name]
if current.kind != kind:
self.logger.error(
f"Member '{member_name}' already defined for type {type_name}"
f"Member '{member_name}' is already defined as a {current.kind},"
+ f" cannot define a {kind} with the same name"
)
return
current: Type = members[member_name]
if kind != MemberKind.METHOD:
self.logger.error(
f"Member '{member_name}' already defined for type {type_name},"
+ " only methods can be overloaded"
)
return
combined: Type
match current:
match current.type:
case OverloadedFunction(overloads=overloads):
combined = OverloadedFunction(overloads=overloads + [member_type])
case _:
combined = OverloadedFunction(overloads=[current, member_type])
members[member_name] = combined
combined = OverloadedFunction(overloads=[current.type, member_type])
members[member_name] = Member(kind=current.kind, type=combined)
else:
members[member_name] = member_type
def define_predicate(self, name: str, predicate: Predicate):
if name in self._predicates:
raise ValueError(f"Predicate {name} already defined")
self._predicates[name] = predicate
members[member_name] = Member(kind=kind, type=member_type)
def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2`
@@ -304,13 +317,13 @@ class TypesRegistry:
case BaseType(name=name):
if name in self._members:
if member_name in self._members[name]:
return self._members[name][member_name]
return self._members[name][member_name].type
return None
case AliasType(name=name, type=base):
if name in self._members:
if member_name in self._members[name]:
return self._members[name][member_name]
return self._members[name][member_name].type
return self.lookup_member(base, member_name)
case AppliedType(name=name, body=body, args=args):
@@ -324,7 +337,7 @@ class TypesRegistry:
}
if name in self._members:
if member_name in self._members[name]:
member_type: Type = self._members[name][member_name]
member_type: Type = self._members[name][member_name].type
return substitute_typevars(member_type, substitutions)
member_type2: Optional[Type] = self.lookup_member(body, member_name)

View File

@@ -1,10 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional, assert_never
import midas.ast.midas as m
from midas.ast.printer import MidasPrinter
from typing import Optional
@dataclass(frozen=True, kw_only=True)
@@ -133,16 +130,6 @@ class AppliedType:
return f"{self.name}[{', '.join(map(str, self.args))}]"
@dataclass(frozen=True, kw_only=True)
class ConstraintType:
type: Type
constraint: m.Expr
def __str__(self) -> str:
printer = MidasPrinter()
return f"{self.type} where {printer.print(self.constraint)}"
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def sub_argument(arg: Function.Argument):
return Function.Argument(
@@ -153,6 +140,9 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
)
match type:
case TopType():
return type
case BaseType(name=name) if name in substitutions:
return substitutions[name]
@@ -208,26 +198,31 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
body=substitute_typevars(body, substitutions),
)
case ConstraintType():
return ConstraintType(
type=substitute_typevars(type.type, substitutions),
constraint=type.constraint,
)
case TypeVar(name=name):
if name in substitutions:
return substitutions[name]
raise ValueError(f"Missing TypeVar substitution for {name}")
case GenericType(name=name, params=params, body=body):
params2: list[TypeVar] = []
for param in params:
param2: Type = substitute_typevars(param, substitutions)
if not isinstance(param2, TypeVar):
raise ValueError(
f"Invalid type parameter substitution, expected TypeVar, got {param2}"
)
params2.append(param2)
return GenericType(
name=name,
params=params2,
body=substitute_typevars(body, substitutions),
)
case UnknownType() | UnitType():
return type
case TopType() | GenericType():
raise NotImplementedError(f"Unsupported type {type}")
# Ensure exhaustiveness
case _:
assert_never(type)
raise NotImplementedError(f"Unsupported type {type}")
def unfold_type(type: Type) -> Type:
@@ -238,12 +233,6 @@ def unfold_type(type: Type) -> Type:
return type
@dataclass(frozen=True, kw_only=True)
class Predicate:
type: Type
body: m.Expr
Type = (
TopType
| BaseType
@@ -257,5 +246,4 @@ Type = (
| TypeVar
| GenericType
| AppliedType
| ConstraintType
)

View File

@@ -4,5 +4,6 @@ from .format import format as format
from .highlight import highlight as highlight
from .parse import parse as parse
from .registry import dump_registry as dump_registry
from .stubs import stubs as stubs
from .types import types as types
from .validate import validate as validate

View File

@@ -0,0 +1,27 @@
import ast
from pathlib import Path
from typing import TextIO
import click
from midas.checker.checker import TypeChecker
from midas.generator.stubs import StubsGenerator
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
def stubs(
file: TextIO,
output: TextIO,
):
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
checker.import_midas(source_path)
generator = StubsGenerator(checker.types)
module: ast.Module = generator.generate_stubs()
module = ast.fix_missing_locations(module)
output.write(ast.unparse(module))

View File

@@ -18,6 +18,7 @@ midas.add_command(commands.highlight)
midas.add_command(commands.parse)
midas.add_command(commands.dump_registry)
midas.add_command(commands.types)
midas.add_command(commands.stubs)
midas.add_command(commands.validate)

View File

@@ -1,92 +0,0 @@
import ast
import midas.ast.midas as m
from midas.lexer.token import TokenType
LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = {
TokenType.AND: ast.And,
# TokenType.OR: ast.Or,
}
BINARY_OPERATORS: dict[TokenType, type[ast.operator]] = {
# TokenType.PLUS: ast.Add,
TokenType.MINUS: ast.Sub,
TokenType.STAR: ast.Mult,
TokenType.SLASH: ast.Div,
}
UNARY_OPERATORS: dict[TokenType, type[ast.unaryop]] = {
# TokenType.PLUS: ast.UAdd,
TokenType.MINUS: ast.USub,
}
COMPARISON_OPERATORS: dict[TokenType, type[ast.cmpop]] = {
TokenType.GREATER: ast.Gt,
TokenType.GREATER_EQUAL: ast.GtE,
TokenType.LESS: ast.Lt,
TokenType.LESS_EQUAL: ast.LtE,
TokenType.EQUAL_EQUAL: ast.Eq,
TokenType.BANG_EQUAL: ast.NotEq,
}
class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
def visit_logical_expr(self, expr: m.LogicalExpr) -> ast.expr:
return ast.BoolOp(
op=LOGICAL_OPERATORS[expr.operator.type](),
values=[
expr.left.accept(self),
expr.right.accept(self),
],
)
def visit_binary_expr(self, expr: m.BinaryExpr) -> ast.expr:
op: TokenType = expr.operator.type
if op in BINARY_OPERATORS:
return ast.BinOp(
left=expr.left.accept(self),
op=BINARY_OPERATORS[op](),
right=expr.right.accept(self),
)
if op in COMPARISON_OPERATORS:
return ast.Compare(
left=expr.left.accept(self),
ops=[COMPARISON_OPERATORS[op]()],
comparators=[expr.right.accept(self)],
)
raise ValueError(f"Unexpected binary operator {op}")
def visit_unary_expr(self, expr: m.UnaryExpr) -> ast.expr:
return ast.UnaryOp(
op=UNARY_OPERATORS[expr.operator.type](),
operand=expr.right.accept(self),
)
def visit_call_expr(self, expr: m.CallExpr) -> ast.expr:
return ast.Call(
func=expr.callee.accept(self),
args=[arg.accept(self) for arg in expr.arguments],
keywords=[
ast.keyword(arg=name, value=arg.accept(self))
for name, arg in expr.keywords.items()
],
)
def visit_get_expr(self, expr: m.GetExpr) -> ast.expr:
return ast.Attribute(
value=expr.expr.accept(self),
attr=expr.name.lexeme,
)
def visit_variable_expr(self, expr: m.VariableExpr) -> ast.expr:
# TODO: lookup predicate
return ast.Name(id=expr.name.lexeme)
def visit_grouping_expr(self, expr: m.GroupingExpr) -> ast.expr:
return expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> ast.expr:
return ast.Constant(value=expr.value)
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> ast.expr:
return ast.Name(id="_")

View File

@@ -2,18 +2,15 @@ import ast
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, assert_never
from typing import Optional
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasPrinter
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
ConstraintType,
ExtensionType,
Function,
GenericType,
@@ -22,9 +19,7 @@ from midas.checker.types import (
Type,
TypeVar,
UnitType,
UnknownType,
)
from midas.generator.constraints import ConstraintGenerator
from midas.utils import TypedAST
@@ -50,8 +45,6 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._alias_count: int = 0
self._scopes: list[Scope] = []
self._constraint_generator: ConstraintGenerator = ConstraintGenerator()
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
self.rel_src_path = src_path.relative_to(self.workdir)
self._typed_ast = typed_ast
@@ -283,9 +276,6 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
match type:
case UnknownType():
pass
case BaseType(name=name):
self._add_assert(
ast.Call(
@@ -311,15 +301,8 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._make_cast_assert_message(src_location, expr, type),
)
case AppliedType(body=body):
self._make_cast_asserts(src_location, expr, body)
case ConstraintType(type=base, constraint=constraint):
self._make_cast_asserts(src_location, expr, base)
self._make_constraint_assert(src_location, expr, constraint)
case TypeVar():
raise RuntimeError("Unexpected TypeVar")
case AppliedType():
self._make_cast_asserts(src_location, expr, type.body)
case (
TopType()
@@ -331,9 +314,8 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
):
raise NotImplementedError(f"Can't make assertion for type {type}")
# Ensure exhaustiveness
case _:
assert_never(type)
case TypeVar():
raise RuntimeError("Unexpected TypeVar")
def _make_cast_assert_message(
self, location: Location, expr: ast.expr, type: Type
@@ -357,23 +339,3 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
ast.Constant(f" to {type}"),
]
)
def _make_constraint_assert(
self, src_location: Location, expr: ast.expr, constraint: m.Expr
):
test: ast.expr = constraint.accept(self._constraint_generator)
self._add_assert(
test,
self._make_constraint_assert_message(src_location, expr, constraint),
)
def _make_constraint_assert_message(
self, location: Location, expr: ast.expr, constraint: m.Expr
) -> ast.expr:
printer = MidasPrinter()
constraint_str: str = printer.print(constraint)
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
# f"file.py:L1:1: ConstraintError: Value does not fit constraint 'v > 0'"
return ast.Constant(
f"{loc_str}: ConstraintError: Value does not fit constraint '{constraint_str}'"
)

337
midas/generator/stubs.py Normal file
View File

@@ -0,0 +1,337 @@
import ast
from typing import Optional
import midas.ast.midas as m
from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
TopType,
Type,
TypeVar,
UnitType,
UnknownType,
substitute_typevars,
)
Empty = ast.Constant(value=...)
class StubsGenerator:
def __init__(self, types: TypesRegistry) -> None:
self.types: TypesRegistry = types
self.stubs: list[ast.stmt] = []
self.typing_imports: set[str] = set()
self.protocol_idx: int = 0
self.stub_idx: int = 0
self.type_var_idx: int = 0
self.substitutions: dict[str, dict[str, Type]] = {}
def generate_stubs(self) -> ast.Module:
self.stubs = []
self.typing_imports = set()
for name, type in self.types._types.items():
self.generate_stub(name, type)
imports = [
ast.ImportFrom(
module="__future__",
names=[ast.alias(name="annotations")],
level=0,
)
]
if len(self.typing_imports) != 0:
imports.append(
ast.ImportFrom(
module="typing",
names=[
ast.alias(name=name) for name in sorted(self.typing_imports)
],
level=0,
)
)
return ast.Module(body=imports + self.stubs, type_ignores=[])
def generate_stub(self, name: str, type: Type):
base_type: Type = type
members: dict[str, Member] = self.types._members.get(name, {})
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
return
bases: list[ast.expr] = []
substitutions: dict[str, Type] = {}
bases, substitutions = self.get_bases(type)
self.substitutions[name] = substitutions
body = self.generate_body(members, substitutions)
stub = ast.ClassDef(
name=name,
bases=bases,
body=body,
keywords=[],
decorator_list=[],
)
self.add_stub(stub)
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
match type:
case AliasType(type=base):
return [self.dump_type(base)], {}
case GenericType(params=params, body=body):
self.add_typing_import("Generic")
type_vars: ast.expr
params2: list[TypeVar] = self.define_type_vars(params)
if len(params) == 1:
type_vars = ast.Name(id=params2[0].name)
else:
type_vars = ast.Tuple(
elts=[ast.Name(id=param.name) for param in params2]
)
substitutions: dict[str, TypeVar] = {
param.name: param2 for param, param2 in zip(params, params2)
}
body_bases, body_subsitutions = self.get_bases(body)
return (
body_bases
+ [
ast.Subscript(
value=ast.Name(id="Generic"),
slice=type_vars,
)
],
body_subsitutions | substitutions,
)
case _:
return [], {}
def generate_body(
self, members: dict[str, Member], substitutions: dict[str, Type]
) -> list[ast.stmt]:
if len(members) == 0:
return [ast.Expr(value=Empty)]
body: list[ast.stmt] = []
for name, member in members.items():
type: Type = member.type
type = substitute_typevars(type, substitutions)
match member.kind:
case m.MemberKind.PROPERTY:
body.append(
ast.AnnAssign(
target=ast.Name(id=name),
annotation=self.dump_type(type),
simple=1,
)
)
case m.MemberKind.METHOD:
body.extend(self.dump_method(name, type))
return body
def dump_type(self, type: Type) -> ast.expr:
match type:
case AliasType(name=name) | GenericType(name=name) if (
name in self.substitutions
):
type = substitute_typevars(type, self.substitutions[name])
match type:
case TopType() | UnknownType():
self.add_typing_import("Any")
return ast.Name(id="Any")
case BaseType(name=name):
return ast.Name(id=name)
case AliasType(name=name):
return ast.Name(id=name)
case UnitType():
return ast.Constant(value=None)
case Function():
name: str = self.define_protocol(type)
return ast.Name(id=name)
case OverloadedFunction(overloads=overloads):
if len(overloads) == 1:
return self.dump_type(overloads[0])
return ast.BinOp(
left=self.dump_type(OverloadedFunction(overloads=overloads[:-1])),
op=ast.BitOr(),
right=self.dump_type(overloads[-1]),
)
case ComplexType():
name: str = self.new_stub_name()
self.generate_stub(name, type)
return ast.Name(id=name)
case ExtensionType():
raise NotImplementedError
case TypeVar():
return ast.Name(id=type.name)
case GenericType(name=name):
params: ast.expr
if len(type.params) == 1:
params = self.dump_type(type.params[0])
else:
params = ast.Tuple(
elts=[self.dump_type(param) for param in type.params]
)
return ast.Subscript(
value=ast.Name(id=type.name),
slice=params,
)
case AppliedType():
args: ast.expr
if len(type.args) == 1:
args = self.dump_type(type.args[0])
else:
args = ast.Tuple(elts=[self.dump_type(arg) for arg in type.args])
return ast.Subscript(
value=ast.Name(id=type.name),
slice=args,
)
def dump_method(
self, name: str, method: Type, overloaded: bool = False
) -> list[ast.stmt]:
match method:
case Function():
if overloaded:
self.add_typing_import("overload")
return [
ast.FunctionDef(
name=name,
args=self.dump_args(method, with_self=True),
returns=self.dump_type(method.returns),
body=[ast.Expr(value=Empty)],
decorator_list=[ast.Name(id="overload")] if overloaded else [],
)
]
case OverloadedFunction(overloads=overloads):
stmts: list[ast.stmt] = []
for overload in overloads:
stmts.extend(self.dump_method(name, overload, True))
return stmts
case _:
return [
ast.AnnAssign(
target=ast.Name(id=name),
annotation=self.dump_type(method),
simple=1,
)
]
def dump_args(self, func: Function, with_self: bool = False) -> ast.arguments:
pos: list[ast.arg] = [
ast.arg(arg=f"_{arg.pos}", annotation=self.dump_type(arg.type))
for arg in func.pos_args
]
mixed: list[ast.arg] = [
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
for arg in func.args
]
kw: list[ast.arg] = [
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
for arg in func.kw_args
]
defaults: list[ast.expr] = [
Empty for arg in func.pos_args + func.args if not arg.required
]
kw_defaults: list[Optional[ast.expr]] = [
None if arg.required else Empty for arg in func.kw_args
]
if with_self:
arg = ast.arg(arg="self", annotation=None)
if len(pos) != 0:
pos.insert(0, arg)
else:
mixed.insert(0, arg)
return ast.arguments(
posonlyargs=pos,
args=mixed,
kwonlyargs=kw,
defaults=defaults,
kw_defaults=kw_defaults,
)
def define_protocol(self, func: Function) -> str:
self.add_typing_import("Protocol")
name: str = self.new_protocol_name()
protocol = ast.ClassDef(
name=name,
bases=[ast.Name(id="Protocol")],
keywords=[],
body=[
ast.FunctionDef(
name="__call__",
args=self.dump_args(func, with_self=True),
returns=self.dump_type(func.returns),
body=[ast.Expr(value=Empty)],
decorator_list=[],
),
],
decorator_list=[],
)
self.add_stub(protocol)
return name
def new_protocol_name(self) -> str:
name: str = f"_Protocol{self.protocol_idx}"
self.protocol_idx += 1
return name
def new_stub_name(self) -> str:
name: str = f"_Stub_{self.stub_idx}"
self.stub_idx += 1
return name
def new_type_var_name(self) -> str:
name: str = f"_T{self.type_var_idx}"
self.type_var_idx += 1
return name
def add_stub(self, stub: ast.stmt):
self.stubs.append(stub)
def add_typing_import(self, name: str):
self.typing_imports.add(name)
def define_type_vars(self, vars: list[TypeVar]) -> list[TypeVar]:
vars2: list[TypeVar] = []
for var in vars:
vars2.append(self.define_type_var(var))
return vars2
def define_type_var(self, var: TypeVar) -> TypeVar:
name: str = self.new_type_var_name()
self.add_typing_import("TypeVar")
self.add_stub(
ast.Assign(
targets=[ast.Name(id=name)],
value=ast.Call(
func=ast.Name(id="TypeVar"),
args=[
ast.Constant(value=name),
],
keywords=(
[]
if var.bound is None
else [
ast.keyword(
arg="bound",
value=self.dump_type(var.bound),
)
]
),
),
)
)
return TypeVar(name=name, bound=None)

View File

@@ -3,7 +3,6 @@ from typing import Optional
from midas.ast.location import Location
from midas.ast.midas import (
BinaryExpr,
CallExpr,
ComplexType,
ConstraintType,
Expr,
@@ -18,7 +17,6 @@ from midas.ast.midas import (
MemberKind,
MemberStmt,
NamedType,
ParamSpec,
PredicateStmt,
Stmt,
Type,
@@ -267,9 +265,6 @@ class MidasParser(Parser):
Returns:
Expr: the parsed constraint expression
"""
return self.expression()
def expression(self) -> Expr:
return self.and_()
def and_(self) -> Expr:
@@ -336,55 +331,7 @@ class MidasParser(Parser):
right: Expr = self.unary()
location: Location = Location.span(operator.get_location(), right.location)
return UnaryExpr(location=location, operator=operator, right=right)
return self.call()
def call(self) -> Expr:
expr: Expr = self.reference()
if self.match(TokenType.LEFT_PAREN):
expr = self.finish_call(expr)
return expr
def finish_call(self, callee: Expr) -> Expr:
l_paren: Token = self.previous()
pos_args: list[Expr] = []
kw_args: dict[str, Expr] = {}
keywords: bool = False
while not self.match(TokenType.RIGHT_PAREN):
if self.check_identifier() and self.check_next(TokenType.EQUAL):
keywords = True
keyword: Token = self.advance()
value: Expr = self.expression()
name: str = keyword.lexeme
if name in kw_args:
self.error(
self.peek(),
f"Multiple values passed for '{name}', only the last occurrence will be used",
)
kw_args[name] = value
else:
value = self.expression()
if self.check(TokenType.EQUAL):
if keywords:
raise self.error(self.peek(), "Invalid keyword argument name")
else:
raise self.error(
self.peek(),
"Cannot pass positional arguments after a keyword argument",
)
pos_args.append(value)
if not self.match(TokenType.COMMA):
break
r_paren: Token = self.consume(
TokenType.RIGHT_PAREN, "Expected ')' after arguments."
)
return CallExpr(
location=l_paren.location_to(r_paren),
callee=callee,
arguments=pos_args,
keywords=kw_args,
)
return self.reference()
def reference(self) -> Expr:
"""Parse an attribute access expression or a simpler expression
@@ -506,35 +453,23 @@ class MidasParser(Parser):
PredicateStmt: the parsed predicate declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume_identifier("Expected predicate name")
params: list[ParamSpec] = []
while self.check(TokenType.LEFT_PAREN):
params.append(self.function_args())
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
subject: Token = self.consume_identifier("Expected subject name")
self.consume(TokenType.COLON, "Expected ':' after subject name")
type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
body: Expr = self.constraint()
condition: Expr = self.constraint()
return PredicateStmt(
location=keyword.location_to(self.previous()),
name=name,
params=params,
body=body,
subject=subject,
type=type,
condition=condition,
)
def function(self) -> FunctionType:
params: ParamSpec = self.function_args()
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: Type = self.type_expr()
return FunctionType(
location=params.l_paren.location_to(self.previous()),
params=params,
returns=result,
)
def function_args(self) -> ParamSpec:
l_paren: Token = self.consume(
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
)
@@ -591,4 +526,14 @@ class MidasParser(Parser):
self.error(token, "Unnamed mixed argument")
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
return ParamSpec(l_paren=l_paren, pos=pos_args, mixed=args, kw=kw_args)
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: Type = self.type_expr()
return FunctionType(
location=l_paren.location_to(self.previous()),
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=result,
)

View File

@@ -2582,21 +2582,18 @@
"name": "__sub__",
"type": {
"_type": "FunctionType",
"params": {
"_type": "ParamSpec",
"pos": [
{
"name": null,
"type": {
"_type": "NamedType",
"name": "GeoLocation"
},
"required": true
}
],
"mixed": [],
"kw": []
},
"pos_args": [
{
"name": null,
"type": {
"_type": "NamedType",
"name": "GeoLocation"
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"_type": "GenericType",
"type": {
@@ -2676,21 +2673,18 @@
"name": "__sub__",
"type": {
"_type": "FunctionType",
"params": {
"_type": "ParamSpec",
"pos": [
{
"name": null,
"type": {
"_type": "NamedType",
"name": "Latitude"
},
"required": true
}
],
"mixed": [],
"kw": []
},
"pos_args": [
{
"name": null,
"type": {
"_type": "NamedType",
"name": "Latitude"
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"_type": "GenericType",
"type": {
@@ -2719,21 +2713,18 @@
"name": "__sub__",
"type": {
"_type": "FunctionType",
"params": {
"_type": "ParamSpec",
"pos": [
{
"name": null,
"type": {
"_type": "NamedType",
"name": "Longitude"
},
"required": true
}
],
"mixed": [],
"kw": []
},
"pos_args": [
{
"name": null,
"type": {
"_type": "NamedType",
"name": "Longitude"
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"_type": "GenericType",
"type": {
@@ -2754,24 +2745,12 @@
{
"_type": "PredicateStmt",
"name": "Positive",
"params": [
{
"_type": "ParamSpec",
"pos": [],
"mixed": [
{
"name": "v",
"type": {
"_type": "NamedType",
"name": "float"
},
"required": true
}
],
"kw": []
}
],
"body": {
"subject": "v",
"type": {
"_type": "NamedType",
"name": "float"
},
"condition": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
@@ -2787,24 +2766,12 @@
{
"_type": "PredicateStmt",
"name": "StrictlyPositive",
"params": [
{
"_type": "ParamSpec",
"pos": [],
"mixed": [
{
"name": "v",
"type": {
"_type": "NamedType",
"name": "float"
},
"required": true
}
],
"kw": []
}
],
"body": {
"subject": "v",
"type": {
"_type": "NamedType",
"name": "float"
},
"condition": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
@@ -2820,24 +2787,12 @@
{
"_type": "PredicateStmt",
"name": "Equatorial",
"params": [
{
"_type": "ParamSpec",
"pos": [],
"mixed": [
{
"name": "loc",
"type": {
"_type": "NamedType",
"name": "GeoLocation"
},
"required": true
}
],
"kw": []
}
],
"body": {
"subject": "loc",
"type": {
"_type": "NamedType",
"name": "GeoLocation"
},
"condition": {
"_type": "GroupingExpr",
"expr": {
"_type": "BinaryExpr",
@@ -2872,24 +2827,12 @@
{
"_type": "PredicateStmt",
"name": "Arctic",
"params": [
{
"_type": "ParamSpec",
"pos": [],
"mixed": [
{
"name": "loc",
"type": {
"_type": "NamedType",
"name": "GeoLocation"
},
"required": true
}
],
"kw": []
}
],
"body": {
"subject": "loc",
"type": {
"_type": "NamedType",
"name": "GeoLocation"
},
"condition": {
"_type": "GroupingExpr",
"expr": {
"_type": "BinaryExpr",

View File

@@ -2,7 +2,6 @@ from typing import Optional, Sequence
from midas.ast.midas import (
BinaryExpr,
CallExpr,
ComplexType,
ConstraintType,
Expr,
@@ -16,7 +15,6 @@ from midas.ast.midas import (
LogicalExpr,
MemberStmt,
NamedType,
ParamSpec,
PredicateStmt,
Stmt,
Type,
@@ -80,8 +78,9 @@ class MidasAstJsonSerializer(
return {
"_type": "PredicateStmt",
"name": stmt.name.lexeme,
"params": [self._serialize_param_spec(spec) for spec in stmt.params],
"body": stmt.body.accept(self),
"subject": stmt.subject.lexeme,
"type": stmt.type.accept(self),
"condition": stmt.condition.accept(self),
}
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
@@ -107,14 +106,6 @@ class MidasAstJsonSerializer(
"right": expr.right.accept(self),
}
def visit_call_expr(self, expr: CallExpr) -> dict:
return {
"_type": "CallExpr",
"callee": expr.callee.accept(self),
"arguments": self._serialize_list(expr.arguments),
"keywords": {name: arg.accept(self) for name, arg in expr.keywords.items()},
}
def visit_get_expr(self, expr: GetExpr) -> dict:
return {
"_type": "GetExpr",
@@ -172,21 +163,15 @@ class MidasAstJsonSerializer(
def visit_function_type(self, type: FunctionType) -> dict:
return {
"_type": "FunctionType",
"params": self._serialize_param_spec(type.params),
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
"args": [self._serialize_func_arg(arg) for arg in type.args],
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
"returns": type.returns.accept(self),
}
def _serialize_param_spec(self, spec: ParamSpec) -> dict:
return {
"_type": "ParamSpec",
"pos": [self._serialize_func_arg(arg) for arg in spec.pos],
"mixed": [self._serialize_func_arg(arg) for arg in spec.mixed],
"kw": [self._serialize_func_arg(arg) for arg in spec.kw],
}
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
return {
"name": arg.name.lexeme if arg.name is not None else None,
"name": arg.name,
"type": arg.type.accept(self),
"required": arg.required,
}