10 Commits

19 changed files with 708 additions and 603 deletions

View File

@@ -26,6 +26,14 @@ class MemberKind(Enum):
METHOD = auto()
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
l_paren: Token
pos: list[FunctionType.Argument]
mixed: list[FunctionType.Argument]
kw: list[FunctionType.Argument]
###<
@@ -50,9 +58,8 @@ class ExtendStmt:
class PredicateStmt:
name: Token
subject: Token
type: Type
condition: Expr
params: list[ParamSpec]
body: Expr
###<
@@ -78,6 +85,12 @@ class UnaryExpr:
right: Expr
class CallExpr:
callee: Expr
arguments: list[Expr]
keywords: dict[str, Expr]
class GetExpr:
expr: Expr
name: Token
@@ -128,9 +141,7 @@ class ExtensionType:
class FunctionType:
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
params: ParamSpec
returns: Type
@dataclass(frozen=True, kw_only=True)

View File

@@ -27,6 +27,14 @@ class MemberKind(Enum):
METHOD = auto()
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
l_paren: Token
pos: list[FunctionType.Argument]
mixed: list[FunctionType.Argument]
kw: list[FunctionType.Argument]
##############
# Statements #
##############
@@ -86,9 +94,8 @@ class ExtendStmt(Stmt):
@dataclass(frozen=True)
class PredicateStmt(Stmt):
name: Token
subject: Token
type: Type
condition: Expr
params: list[ParamSpec]
body: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_predicate_stmt(self)
@@ -116,6 +123,9 @@ class Expr(ABC):
@abstractmethod
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
@abstractmethod
def visit_call_expr(self, expr: CallExpr) -> T: ...
@abstractmethod
def visit_get_expr(self, expr: GetExpr) -> T: ...
@@ -161,6 +171,16 @@ class UnaryExpr(Expr):
return visitor.visit_unary_expr(self)
@dataclass(frozen=True)
class CallExpr(Expr):
callee: Expr
arguments: list[Expr]
keywords: dict[str, Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_call_expr(self)
@dataclass(frozen=True)
class GetExpr(Expr):
expr: Expr
@@ -279,9 +299,7 @@ class ExtensionType(Type):
@dataclass(frozen=True)
class FunctionType(Type):
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
params: ParamSpec
returns: Type
@dataclass(frozen=True, kw_only=True)

View File

@@ -150,13 +150,17 @@ class MidasAstPrinter(
self._write_line("PredicateStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line(f'subject: "{stmt.subject.lexeme}"')
self._write_line("type")
self._write_line("params")
with self._child_level():
for i, spec in enumerate(stmt.params):
self._idx = i
if i == len(stmt.params) - 1:
self._mark_last()
self._visit_param_spec(spec)
self._write_line("body", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
self._write_line("condition", last=True)
with self._child_level(single=True):
stmt.condition.accept(self)
stmt.body.accept(self)
# Expressions
@@ -195,6 +199,29 @@ class MidasAstPrinter(
with self._child_level(single=True):
expr.right.accept(self)
def visit_call_expr(self, expr: m.CallExpr) -> None:
self._write_line("CallExpr")
with self._child_level():
self._write_line("callee")
with self._child_level(single=True):
expr.callee.accept(self)
self._write_line("arguments")
with self._child_level():
for i, arg in enumerate(expr.arguments):
self._idx = i
if i == len(expr.arguments) - 1:
self._mark_last()
arg.accept(self)
self._write_line("keywords", last=True)
with self._child_level():
for i, (name, arg) in enumerate(expr.keywords.items()):
self._idx = i
if i == len(expr.keywords) - 1:
self._mark_last()
self._write_line(name)
with self._child_level(single=True):
arg.accept(self)
def visit_get_expr(self, expr: m.GetExpr):
self._write_line("GetExpr")
with self._child_level():
@@ -276,34 +303,41 @@ class MidasAstPrinter(
def visit_function_type(self, type: m.FunctionType) -> None:
self._write_line("FunctionType")
with self._child_level():
self._write_line("pos_args")
with self._child_level():
for i, arg in enumerate(type.pos_args):
self._idx = i
if i == len(type.pos_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("args")
with self._child_level():
for i, arg in enumerate(type.args):
self._idx = i
if i == len(type.args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw_args")
with self._child_level():
for i, arg in enumerate(type.kw_args):
self._idx = i
if i == len(type.kw_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("params")
with self._child_level(single=True):
self._visit_param_spec(type.params)
self._write_line("returns", last=True)
with self._child_level(single=True):
type.returns.accept(self)
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
self._write_line("ParamSpec")
with self._child_level():
self._write_line("pos")
with self._child_level():
for i, arg in enumerate(spec.pos):
self._idx = i
if i == len(spec.pos) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("mixed")
with self._child_level():
for i, arg in enumerate(spec.mixed):
self._idx = i
if i == len(spec.mixed) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw", last=True)
with self._child_level():
for i, arg in enumerate(spec.kw):
self._idx = i
if i == len(spec.kw) - 1:
self._mark_last()
self._print_function_arg(arg)
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
self._write_line("Argument")
with self._child_level():
@@ -367,10 +401,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
name: str = stmt.name.lexeme
subject: str = stmt.subject.lexeme
type: str = stmt.type.accept(self)
condition: str = stmt.condition.accept(self)
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
body: str = stmt.body.accept(self)
return self.indented(f"predicate {name}{sig} = {body}")
def visit_logical_expr(self, expr: m.LogicalExpr):
left: str = expr.left.accept(self)
@@ -389,6 +422,12 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
right: str = expr.right.accept(self)
return f"{operator}{right}"
def visit_call_expr(self, expr: m.CallExpr) -> str:
args: list[str] = [arg.accept(self) for arg in expr.arguments] + [
f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items()
]
return f"{expr.callee.accept(self)}({', '.join(args)})"
def visit_get_expr(self, expr: m.GetExpr):
expr_: str = expr.expr.accept(self)
name: str = expr.name.lexeme
@@ -436,9 +475,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
def visit_function_type(self, type: m.FunctionType) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
spec: str = self._visit_param_spec(type.params)
return f"fn {spec} -> {type.returns.accept(self)}"
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in spec.pos]
mixed_args: list[str] = [self._print_arg(arg) for arg in spec.mixed]
kw_args: list[str] = [self._print_arg(arg) for arg in spec.kw]
args: list[str] = pos_args
if len(pos_args) != 0:
@@ -447,8 +490,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
if len(kw_args) != 0:
args.append("*")
args += kw_args
return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}"
return f"({', '.join(args)})"
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
res: str = ""

View File

@@ -1,4 +1,5 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
@@ -9,9 +10,11 @@ from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import (
AliasType,
ComplexType,
ConstraintType,
ExtensionType,
Function,
GenericType,
Predicate,
Type,
TypeVar,
UnknownType,
@@ -21,6 +24,13 @@ from midas.lexer.token import Token
from midas.parser.midas import MidasParser
@dataclass(frozen=True, kw_only=True)
class TypedParamSpec:
pos: list[Function.Argument]
mixed: list[Function.Argument]
kw: list[Function.Argument]
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry"""
@@ -37,6 +47,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
self.process(builtins_path.read_text(), str(builtins_path))
self._bool: Type = self.get_type("bool")
def process(self, source: str, path: Optional[str]):
self.reporter = self.reporter.for_file(path)
lexer: MidasLexer = MidasLexer(source)
@@ -102,11 +114,28 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
base_name,
member.name.lexeme,
member_type,
member.kind,
member.kind == m.MemberKind.METHOD,
)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.reporter.warning(stmt.location, "PredicateStmt not yet supported")
params: list[TypedParamSpec] = [
self._visit_param_spec(spec) for spec in stmt.params
]
type: Type = self._bool
for spec in reversed(params):
type = Function(
pos_args=spec.pos,
args=spec.mixed,
kw_args=spec.kw,
returns=type,
)
self.types.define_predicate(
stmt.name.lexeme,
Predicate(
type=type,
body=stmt.body,
),
)
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
@@ -117,6 +146,9 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
def visit_call_expr(self, expr: m.CallExpr) -> None:
self.reporter.warning(expr.location, "CallExpr not yet supported")
def visit_get_expr(self, expr: m.GetExpr) -> None:
self.reporter.warning(expr.location, "GetExpr not yet supported")
@@ -153,10 +185,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
return UnknownType()
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
type_: Type = type.type.accept(self)
type.constraint.accept(self)
# TODO
return UnknownType()
return ConstraintType(
type=type.type.accept(self),
constraint=type.constraint,
)
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
return ComplexType(
@@ -172,8 +204,17 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
)
def visit_function_type(self, type: m.FunctionType) -> Type:
n_pos_args: int = len(type.pos_args)
n_args: int = len(type.args)
params: TypedParamSpec = self._visit_param_spec(type.params)
return Function(
pos_args=params.pos,
args=params.mixed,
kw_args=params.kw,
returns=type.returns.accept(self),
)
def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec:
n_pos: int = len(spec.pos)
n_mixed: int = len(spec.mixed)
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
return Function.Argument(
@@ -183,14 +224,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
required=arg.required,
)
return Function(
pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
kw_args=[
process_arg(arg, i + n_pos_args + n_args)
for i, arg in enumerate(type.kw_args)
],
returns=type.returns.accept(self),
return TypedParamSpec(
pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)],
mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)],
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
)
def _resolve_type_params(self, params: list[m.TypeParam]):

View File

@@ -12,7 +12,6 @@ from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver
from midas.checker.types import (
AliasType,
AppliedType,
Function,
OverloadedFunction,
@@ -695,17 +694,9 @@ class PythonTyper(
case UnknownType():
return UnknownType()
case AliasType(type=base):
return self._get_call_result(
location, base, positional, keywords, report_errors
)
case _:
if report_errors:
self.reporter.error(
location,
f"{callee} ({callee.__class__.__name__}) is not callable",
)
self.reporter.error(location, f"{callee} is not callable")
return None
def _are_arguments_valid(

View File

@@ -1,8 +1,6 @@
import logging
from dataclasses import dataclass
from typing import Optional
from midas.ast.midas import MemberKind
from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.types import (
AliasType,
@@ -13,6 +11,7 @@ from midas.checker.types import (
Function,
GenericType,
OverloadedFunction,
Predicate,
TopType,
Type,
TypeVar,
@@ -21,17 +20,12 @@ from midas.checker.types import (
)
@dataclass
class Member:
kind: MemberKind
type: Type
class TypesRegistry:
def __init__(self) -> None:
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
self._types: dict[str, Type] = {}
self._members: dict[str, dict[str, Member]] = {}
self._members: dict[str, dict[str, Type]] = {}
self._predicates: dict[str, Predicate] = {}
def get_type(self, name: str) -> Type:
"""Get a type from its name
@@ -68,38 +62,31 @@ class TypesRegistry:
return type
def define_member(
self,
type_name: str,
member_name: str,
member_type: Type,
kind: MemberKind,
self, type_name: str, member_name: str, member_type: Type, is_method: bool
):
members: dict[str, Member] = self._members.setdefault(type_name, {})
members: dict[str, Type] = self._members.setdefault(type_name, {})
if member_name in members:
current: Member = members[member_name]
if current.kind != kind:
if not is_method:
self.logger.error(
f"Member '{member_name}' is already defined as a {current.kind},"
+ f" cannot define a {kind} with the same name"
f"Member '{member_name}' already defined for type {type_name}"
)
return
if kind != MemberKind.METHOD:
self.logger.error(
f"Member '{member_name}' already defined for type {type_name},"
+ " only methods can be overloaded"
)
return
current: Type = members[member_name]
combined: Type
match current.type:
match current:
case OverloadedFunction(overloads=overloads):
combined = OverloadedFunction(overloads=overloads + [member_type])
case _:
combined = OverloadedFunction(overloads=[current.type, member_type])
members[member_name] = Member(kind=current.kind, type=combined)
combined = OverloadedFunction(overloads=[current, member_type])
members[member_name] = combined
else:
members[member_name] = Member(kind=kind, type=member_type)
members[member_name] = member_type
def define_predicate(self, name: str, predicate: Predicate):
if name in self._predicates:
raise ValueError(f"Predicate {name} already defined")
self._predicates[name] = predicate
def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2`
@@ -317,13 +304,13 @@ class TypesRegistry:
case BaseType(name=name):
if name in self._members:
if member_name in self._members[name]:
return self._members[name][member_name].type
return self._members[name][member_name]
return None
case AliasType(name=name, type=base):
if name in self._members:
if member_name in self._members[name]:
return self._members[name][member_name].type
return self._members[name][member_name]
return self.lookup_member(base, member_name)
case AppliedType(name=name, body=body, args=args):
@@ -337,7 +324,7 @@ class TypesRegistry:
}
if name in self._members:
if member_name in self._members[name]:
member_type: Type = self._members[name][member_name].type
member_type: Type = self._members[name][member_name]
return substitute_typevars(member_type, substitutions)
member_type2: Optional[Type] = self.lookup_member(body, member_name)
@@ -365,3 +352,6 @@ class TypesRegistry:
case _:
self.logger.debug(f"Can't get member on {type}")
return None
def lookup_predicate(self, name: str) -> Optional[Predicate]:
return self._predicates.get(name)

View File

@@ -1,7 +1,10 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, assert_never
import midas.ast.midas as m
from midas.ast.printer import MidasPrinter
@dataclass(frozen=True, kw_only=True)
@@ -130,6 +133,16 @@ class AppliedType:
return f"{self.name}[{', '.join(map(str, self.args))}]"
@dataclass(frozen=True, kw_only=True)
class ConstraintType:
type: Type
constraint: m.Expr
def __str__(self) -> str:
printer = MidasPrinter()
return f"{self.type} where {printer.print(self.constraint)}"
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def sub_argument(arg: Function.Argument):
return Function.Argument(
@@ -140,9 +153,6 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
)
match type:
case TopType():
return type
case BaseType(name=name) if name in substitutions:
return substitutions[name]
@@ -198,32 +208,27 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
body=substitute_typevars(body, substitutions),
)
case ConstraintType():
return ConstraintType(
type=substitute_typevars(type.type, substitutions),
constraint=type.constraint,
)
case TypeVar(name=name):
if name in substitutions:
return substitutions[name]
raise ValueError(f"Missing TypeVar substitution for {name}")
case GenericType(name=name, params=params, body=body):
params2: list[TypeVar] = []
for param in params:
param2: Type = substitute_typevars(param, substitutions)
if not isinstance(param2, TypeVar):
raise ValueError(
f"Invalid type parameter substitution, expected TypeVar, got {param2}"
)
params2.append(param2)
return GenericType(
name=name,
params=params2,
body=substitute_typevars(body, substitutions),
)
case UnknownType() | UnitType():
return type
case _:
case TopType() | GenericType():
raise NotImplementedError(f"Unsupported type {type}")
# Ensure exhaustiveness
case _:
assert_never(type)
def unfold_type(type: Type) -> Type:
match type:
@@ -233,6 +238,12 @@ def unfold_type(type: Type) -> Type:
return type
@dataclass(frozen=True, kw_only=True)
class Predicate:
type: Type
body: m.Expr
Type = (
TopType
| BaseType
@@ -246,4 +257,5 @@ Type = (
| TypeVar
| GenericType
| AppliedType
| ConstraintType
)

View File

@@ -4,6 +4,5 @@ from .format import format as format
from .highlight import highlight as highlight
from .parse import parse as parse
from .registry import dump_registry as dump_registry
from .stubs import stubs as stubs
from .types import types as types
from .validate import validate as validate

View File

@@ -38,5 +38,5 @@ def compile(
if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)):
sys.exit(1)
generator = Generator(workdir=source_path.parent)
generator = Generator(workdir=source_path.parent, types=checker.types)
generator.generate(typed_ast, source_path)

View File

@@ -1,27 +0,0 @@
import ast
from pathlib import Path
from typing import TextIO
import click
from midas.checker.checker import TypeChecker
from midas.generator.stubs import StubsGenerator
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
def stubs(
file: TextIO,
output: TextIO,
):
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
checker.import_midas(source_path)
generator = StubsGenerator(checker.types)
module: ast.Module = generator.generate_stubs()
module = ast.fix_missing_locations(module)
output.write(ast.unparse(module))

View File

@@ -18,7 +18,6 @@ midas.add_command(commands.highlight)
midas.add_command(commands.parse)
midas.add_command(commands.dump_registry)
midas.add_command(commands.types)
midas.add_command(commands.stubs)
midas.add_command(commands.validate)

View File

@@ -0,0 +1,188 @@
import ast
from typing import Optional
import midas.ast.midas as m
from midas.checker.registry import TypesRegistry
from midas.checker.types import Function, Predicate, Type
from midas.lexer.token import TokenType
LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = {
TokenType.AND: ast.And,
# TokenType.OR: ast.Or,
}
BINARY_OPERATORS: dict[TokenType, type[ast.operator]] = {
# TokenType.PLUS: ast.Add,
TokenType.MINUS: ast.Sub,
TokenType.STAR: ast.Mult,
TokenType.SLASH: ast.Div,
}
UNARY_OPERATORS: dict[TokenType, type[ast.unaryop]] = {
# TokenType.PLUS: ast.UAdd,
TokenType.MINUS: ast.USub,
}
COMPARISON_OPERATORS: dict[TokenType, type[ast.cmpop]] = {
TokenType.GREATER: ast.Gt,
TokenType.GREATER_EQUAL: ast.GtE,
TokenType.LESS: ast.Lt,
TokenType.LESS_EQUAL: ast.LtE,
TokenType.EQUAL_EQUAL: ast.Eq,
TokenType.BANG_EQUAL: ast.NotEq,
}
class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
def __init__(self, types: TypesRegistry):
self.types: TypesRegistry = types
self._id: int = 0
self._definitions: list[ast.stmt] = []
self._aliases: dict[str, str] = {}
def get_definitions(self) -> list[ast.stmt]:
return self._definitions
def generate(self, expr: m.Expr) -> ast.expr:
match expr:
case m.VariableExpr():
return expr.accept(self)
case _:
func = Function(
pos_args=[],
args=[
Function.Argument(
pos=0,
name="_",
type=self.types.get_type("Any"),
required=True,
)
],
kw_args=[],
returns=self.types.get_type("bool"),
)
alias: str = self.make_alias(None)
definition: ast.stmt = self.make_definition(
alias, Predicate(type=func, body=expr)
)
self._definitions.append(definition)
return ast.Name(id=alias)
def make_alias(self, name: Optional[str]) -> str:
suffix: str = f"_{name}" if name is not None else ""
alias: str = f"__midas_p{self._id}{suffix}__"
self._id += 1
return alias
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
body: list[ast.stmt] = [ast.Return(value=predicate.body.accept(self))]
return self.make_func(name, body, predicate.type)
def make_args(self, func: Function) -> ast.arguments:
return ast.arguments(
posonlyargs=[ast.arg(arg=arg.name) for arg in func.pos_args],
args=[ast.arg(arg=arg.name) for arg in func.args],
kwonlyargs=[ast.arg(arg=arg.name) for arg in func.kw_args],
defaults=[],
kw_defaults=[],
)
def make_func(
self, name: str, inner_body: list[ast.stmt], type: Type, level: int = 0
) -> ast.stmt:
match type:
case Function(returns=Function()):
inner_name: str = f"inner{level}"
return ast.FunctionDef(
name=name,
args=self.make_args(type),
body=[
self.make_func(inner_name, inner_body, type.returns, level + 1),
ast.Return(value=ast.Name(id=inner_name)),
],
decorator_list=[],
)
case Function():
return ast.FunctionDef(
name=name,
args=self.make_args(type),
body=inner_body,
decorator_list=[],
)
case _:
raise ValueError(f"Expected function, got {type}")
def get_predicate(self, name: str) -> Optional[ast.expr]:
if name not in self._aliases:
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
if predicate is None:
return None
alias: str = self.make_alias(name)
self._aliases[name] = alias
self._definitions.append(self.make_definition(alias, predicate))
return ast.Name(id=self._aliases[name])
def visit_logical_expr(self, expr: m.LogicalExpr) -> ast.expr:
return ast.BoolOp(
op=LOGICAL_OPERATORS[expr.operator.type](),
values=[
expr.left.accept(self),
expr.right.accept(self),
],
)
def visit_binary_expr(self, expr: m.BinaryExpr) -> ast.expr:
op: TokenType = expr.operator.type
if op in BINARY_OPERATORS:
return ast.BinOp(
left=expr.left.accept(self),
op=BINARY_OPERATORS[op](),
right=expr.right.accept(self),
)
if op in COMPARISON_OPERATORS:
return ast.Compare(
left=expr.left.accept(self),
ops=[COMPARISON_OPERATORS[op]()],
comparators=[expr.right.accept(self)],
)
raise ValueError(f"Unexpected binary operator {op}")
def visit_unary_expr(self, expr: m.UnaryExpr) -> ast.expr:
return ast.UnaryOp(
op=UNARY_OPERATORS[expr.operator.type](),
operand=expr.right.accept(self),
)
def visit_call_expr(self, expr: m.CallExpr) -> ast.expr:
return ast.Call(
func=expr.callee.accept(self),
args=[arg.accept(self) for arg in expr.arguments],
keywords=[
ast.keyword(arg=name, value=arg.accept(self))
for name, arg in expr.keywords.items()
],
)
def visit_get_expr(self, expr: m.GetExpr) -> ast.expr:
return ast.Attribute(
value=expr.expr.accept(self),
attr=expr.name.lexeme,
)
def visit_variable_expr(self, expr: m.VariableExpr) -> ast.expr:
name: str = expr.name.lexeme
if (p := self.get_predicate(name)) is not None:
return p
return ast.Name(id=name)
def visit_grouping_expr(self, expr: m.GroupingExpr) -> ast.expr:
return expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> ast.expr:
return ast.Constant(value=expr.value)
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> ast.expr:
return ast.Name(id="_")

View File

@@ -2,15 +2,19 @@ import ast
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from typing import Optional, assert_never
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasPrinter
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
ConstraintType,
ExtensionType,
Function,
GenericType,
@@ -19,7 +23,9 @@ from midas.checker.types import (
Type,
TypeVar,
UnitType,
UnknownType,
)
from midas.generator.constraints import ConstraintGenerator
from midas.utils import TypedAST
@@ -30,7 +36,7 @@ class Scope:
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def __init__(self, workdir: Path) -> None:
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
self.workdir: Path = workdir.resolve()
self.build_dir: Path = self.workdir / "build" / "midas"
if self.build_dir.exists():
@@ -43,13 +49,18 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
judgements=[],
)
self._alias_count: int = 0
self._predicate_count: int = 0
self._scopes: list[Scope] = []
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
self._constraints: list[tuple[m.Expr, ast.expr]] = []
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
self.rel_src_path = src_path.relative_to(self.workdir)
self._typed_ast = typed_ast
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
module = ast.Module(body=body, type_ignores=[])
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
module = ast.Module(body=predicates + body, type_ignores=[])
module = ast.fix_missing_locations(module)
return module
@@ -246,7 +257,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
return generated
def _make_alias(self, expr: ast.expr) -> ast.expr:
name: str = f"__midas_alias_{self._alias_count}__"
name: str = f"__midas_a{self._alias_count}__"
alias = ast.Name(id=name)
self._alias_count += 1
self._scopes[-1].aliases.append(name)
@@ -276,6 +287,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
match type:
case UnknownType():
pass
case BaseType(name=name):
self._add_assert(
ast.Call(
@@ -301,8 +315,15 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._make_cast_assert_message(src_location, expr, type),
)
case AppliedType():
self._make_cast_asserts(src_location, expr, type.body)
case AppliedType(body=body):
self._make_cast_asserts(src_location, expr, body)
case ConstraintType(type=base, constraint=constraint):
self._make_cast_asserts(src_location, expr, base)
self._make_constraint_assert(src_location, expr, constraint)
case TypeVar():
raise RuntimeError("Unexpected TypeVar")
case (
TopType()
@@ -314,8 +335,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
):
raise NotImplementedError(f"Can't make assertion for type {type}")
case TypeVar():
raise RuntimeError("Unexpected TypeVar")
# Ensure exhaustiveness
case _:
assert_never(type)
def _make_cast_assert_message(
self, location: Location, expr: ast.expr, type: Type
@@ -339,3 +361,36 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
ast.Constant(f" to {type}"),
]
)
def _make_constraint_assert(
self, src_location: Location, expr: ast.expr, constraint: m.Expr
):
test_func: ast.expr = self._get_constraint(constraint)
self._add_assert(
ast.Call(
func=test_func,
args=[expr],
keywords=[],
),
self._make_constraint_assert_message(src_location, expr, constraint),
)
def _make_constraint_assert_message(
self, location: Location, expr: ast.expr, constraint: m.Expr
) -> ast.expr:
printer = MidasPrinter()
constraint_str: str = printer.print(constraint)
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
# f"file.py:L1:1: ConstraintError: Value does not fit constraint 'v > 0'"
return ast.Constant(
f"{loc_str}: ConstraintError: Value does not fit constraint '{constraint_str}'"
)
def _get_constraint(self, expr: m.Expr) -> ast.expr:
for expr2, constraint in self._constraints:
if expr2 == expr:
return constraint
constraint: ast.expr = self._constraint_generator.generate(expr)
self._constraints.append((expr, constraint))
return constraint

View File

@@ -1,337 +0,0 @@
import ast
from typing import Optional
import midas.ast.midas as m
from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
TopType,
Type,
TypeVar,
UnitType,
UnknownType,
substitute_typevars,
)
Empty = ast.Constant(value=...)
class StubsGenerator:
def __init__(self, types: TypesRegistry) -> None:
self.types: TypesRegistry = types
self.stubs: list[ast.stmt] = []
self.typing_imports: set[str] = set()
self.protocol_idx: int = 0
self.stub_idx: int = 0
self.type_var_idx: int = 0
self.substitutions: dict[str, dict[str, Type]] = {}
def generate_stubs(self) -> ast.Module:
self.stubs = []
self.typing_imports = set()
for name, type in self.types._types.items():
self.generate_stub(name, type)
imports = [
ast.ImportFrom(
module="__future__",
names=[ast.alias(name="annotations")],
level=0,
)
]
if len(self.typing_imports) != 0:
imports.append(
ast.ImportFrom(
module="typing",
names=[
ast.alias(name=name) for name in sorted(self.typing_imports)
],
level=0,
)
)
return ast.Module(body=imports + self.stubs, type_ignores=[])
def generate_stub(self, name: str, type: Type):
base_type: Type = type
members: dict[str, Member] = self.types._members.get(name, {})
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
return
bases: list[ast.expr] = []
substitutions: dict[str, Type] = {}
bases, substitutions = self.get_bases(type)
self.substitutions[name] = substitutions
body = self.generate_body(members, substitutions)
stub = ast.ClassDef(
name=name,
bases=bases,
body=body,
keywords=[],
decorator_list=[],
)
self.add_stub(stub)
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
match type:
case AliasType(type=base):
return [self.dump_type(base)], {}
case GenericType(params=params, body=body):
self.add_typing_import("Generic")
type_vars: ast.expr
params2: list[TypeVar] = self.define_type_vars(params)
if len(params) == 1:
type_vars = ast.Name(id=params2[0].name)
else:
type_vars = ast.Tuple(
elts=[ast.Name(id=param.name) for param in params2]
)
substitutions: dict[str, TypeVar] = {
param.name: param2 for param, param2 in zip(params, params2)
}
body_bases, body_subsitutions = self.get_bases(body)
return (
body_bases
+ [
ast.Subscript(
value=ast.Name(id="Generic"),
slice=type_vars,
)
],
body_subsitutions | substitutions,
)
case _:
return [], {}
def generate_body(
self, members: dict[str, Member], substitutions: dict[str, Type]
) -> list[ast.stmt]:
if len(members) == 0:
return [ast.Expr(value=Empty)]
body: list[ast.stmt] = []
for name, member in members.items():
type: Type = member.type
type = substitute_typevars(type, substitutions)
match member.kind:
case m.MemberKind.PROPERTY:
body.append(
ast.AnnAssign(
target=ast.Name(id=name),
annotation=self.dump_type(type),
simple=1,
)
)
case m.MemberKind.METHOD:
body.extend(self.dump_method(name, type))
return body
def dump_type(self, type: Type) -> ast.expr:
match type:
case AliasType(name=name) | GenericType(name=name) if (
name in self.substitutions
):
type = substitute_typevars(type, self.substitutions[name])
match type:
case TopType() | UnknownType():
self.add_typing_import("Any")
return ast.Name(id="Any")
case BaseType(name=name):
return ast.Name(id=name)
case AliasType(name=name):
return ast.Name(id=name)
case UnitType():
return ast.Constant(value=None)
case Function():
name: str = self.define_protocol(type)
return ast.Name(id=name)
case OverloadedFunction(overloads=overloads):
if len(overloads) == 1:
return self.dump_type(overloads[0])
return ast.BinOp(
left=self.dump_type(OverloadedFunction(overloads=overloads[:-1])),
op=ast.BitOr(),
right=self.dump_type(overloads[-1]),
)
case ComplexType():
name: str = self.new_stub_name()
self.generate_stub(name, type)
return ast.Name(id=name)
case ExtensionType():
raise NotImplementedError
case TypeVar():
return ast.Name(id=type.name)
case GenericType(name=name):
params: ast.expr
if len(type.params) == 1:
params = self.dump_type(type.params[0])
else:
params = ast.Tuple(
elts=[self.dump_type(param) for param in type.params]
)
return ast.Subscript(
value=ast.Name(id=type.name),
slice=params,
)
case AppliedType():
args: ast.expr
if len(type.args) == 1:
args = self.dump_type(type.args[0])
else:
args = ast.Tuple(elts=[self.dump_type(arg) for arg in type.args])
return ast.Subscript(
value=ast.Name(id=type.name),
slice=args,
)
def dump_method(
self, name: str, method: Type, overloaded: bool = False
) -> list[ast.stmt]:
match method:
case Function():
if overloaded:
self.add_typing_import("overload")
return [
ast.FunctionDef(
name=name,
args=self.dump_args(method, with_self=True),
returns=self.dump_type(method.returns),
body=[ast.Expr(value=Empty)],
decorator_list=[ast.Name(id="overload")] if overloaded else [],
)
]
case OverloadedFunction(overloads=overloads):
stmts: list[ast.stmt] = []
for overload in overloads:
stmts.extend(self.dump_method(name, overload, True))
return stmts
case _:
return [
ast.AnnAssign(
target=ast.Name(id=name),
annotation=self.dump_type(method),
simple=1,
)
]
def dump_args(self, func: Function, with_self: bool = False) -> ast.arguments:
pos: list[ast.arg] = [
ast.arg(arg=f"_{arg.pos}", annotation=self.dump_type(arg.type))
for arg in func.pos_args
]
mixed: list[ast.arg] = [
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
for arg in func.args
]
kw: list[ast.arg] = [
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
for arg in func.kw_args
]
defaults: list[ast.expr] = [
Empty for arg in func.pos_args + func.args if not arg.required
]
kw_defaults: list[Optional[ast.expr]] = [
None if arg.required else Empty for arg in func.kw_args
]
if with_self:
arg = ast.arg(arg="self", annotation=None)
if len(pos) != 0:
pos.insert(0, arg)
else:
mixed.insert(0, arg)
return ast.arguments(
posonlyargs=pos,
args=mixed,
kwonlyargs=kw,
defaults=defaults,
kw_defaults=kw_defaults,
)
def define_protocol(self, func: Function) -> str:
self.add_typing_import("Protocol")
name: str = self.new_protocol_name()
protocol = ast.ClassDef(
name=name,
bases=[ast.Name(id="Protocol")],
keywords=[],
body=[
ast.FunctionDef(
name="__call__",
args=self.dump_args(func, with_self=True),
returns=self.dump_type(func.returns),
body=[ast.Expr(value=Empty)],
decorator_list=[],
),
],
decorator_list=[],
)
self.add_stub(protocol)
return name
def new_protocol_name(self) -> str:
name: str = f"_Protocol{self.protocol_idx}"
self.protocol_idx += 1
return name
def new_stub_name(self) -> str:
name: str = f"_Stub_{self.stub_idx}"
self.stub_idx += 1
return name
def new_type_var_name(self) -> str:
name: str = f"_T{self.type_var_idx}"
self.type_var_idx += 1
return name
def add_stub(self, stub: ast.stmt):
self.stubs.append(stub)
def add_typing_import(self, name: str):
self.typing_imports.add(name)
def define_type_vars(self, vars: list[TypeVar]) -> list[TypeVar]:
vars2: list[TypeVar] = []
for var in vars:
vars2.append(self.define_type_var(var))
return vars2
def define_type_var(self, var: TypeVar) -> TypeVar:
name: str = self.new_type_var_name()
self.add_typing_import("TypeVar")
self.add_stub(
ast.Assign(
targets=[ast.Name(id=name)],
value=ast.Call(
func=ast.Name(id="TypeVar"),
args=[
ast.Constant(value=name),
],
keywords=(
[]
if var.bound is None
else [
ast.keyword(
arg="bound",
value=self.dump_type(var.bound),
)
]
),
),
)
)
return TypeVar(name=name, bound=None)

View File

@@ -3,6 +3,7 @@ from typing import Optional
from midas.ast.location import Location
from midas.ast.midas import (
BinaryExpr,
CallExpr,
ComplexType,
ConstraintType,
Expr,
@@ -17,6 +18,7 @@ from midas.ast.midas import (
MemberKind,
MemberStmt,
NamedType,
ParamSpec,
PredicateStmt,
Stmt,
Type,
@@ -265,6 +267,9 @@ class MidasParser(Parser):
Returns:
Expr: the parsed constraint expression
"""
return self.expression()
def expression(self) -> Expr:
return self.and_()
def and_(self) -> Expr:
@@ -331,7 +336,55 @@ class MidasParser(Parser):
right: Expr = self.unary()
location: Location = Location.span(operator.get_location(), right.location)
return UnaryExpr(location=location, operator=operator, right=right)
return self.reference()
return self.call()
def call(self) -> Expr:
expr: Expr = self.reference()
while self.match(TokenType.LEFT_PAREN):
expr = self.finish_call(expr)
return expr
def finish_call(self, callee: Expr) -> Expr:
l_paren: Token = self.previous()
pos_args: list[Expr] = []
kw_args: dict[str, Expr] = {}
keywords: bool = False
while not self.match(TokenType.RIGHT_PAREN):
if self.check_identifier() and self.check_next(TokenType.EQUAL):
keywords = True
keyword: Token = self.advance()
value: Expr = self.expression()
name: str = keyword.lexeme
if name in kw_args:
self.error(
self.peek(),
f"Multiple values passed for '{name}', only the last occurrence will be used",
)
kw_args[name] = value
else:
value = self.expression()
if self.check(TokenType.EQUAL):
if keywords:
raise self.error(self.peek(), "Invalid keyword argument name")
else:
raise self.error(
self.peek(),
"Cannot pass positional arguments after a keyword argument",
)
pos_args.append(value)
if not self.match(TokenType.COMMA):
break
r_paren: Token = self.consume(
TokenType.RIGHT_PAREN, "Expected ')' after arguments."
)
return CallExpr(
location=l_paren.location_to(r_paren),
callee=callee,
arguments=pos_args,
keywords=kw_args,
)
def reference(self) -> Expr:
"""Parse an attribute access expression or a simpler expression
@@ -453,23 +506,35 @@ class MidasParser(Parser):
PredicateStmt: the parsed predicate declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume_identifier("Expected predicate name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
subject: Token = self.consume_identifier("Expected subject name")
self.consume(TokenType.COLON, "Expected ':' after subject name")
type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
params: list[ParamSpec] = []
while self.check(TokenType.LEFT_PAREN):
params.append(self.function_args())
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
condition: Expr = self.constraint()
body: Expr = self.constraint()
return PredicateStmt(
location=keyword.location_to(self.previous()),
name=name,
subject=subject,
type=type,
condition=condition,
params=params,
body=body,
)
def function(self) -> FunctionType:
params: ParamSpec = self.function_args()
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: Type = self.type_expr()
return FunctionType(
location=params.l_paren.location_to(self.previous()),
params=params,
returns=result,
)
def function_args(self) -> ParamSpec:
l_paren: Token = self.consume(
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
)
@@ -526,14 +591,4 @@ class MidasParser(Parser):
self.error(token, "Unnamed mixed argument")
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: Type = self.type_expr()
return FunctionType(
location=l_paren.location_to(self.previous()),
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=result,
)
return ParamSpec(l_paren=l_paren, pos=pos_args, mixed=args, kw=kw_args)

View File

@@ -9,13 +9,13 @@ Module(
level=0),
Assign(
targets=[
Name(id='__midas_alias_0__')],
Name(id='__midas_a0__')],
value=Constant(value=123.45)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_alias_0__'),
Name(id='__midas_a0__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
@@ -26,7 +26,7 @@ Module(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_alias_0__')],
Name(id='__midas_a0__')],
keywords=[]),
attr='__name__'),
conversion=-1),
@@ -34,19 +34,19 @@ Module(
Assign(
targets=[
Name(id='distance')],
value=Name(id='__midas_alias_0__')),
value=Name(id='__midas_a0__')),
Delete(
targets=[
Name(id='__midas_alias_0__')]),
Name(id='__midas_a0__')]),
Assign(
targets=[
Name(id='__midas_alias_1__')],
Name(id='__midas_a1__')],
value=Constant(value=6.7)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_alias_1__'),
Name(id='__midas_a1__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
@@ -57,7 +57,7 @@ Module(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_alias_1__')],
Name(id='__midas_a1__')],
keywords=[]),
attr='__name__'),
conversion=-1),
@@ -65,10 +65,10 @@ Module(
Assign(
targets=[
Name(id='time')],
value=Name(id='__midas_alias_1__')),
value=Name(id='__midas_a1__')),
Delete(
targets=[
Name(id='__midas_alias_1__')]),
Name(id='__midas_a1__')]),
Assign(
targets=[
Name(id='speed')],

View File

@@ -2582,18 +2582,21 @@
"name": "__sub__",
"type": {
"_type": "FunctionType",
"pos_args": [
{
"name": null,
"type": {
"_type": "NamedType",
"name": "GeoLocation"
},
"required": true
}
],
"args": [],
"kw_args": [],
"params": {
"_type": "ParamSpec",
"pos": [
{
"name": null,
"type": {
"_type": "NamedType",
"name": "GeoLocation"
},
"required": true
}
],
"mixed": [],
"kw": []
},
"returns": {
"_type": "GenericType",
"type": {
@@ -2673,18 +2676,21 @@
"name": "__sub__",
"type": {
"_type": "FunctionType",
"pos_args": [
{
"name": null,
"type": {
"_type": "NamedType",
"name": "Latitude"
},
"required": true
}
],
"args": [],
"kw_args": [],
"params": {
"_type": "ParamSpec",
"pos": [
{
"name": null,
"type": {
"_type": "NamedType",
"name": "Latitude"
},
"required": true
}
],
"mixed": [],
"kw": []
},
"returns": {
"_type": "GenericType",
"type": {
@@ -2713,18 +2719,21 @@
"name": "__sub__",
"type": {
"_type": "FunctionType",
"pos_args": [
{
"name": null,
"type": {
"_type": "NamedType",
"name": "Longitude"
},
"required": true
}
],
"args": [],
"kw_args": [],
"params": {
"_type": "ParamSpec",
"pos": [
{
"name": null,
"type": {
"_type": "NamedType",
"name": "Longitude"
},
"required": true
}
],
"mixed": [],
"kw": []
},
"returns": {
"_type": "GenericType",
"type": {
@@ -2745,12 +2754,24 @@
{
"_type": "PredicateStmt",
"name": "Positive",
"subject": "v",
"type": {
"_type": "NamedType",
"name": "float"
},
"condition": {
"params": [
{
"_type": "ParamSpec",
"pos": [],
"mixed": [
{
"name": "v",
"type": {
"_type": "NamedType",
"name": "float"
},
"required": true
}
],
"kw": []
}
],
"body": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
@@ -2766,12 +2787,24 @@
{
"_type": "PredicateStmt",
"name": "StrictlyPositive",
"subject": "v",
"type": {
"_type": "NamedType",
"name": "float"
},
"condition": {
"params": [
{
"_type": "ParamSpec",
"pos": [],
"mixed": [
{
"name": "v",
"type": {
"_type": "NamedType",
"name": "float"
},
"required": true
}
],
"kw": []
}
],
"body": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
@@ -2787,12 +2820,24 @@
{
"_type": "PredicateStmt",
"name": "Equatorial",
"subject": "loc",
"type": {
"_type": "NamedType",
"name": "GeoLocation"
},
"condition": {
"params": [
{
"_type": "ParamSpec",
"pos": [],
"mixed": [
{
"name": "loc",
"type": {
"_type": "NamedType",
"name": "GeoLocation"
},
"required": true
}
],
"kw": []
}
],
"body": {
"_type": "GroupingExpr",
"expr": {
"_type": "BinaryExpr",
@@ -2827,12 +2872,24 @@
{
"_type": "PredicateStmt",
"name": "Arctic",
"subject": "loc",
"type": {
"_type": "NamedType",
"name": "GeoLocation"
},
"condition": {
"params": [
{
"_type": "ParamSpec",
"pos": [],
"mixed": [
{
"name": "loc",
"type": {
"_type": "NamedType",
"name": "GeoLocation"
},
"required": true
}
],
"kw": []
}
],
"body": {
"_type": "GroupingExpr",
"expr": {
"_type": "BinaryExpr",

View File

@@ -45,7 +45,7 @@ class GeneratorTester(Tester):
typed_ast: TypedAST = checker.type_check(path)
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
generator = Generator(workdir=path.parent)
generator = Generator(workdir=path.parent, types=checker.types)
result.compiled_ast = generator.generate_ast(typed_ast, path)
return result

View File

@@ -2,6 +2,7 @@ from typing import Optional, Sequence
from midas.ast.midas import (
BinaryExpr,
CallExpr,
ComplexType,
ConstraintType,
Expr,
@@ -15,6 +16,7 @@ from midas.ast.midas import (
LogicalExpr,
MemberStmt,
NamedType,
ParamSpec,
PredicateStmt,
Stmt,
Type,
@@ -78,9 +80,8 @@ class MidasAstJsonSerializer(
return {
"_type": "PredicateStmt",
"name": stmt.name.lexeme,
"subject": stmt.subject.lexeme,
"type": stmt.type.accept(self),
"condition": stmt.condition.accept(self),
"params": [self._serialize_param_spec(spec) for spec in stmt.params],
"body": stmt.body.accept(self),
}
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
@@ -106,6 +107,14 @@ class MidasAstJsonSerializer(
"right": expr.right.accept(self),
}
def visit_call_expr(self, expr: CallExpr) -> dict:
return {
"_type": "CallExpr",
"callee": expr.callee.accept(self),
"arguments": self._serialize_list(expr.arguments),
"keywords": {name: arg.accept(self) for name, arg in expr.keywords.items()},
}
def visit_get_expr(self, expr: GetExpr) -> dict:
return {
"_type": "GetExpr",
@@ -163,15 +172,21 @@ class MidasAstJsonSerializer(
def visit_function_type(self, type: FunctionType) -> dict:
return {
"_type": "FunctionType",
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
"args": [self._serialize_func_arg(arg) for arg in type.args],
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
"params": self._serialize_param_spec(type.params),
"returns": type.returns.accept(self),
}
def _serialize_param_spec(self, spec: ParamSpec) -> dict:
return {
"_type": "ParamSpec",
"pos": [self._serialize_func_arg(arg) for arg in spec.pos],
"mixed": [self._serialize_func_arg(arg) for arg in spec.mixed],
"kw": [self._serialize_func_arg(arg) for arg in spec.kw],
}
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
return {
"name": arg.name,
"name": arg.name.lexeme if arg.name is not None else None,
"type": arg.type.accept(self),
"required": arg.required,
}