Compare commits
10 Commits
feat/stubs
...
feat/const
| Author | SHA1 | Date | |
|---|---|---|---|
|
35ec0d0db8
|
|||
|
48fcb499a1
|
|||
|
bdc1b265a6
|
|||
|
1fb4b6f8c6
|
|||
|
48c1ecc1c8
|
|||
|
04853eac70
|
|||
|
020824d1f8
|
|||
|
ad86446a2d
|
|||
|
94d84ab170
|
|||
|
8381f4f31d
|
23
gen/midas.py
23
gen/midas.py
@@ -26,6 +26,14 @@ class MemberKind(Enum):
|
||||
METHOD = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
l_paren: Token
|
||||
pos: list[FunctionType.Argument]
|
||||
mixed: list[FunctionType.Argument]
|
||||
kw: list[FunctionType.Argument]
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
@@ -50,9 +58,8 @@ class ExtendStmt:
|
||||
|
||||
class PredicateStmt:
|
||||
name: Token
|
||||
subject: Token
|
||||
type: Type
|
||||
condition: Expr
|
||||
params: list[ParamSpec]
|
||||
body: Expr
|
||||
|
||||
|
||||
###<
|
||||
@@ -78,6 +85,12 @@ class UnaryExpr:
|
||||
right: Expr
|
||||
|
||||
|
||||
class CallExpr:
|
||||
callee: Expr
|
||||
arguments: list[Expr]
|
||||
keywords: dict[str, Expr]
|
||||
|
||||
|
||||
class GetExpr:
|
||||
expr: Expr
|
||||
name: Token
|
||||
@@ -128,9 +141,7 @@ class ExtensionType:
|
||||
|
||||
|
||||
class FunctionType:
|
||||
pos_args: list[Argument]
|
||||
args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
params: ParamSpec
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
|
||||
@@ -27,6 +27,14 @@ class MemberKind(Enum):
|
||||
METHOD = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
l_paren: Token
|
||||
pos: list[FunctionType.Argument]
|
||||
mixed: list[FunctionType.Argument]
|
||||
kw: list[FunctionType.Argument]
|
||||
|
||||
|
||||
##############
|
||||
# Statements #
|
||||
##############
|
||||
@@ -86,9 +94,8 @@ class ExtendStmt(Stmt):
|
||||
@dataclass(frozen=True)
|
||||
class PredicateStmt(Stmt):
|
||||
name: Token
|
||||
subject: Token
|
||||
type: Type
|
||||
condition: Expr
|
||||
params: list[ParamSpec]
|
||||
body: Expr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_predicate_stmt(self)
|
||||
@@ -116,6 +123,9 @@ class Expr(ABC):
|
||||
@abstractmethod
|
||||
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_call_expr(self, expr: CallExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_get_expr(self, expr: GetExpr) -> T: ...
|
||||
|
||||
@@ -161,6 +171,16 @@ class UnaryExpr(Expr):
|
||||
return visitor.visit_unary_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CallExpr(Expr):
|
||||
callee: Expr
|
||||
arguments: list[Expr]
|
||||
keywords: dict[str, Expr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_call_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GetExpr(Expr):
|
||||
expr: Expr
|
||||
@@ -279,9 +299,7 @@ class ExtensionType(Type):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionType(Type):
|
||||
pos_args: list[Argument]
|
||||
args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
params: ParamSpec
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
|
||||
@@ -150,13 +150,17 @@ class MidasAstPrinter(
|
||||
self._write_line("PredicateStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line(f'subject: "{stmt.subject.lexeme}"')
|
||||
self._write_line("type")
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
for i, spec in enumerate(stmt.params):
|
||||
self._idx = i
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
self._visit_param_spec(spec)
|
||||
|
||||
self._write_line("body", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
self._write_line("condition", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.condition.accept(self)
|
||||
stmt.body.accept(self)
|
||||
|
||||
# Expressions
|
||||
|
||||
@@ -195,6 +199,29 @@ class MidasAstPrinter(
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> None:
|
||||
self._write_line("CallExpr")
|
||||
with self._child_level():
|
||||
self._write_line("callee")
|
||||
with self._child_level(single=True):
|
||||
expr.callee.accept(self)
|
||||
self._write_line("arguments")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(expr.arguments):
|
||||
self._idx = i
|
||||
if i == len(expr.arguments) - 1:
|
||||
self._mark_last()
|
||||
arg.accept(self)
|
||||
self._write_line("keywords", last=True)
|
||||
with self._child_level():
|
||||
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||
self._idx = i
|
||||
if i == len(expr.keywords) - 1:
|
||||
self._mark_last()
|
||||
self._write_line(name)
|
||||
with self._child_level(single=True):
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
self._write_line("GetExpr")
|
||||
with self._child_level():
|
||||
@@ -276,34 +303,41 @@ class MidasAstPrinter(
|
||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||
self._write_line("FunctionType")
|
||||
with self._child_level():
|
||||
self._write_line("pos_args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.pos_args):
|
||||
self._idx = i
|
||||
if i == len(type.pos_args) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.args):
|
||||
self._idx = i
|
||||
if i == len(type.args) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("kw_args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.kw_args):
|
||||
self._idx = i
|
||||
if i == len(type.kw_args) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
self._write_line("params")
|
||||
with self._child_level(single=True):
|
||||
self._visit_param_spec(type.params)
|
||||
|
||||
self._write_line("returns", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.returns.accept(self)
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
|
||||
self._write_line("ParamSpec")
|
||||
with self._child_level():
|
||||
self._write_line("pos")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(spec.pos):
|
||||
self._idx = i
|
||||
if i == len(spec.pos) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("mixed")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(spec.mixed):
|
||||
self._idx = i
|
||||
if i == len(spec.mixed) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("kw", last=True)
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(spec.kw):
|
||||
self._idx = i
|
||||
if i == len(spec.kw) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
|
||||
self._write_line("Argument")
|
||||
with self._child_level():
|
||||
@@ -367,10 +401,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
name: str = stmt.name.lexeme
|
||||
subject: str = stmt.subject.lexeme
|
||||
type: str = stmt.type.accept(self)
|
||||
condition: str = stmt.condition.accept(self)
|
||||
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
|
||||
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
|
||||
body: str = stmt.body.accept(self)
|
||||
return self.indented(f"predicate {name}{sig} = {body}")
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
@@ -389,6 +422,12 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{operator}{right}"
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> str:
|
||||
args: list[str] = [arg.accept(self) for arg in expr.arguments] + [
|
||||
f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items()
|
||||
]
|
||||
return f"{expr.callee.accept(self)}({', '.join(args)})"
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
expr_: str = expr.expr.accept(self)
|
||||
name: str = expr.name.lexeme
|
||||
@@ -436,9 +475,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
|
||||
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
|
||||
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
|
||||
spec: str = self._visit_param_spec(type.params)
|
||||
return f"fn {spec} -> {type.returns.accept(self)}"
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
|
||||
pos_args: list[str] = [self._print_arg(arg) for arg in spec.pos]
|
||||
mixed_args: list[str] = [self._print_arg(arg) for arg in spec.mixed]
|
||||
kw_args: list[str] = [self._print_arg(arg) for arg in spec.kw]
|
||||
args: list[str] = pos_args
|
||||
|
||||
if len(pos_args) != 0:
|
||||
@@ -447,8 +490,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
if len(kw_args) != 0:
|
||||
args.append("*")
|
||||
args += kw_args
|
||||
|
||||
return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}"
|
||||
return f"({', '.join(args)})"
|
||||
|
||||
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
|
||||
res: str = ""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -9,9 +10,11 @@ from midas.checker.reporter import FileReporter, Reporter
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
ExtensionType,
|
||||
Function,
|
||||
GenericType,
|
||||
Predicate,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
@@ -21,6 +24,13 @@ from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TypedParamSpec:
|
||||
pos: list[Function.Argument]
|
||||
mixed: list[Function.Argument]
|
||||
kw: list[Function.Argument]
|
||||
|
||||
|
||||
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||
|
||||
@@ -37,6 +47,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
|
||||
self.process(builtins_path.read_text(), str(builtins_path))
|
||||
|
||||
self._bool: Type = self.get_type("bool")
|
||||
|
||||
def process(self, source: str, path: Optional[str]):
|
||||
self.reporter = self.reporter.for_file(path)
|
||||
lexer: MidasLexer = MidasLexer(source)
|
||||
@@ -102,11 +114,28 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
base_name,
|
||||
member.name.lexeme,
|
||||
member_type,
|
||||
member.kind,
|
||||
member.kind == m.MemberKind.METHOD,
|
||||
)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||
self.reporter.warning(stmt.location, "PredicateStmt not yet supported")
|
||||
params: list[TypedParamSpec] = [
|
||||
self._visit_param_spec(spec) for spec in stmt.params
|
||||
]
|
||||
type: Type = self._bool
|
||||
for spec in reversed(params):
|
||||
type = Function(
|
||||
pos_args=spec.pos,
|
||||
args=spec.mixed,
|
||||
kw_args=spec.kw,
|
||||
returns=type,
|
||||
)
|
||||
self.types.define_predicate(
|
||||
stmt.name.lexeme,
|
||||
Predicate(
|
||||
type=type,
|
||||
body=stmt.body,
|
||||
),
|
||||
)
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
||||
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
|
||||
@@ -117,6 +146,9 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
|
||||
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> None:
|
||||
self.reporter.warning(expr.location, "CallExpr not yet supported")
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> None:
|
||||
self.reporter.warning(expr.location, "GetExpr not yet supported")
|
||||
|
||||
@@ -153,10 +185,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
return UnknownType()
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
type.constraint.accept(self)
|
||||
# TODO
|
||||
return UnknownType()
|
||||
return ConstraintType(
|
||||
type=type.type.accept(self),
|
||||
constraint=type.constraint,
|
||||
)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
|
||||
return ComplexType(
|
||||
@@ -172,8 +204,17 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> Type:
|
||||
n_pos_args: int = len(type.pos_args)
|
||||
n_args: int = len(type.args)
|
||||
params: TypedParamSpec = self._visit_param_spec(type.params)
|
||||
return Function(
|
||||
pos_args=params.pos,
|
||||
args=params.mixed,
|
||||
kw_args=params.kw,
|
||||
returns=type.returns.accept(self),
|
||||
)
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec:
|
||||
n_pos: int = len(spec.pos)
|
||||
n_mixed: int = len(spec.mixed)
|
||||
|
||||
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
|
||||
return Function.Argument(
|
||||
@@ -183,14 +224,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
required=arg.required,
|
||||
)
|
||||
|
||||
return Function(
|
||||
pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
|
||||
args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
|
||||
kw_args=[
|
||||
process_arg(arg, i + n_pos_args + n_args)
|
||||
for i, arg in enumerate(type.kw_args)
|
||||
],
|
||||
returns=type.returns.accept(self),
|
||||
return TypedParamSpec(
|
||||
pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)],
|
||||
mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)],
|
||||
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
|
||||
)
|
||||
|
||||
def _resolve_type_params(self, params: list[m.TypeParam]):
|
||||
|
||||
@@ -12,7 +12,6 @@ from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter, Reporter
|
||||
from midas.checker.resolver import Resolver
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
AppliedType,
|
||||
Function,
|
||||
OverloadedFunction,
|
||||
@@ -695,17 +694,9 @@ class PythonTyper(
|
||||
case UnknownType():
|
||||
return UnknownType()
|
||||
|
||||
case AliasType(type=base):
|
||||
return self._get_call_result(
|
||||
location, base, positional, keywords, report_errors
|
||||
)
|
||||
|
||||
case _:
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"{callee} ({callee.__class__.__name__}) is not callable",
|
||||
)
|
||||
self.reporter.error(location, f"{callee} is not callable")
|
||||
return None
|
||||
|
||||
def _are_arguments_valid(
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.midas import MemberKind
|
||||
from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
@@ -13,6 +11,7 @@ from midas.checker.types import (
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
Predicate,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
@@ -21,17 +20,12 @@ from midas.checker.types import (
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Member:
|
||||
kind: MemberKind
|
||||
type: Type
|
||||
|
||||
|
||||
class TypesRegistry:
|
||||
def __init__(self) -> None:
|
||||
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
|
||||
self._types: dict[str, Type] = {}
|
||||
self._members: dict[str, dict[str, Member]] = {}
|
||||
self._members: dict[str, dict[str, Type]] = {}
|
||||
self._predicates: dict[str, Predicate] = {}
|
||||
|
||||
def get_type(self, name: str) -> Type:
|
||||
"""Get a type from its name
|
||||
@@ -68,38 +62,31 @@ class TypesRegistry:
|
||||
return type
|
||||
|
||||
def define_member(
|
||||
self,
|
||||
type_name: str,
|
||||
member_name: str,
|
||||
member_type: Type,
|
||||
kind: MemberKind,
|
||||
self, type_name: str, member_name: str, member_type: Type, is_method: bool
|
||||
):
|
||||
members: dict[str, Member] = self._members.setdefault(type_name, {})
|
||||
members: dict[str, Type] = self._members.setdefault(type_name, {})
|
||||
if member_name in members:
|
||||
current: Member = members[member_name]
|
||||
if current.kind != kind:
|
||||
if not is_method:
|
||||
self.logger.error(
|
||||
f"Member '{member_name}' is already defined as a {current.kind},"
|
||||
+ f" cannot define a {kind} with the same name"
|
||||
f"Member '{member_name}' already defined for type {type_name}"
|
||||
)
|
||||
return
|
||||
if kind != MemberKind.METHOD:
|
||||
self.logger.error(
|
||||
f"Member '{member_name}' already defined for type {type_name},"
|
||||
+ " only methods can be overloaded"
|
||||
)
|
||||
return
|
||||
|
||||
current: Type = members[member_name]
|
||||
combined: Type
|
||||
match current.type:
|
||||
match current:
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
combined = OverloadedFunction(overloads=overloads + [member_type])
|
||||
case _:
|
||||
combined = OverloadedFunction(overloads=[current.type, member_type])
|
||||
members[member_name] = Member(kind=current.kind, type=combined)
|
||||
combined = OverloadedFunction(overloads=[current, member_type])
|
||||
members[member_name] = combined
|
||||
|
||||
else:
|
||||
members[member_name] = Member(kind=kind, type=member_type)
|
||||
members[member_name] = member_type
|
||||
|
||||
def define_predicate(self, name: str, predicate: Predicate):
|
||||
if name in self._predicates:
|
||||
raise ValueError(f"Predicate {name} already defined")
|
||||
self._predicates[name] = predicate
|
||||
|
||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||
"""Check whether `type1` is a subtype of `type2`
|
||||
@@ -317,13 +304,13 @@ class TypesRegistry:
|
||||
case BaseType(name=name):
|
||||
if name in self._members:
|
||||
if member_name in self._members[name]:
|
||||
return self._members[name][member_name].type
|
||||
return self._members[name][member_name]
|
||||
return None
|
||||
|
||||
case AliasType(name=name, type=base):
|
||||
if name in self._members:
|
||||
if member_name in self._members[name]:
|
||||
return self._members[name][member_name].type
|
||||
return self._members[name][member_name]
|
||||
return self.lookup_member(base, member_name)
|
||||
|
||||
case AppliedType(name=name, body=body, args=args):
|
||||
@@ -337,7 +324,7 @@ class TypesRegistry:
|
||||
}
|
||||
if name in self._members:
|
||||
if member_name in self._members[name]:
|
||||
member_type: Type = self._members[name][member_name].type
|
||||
member_type: Type = self._members[name][member_name]
|
||||
return substitute_typevars(member_type, substitutions)
|
||||
|
||||
member_type2: Optional[Type] = self.lookup_member(body, member_name)
|
||||
@@ -365,3 +352,6 @@ class TypesRegistry:
|
||||
case _:
|
||||
self.logger.debug(f"Can't get member on {type}")
|
||||
return None
|
||||
|
||||
def lookup_predicate(self, name: str) -> Optional[Predicate]:
|
||||
return self._predicates.get(name)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import Optional, assert_never
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.ast.printer import MidasPrinter
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
@@ -130,6 +133,16 @@ class AppliedType:
|
||||
return f"{self.name}[{', '.join(map(str, self.args))}]"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ConstraintType:
|
||||
type: Type
|
||||
constraint: m.Expr
|
||||
|
||||
def __str__(self) -> str:
|
||||
printer = MidasPrinter()
|
||||
return f"{self.type} where {printer.print(self.constraint)}"
|
||||
|
||||
|
||||
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
def sub_argument(arg: Function.Argument):
|
||||
return Function.Argument(
|
||||
@@ -140,9 +153,6 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
)
|
||||
|
||||
match type:
|
||||
case TopType():
|
||||
return type
|
||||
|
||||
case BaseType(name=name) if name in substitutions:
|
||||
return substitutions[name]
|
||||
|
||||
@@ -198,32 +208,27 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
body=substitute_typevars(body, substitutions),
|
||||
)
|
||||
|
||||
case ConstraintType():
|
||||
return ConstraintType(
|
||||
type=substitute_typevars(type.type, substitutions),
|
||||
constraint=type.constraint,
|
||||
)
|
||||
|
||||
case TypeVar(name=name):
|
||||
if name in substitutions:
|
||||
return substitutions[name]
|
||||
raise ValueError(f"Missing TypeVar substitution for {name}")
|
||||
|
||||
case GenericType(name=name, params=params, body=body):
|
||||
params2: list[TypeVar] = []
|
||||
for param in params:
|
||||
param2: Type = substitute_typevars(param, substitutions)
|
||||
if not isinstance(param2, TypeVar):
|
||||
raise ValueError(
|
||||
f"Invalid type parameter substitution, expected TypeVar, got {param2}"
|
||||
)
|
||||
params2.append(param2)
|
||||
return GenericType(
|
||||
name=name,
|
||||
params=params2,
|
||||
body=substitute_typevars(body, substitutions),
|
||||
)
|
||||
|
||||
case UnknownType() | UnitType():
|
||||
return type
|
||||
|
||||
case _:
|
||||
case TopType() | GenericType():
|
||||
raise NotImplementedError(f"Unsupported type {type}")
|
||||
|
||||
# Ensure exhaustiveness
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
|
||||
def unfold_type(type: Type) -> Type:
|
||||
match type:
|
||||
@@ -233,6 +238,12 @@ def unfold_type(type: Type) -> Type:
|
||||
return type
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Predicate:
|
||||
type: Type
|
||||
body: m.Expr
|
||||
|
||||
|
||||
Type = (
|
||||
TopType
|
||||
| BaseType
|
||||
@@ -246,4 +257,5 @@ Type = (
|
||||
| TypeVar
|
||||
| GenericType
|
||||
| AppliedType
|
||||
| ConstraintType
|
||||
)
|
||||
|
||||
@@ -4,6 +4,5 @@ from .format import format as format
|
||||
from .highlight import highlight as highlight
|
||||
from .parse import parse as parse
|
||||
from .registry import dump_registry as dump_registry
|
||||
from .stubs import stubs as stubs
|
||||
from .types import types as types
|
||||
from .validate import validate as validate
|
||||
|
||||
@@ -38,5 +38,5 @@ def compile(
|
||||
if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)):
|
||||
sys.exit(1)
|
||||
|
||||
generator = Generator(workdir=source_path.parent)
|
||||
generator = Generator(workdir=source_path.parent, types=checker.types)
|
||||
generator.generate(typed_ast, source_path)
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
|
||||
import click
|
||||
|
||||
from midas.checker.checker import TypeChecker
|
||||
from midas.generator.stubs import StubsGenerator
|
||||
|
||||
|
||||
@click.command(help="Generate stubs from Midas definitions")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||
def stubs(
|
||||
file: TextIO,
|
||||
output: TextIO,
|
||||
):
|
||||
source_path: Path = Path(file.name).resolve()
|
||||
|
||||
checker = TypeChecker()
|
||||
checker.import_midas(source_path)
|
||||
|
||||
generator = StubsGenerator(checker.types)
|
||||
module: ast.Module = generator.generate_stubs()
|
||||
module = ast.fix_missing_locations(module)
|
||||
|
||||
output.write(ast.unparse(module))
|
||||
@@ -18,7 +18,6 @@ midas.add_command(commands.highlight)
|
||||
midas.add_command(commands.parse)
|
||||
midas.add_command(commands.dump_registry)
|
||||
midas.add_command(commands.types)
|
||||
midas.add_command(commands.stubs)
|
||||
midas.add_command(commands.validate)
|
||||
|
||||
|
||||
|
||||
188
midas/generator/constraints.py
Normal file
188
midas/generator/constraints.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import ast
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import Function, Predicate, Type
|
||||
from midas.lexer.token import TokenType
|
||||
|
||||
LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = {
|
||||
TokenType.AND: ast.And,
|
||||
# TokenType.OR: ast.Or,
|
||||
}
|
||||
|
||||
BINARY_OPERATORS: dict[TokenType, type[ast.operator]] = {
|
||||
# TokenType.PLUS: ast.Add,
|
||||
TokenType.MINUS: ast.Sub,
|
||||
TokenType.STAR: ast.Mult,
|
||||
TokenType.SLASH: ast.Div,
|
||||
}
|
||||
|
||||
UNARY_OPERATORS: dict[TokenType, type[ast.unaryop]] = {
|
||||
# TokenType.PLUS: ast.UAdd,
|
||||
TokenType.MINUS: ast.USub,
|
||||
}
|
||||
|
||||
COMPARISON_OPERATORS: dict[TokenType, type[ast.cmpop]] = {
|
||||
TokenType.GREATER: ast.Gt,
|
||||
TokenType.GREATER_EQUAL: ast.GtE,
|
||||
TokenType.LESS: ast.Lt,
|
||||
TokenType.LESS_EQUAL: ast.LtE,
|
||||
TokenType.EQUAL_EQUAL: ast.Eq,
|
||||
TokenType.BANG_EQUAL: ast.NotEq,
|
||||
}
|
||||
|
||||
|
||||
class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||
def __init__(self, types: TypesRegistry):
|
||||
self.types: TypesRegistry = types
|
||||
self._id: int = 0
|
||||
self._definitions: list[ast.stmt] = []
|
||||
self._aliases: dict[str, str] = {}
|
||||
|
||||
def get_definitions(self) -> list[ast.stmt]:
|
||||
return self._definitions
|
||||
|
||||
def generate(self, expr: m.Expr) -> ast.expr:
|
||||
match expr:
|
||||
case m.VariableExpr():
|
||||
return expr.accept(self)
|
||||
case _:
|
||||
func = Function(
|
||||
pos_args=[],
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="_",
|
||||
type=self.types.get_type("Any"),
|
||||
required=True,
|
||||
)
|
||||
],
|
||||
kw_args=[],
|
||||
returns=self.types.get_type("bool"),
|
||||
)
|
||||
alias: str = self.make_alias(None)
|
||||
definition: ast.stmt = self.make_definition(
|
||||
alias, Predicate(type=func, body=expr)
|
||||
)
|
||||
self._definitions.append(definition)
|
||||
return ast.Name(id=alias)
|
||||
|
||||
def make_alias(self, name: Optional[str]) -> str:
|
||||
suffix: str = f"_{name}" if name is not None else ""
|
||||
alias: str = f"__midas_p{self._id}{suffix}__"
|
||||
self._id += 1
|
||||
return alias
|
||||
|
||||
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
|
||||
body: list[ast.stmt] = [ast.Return(value=predicate.body.accept(self))]
|
||||
return self.make_func(name, body, predicate.type)
|
||||
|
||||
def make_args(self, func: Function) -> ast.arguments:
|
||||
return ast.arguments(
|
||||
posonlyargs=[ast.arg(arg=arg.name) for arg in func.pos_args],
|
||||
args=[ast.arg(arg=arg.name) for arg in func.args],
|
||||
kwonlyargs=[ast.arg(arg=arg.name) for arg in func.kw_args],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
)
|
||||
|
||||
def make_func(
|
||||
self, name: str, inner_body: list[ast.stmt], type: Type, level: int = 0
|
||||
) -> ast.stmt:
|
||||
match type:
|
||||
case Function(returns=Function()):
|
||||
inner_name: str = f"inner{level}"
|
||||
return ast.FunctionDef(
|
||||
name=name,
|
||||
args=self.make_args(type),
|
||||
body=[
|
||||
self.make_func(inner_name, inner_body, type.returns, level + 1),
|
||||
ast.Return(value=ast.Name(id=inner_name)),
|
||||
],
|
||||
decorator_list=[],
|
||||
)
|
||||
|
||||
case Function():
|
||||
return ast.FunctionDef(
|
||||
name=name,
|
||||
args=self.make_args(type),
|
||||
body=inner_body,
|
||||
decorator_list=[],
|
||||
)
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Expected function, got {type}")
|
||||
|
||||
def get_predicate(self, name: str) -> Optional[ast.expr]:
|
||||
if name not in self._aliases:
|
||||
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
|
||||
if predicate is None:
|
||||
return None
|
||||
alias: str = self.make_alias(name)
|
||||
self._aliases[name] = alias
|
||||
self._definitions.append(self.make_definition(alias, predicate))
|
||||
|
||||
return ast.Name(id=self._aliases[name])
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> ast.expr:
|
||||
return ast.BoolOp(
|
||||
op=LOGICAL_OPERATORS[expr.operator.type](),
|
||||
values=[
|
||||
expr.left.accept(self),
|
||||
expr.right.accept(self),
|
||||
],
|
||||
)
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> ast.expr:
|
||||
op: TokenType = expr.operator.type
|
||||
if op in BINARY_OPERATORS:
|
||||
return ast.BinOp(
|
||||
left=expr.left.accept(self),
|
||||
op=BINARY_OPERATORS[op](),
|
||||
right=expr.right.accept(self),
|
||||
)
|
||||
if op in COMPARISON_OPERATORS:
|
||||
return ast.Compare(
|
||||
left=expr.left.accept(self),
|
||||
ops=[COMPARISON_OPERATORS[op]()],
|
||||
comparators=[expr.right.accept(self)],
|
||||
)
|
||||
raise ValueError(f"Unexpected binary operator {op}")
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> ast.expr:
|
||||
return ast.UnaryOp(
|
||||
op=UNARY_OPERATORS[expr.operator.type](),
|
||||
operand=expr.right.accept(self),
|
||||
)
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=expr.callee.accept(self),
|
||||
args=[arg.accept(self) for arg in expr.arguments],
|
||||
keywords=[
|
||||
ast.keyword(arg=name, value=arg.accept(self))
|
||||
for name, arg in expr.keywords.items()
|
||||
],
|
||||
)
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> ast.expr:
|
||||
return ast.Attribute(
|
||||
value=expr.expr.accept(self),
|
||||
attr=expr.name.lexeme,
|
||||
)
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> ast.expr:
|
||||
name: str = expr.name.lexeme
|
||||
if (p := self.get_predicate(name)) is not None:
|
||||
return p
|
||||
return ast.Name(id=name)
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> ast.expr:
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> ast.expr:
|
||||
return ast.Constant(value=expr.value)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> ast.expr:
|
||||
return ast.Name(id="_")
|
||||
@@ -2,15 +2,19 @@ import ast
|
||||
import shutil
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, assert_never
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.printer import MidasPrinter
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
ExtensionType,
|
||||
Function,
|
||||
GenericType,
|
||||
@@ -19,7 +23,9 @@ from midas.checker.types import (
|
||||
Type,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.generator.constraints import ConstraintGenerator
|
||||
from midas.utils import TypedAST
|
||||
|
||||
|
||||
@@ -30,7 +36,7 @@ class Scope:
|
||||
|
||||
|
||||
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
def __init__(self, workdir: Path) -> None:
|
||||
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
|
||||
self.workdir: Path = workdir.resolve()
|
||||
self.build_dir: Path = self.workdir / "build" / "midas"
|
||||
if self.build_dir.exists():
|
||||
@@ -43,13 +49,18 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
judgements=[],
|
||||
)
|
||||
self._alias_count: int = 0
|
||||
self._predicate_count: int = 0
|
||||
self._scopes: list[Scope] = []
|
||||
|
||||
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
|
||||
self._constraints: list[tuple[m.Expr, ast.expr]] = []
|
||||
|
||||
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
|
||||
self.rel_src_path = src_path.relative_to(self.workdir)
|
||||
self._typed_ast = typed_ast
|
||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
||||
module = ast.Module(body=body, type_ignores=[])
|
||||
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
|
||||
module = ast.Module(body=predicates + body, type_ignores=[])
|
||||
module = ast.fix_missing_locations(module)
|
||||
return module
|
||||
|
||||
@@ -246,7 +257,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
return generated
|
||||
|
||||
def _make_alias(self, expr: ast.expr) -> ast.expr:
|
||||
name: str = f"__midas_alias_{self._alias_count}__"
|
||||
name: str = f"__midas_a{self._alias_count}__"
|
||||
alias = ast.Name(id=name)
|
||||
self._alias_count += 1
|
||||
self._scopes[-1].aliases.append(name)
|
||||
@@ -276,6 +287,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
|
||||
match type:
|
||||
case UnknownType():
|
||||
pass
|
||||
|
||||
case BaseType(name=name):
|
||||
self._add_assert(
|
||||
ast.Call(
|
||||
@@ -301,8 +315,15 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
self._make_cast_assert_message(src_location, expr, type),
|
||||
)
|
||||
|
||||
case AppliedType():
|
||||
self._make_cast_asserts(src_location, expr, type.body)
|
||||
case AppliedType(body=body):
|
||||
self._make_cast_asserts(src_location, expr, body)
|
||||
|
||||
case ConstraintType(type=base, constraint=constraint):
|
||||
self._make_cast_asserts(src_location, expr, base)
|
||||
self._make_constraint_assert(src_location, expr, constraint)
|
||||
|
||||
case TypeVar():
|
||||
raise RuntimeError("Unexpected TypeVar")
|
||||
|
||||
case (
|
||||
TopType()
|
||||
@@ -314,8 +335,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
):
|
||||
raise NotImplementedError(f"Can't make assertion for type {type}")
|
||||
|
||||
case TypeVar():
|
||||
raise RuntimeError("Unexpected TypeVar")
|
||||
# Ensure exhaustiveness
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
def _make_cast_assert_message(
|
||||
self, location: Location, expr: ast.expr, type: Type
|
||||
@@ -339,3 +361,36 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
ast.Constant(f" to {type}"),
|
||||
]
|
||||
)
|
||||
|
||||
def _make_constraint_assert(
|
||||
self, src_location: Location, expr: ast.expr, constraint: m.Expr
|
||||
):
|
||||
test_func: ast.expr = self._get_constraint(constraint)
|
||||
self._add_assert(
|
||||
ast.Call(
|
||||
func=test_func,
|
||||
args=[expr],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_constraint_assert_message(src_location, expr, constraint),
|
||||
)
|
||||
|
||||
def _make_constraint_assert_message(
|
||||
self, location: Location, expr: ast.expr, constraint: m.Expr
|
||||
) -> ast.expr:
|
||||
printer = MidasPrinter()
|
||||
constraint_str: str = printer.print(constraint)
|
||||
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||
# f"file.py:L1:1: ConstraintError: Value does not fit constraint 'v > 0'"
|
||||
return ast.Constant(
|
||||
f"{loc_str}: ConstraintError: Value does not fit constraint '{constraint_str}'"
|
||||
)
|
||||
|
||||
def _get_constraint(self, expr: m.Expr) -> ast.expr:
|
||||
for expr2, constraint in self._constraints:
|
||||
if expr2 == expr:
|
||||
return constraint
|
||||
|
||||
constraint: ast.expr = self._constraint_generator.generate(expr)
|
||||
self._constraints.append((expr, constraint))
|
||||
return constraint
|
||||
|
||||
@@ -1,337 +0,0 @@
|
||||
import ast
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.registry import Member, TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ComplexType,
|
||||
ExtensionType,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
UnknownType,
|
||||
substitute_typevars,
|
||||
)
|
||||
|
||||
Empty = ast.Constant(value=...)
|
||||
|
||||
|
||||
class StubsGenerator:
|
||||
def __init__(self, types: TypesRegistry) -> None:
|
||||
self.types: TypesRegistry = types
|
||||
self.stubs: list[ast.stmt] = []
|
||||
self.typing_imports: set[str] = set()
|
||||
self.protocol_idx: int = 0
|
||||
self.stub_idx: int = 0
|
||||
self.type_var_idx: int = 0
|
||||
self.substitutions: dict[str, dict[str, Type]] = {}
|
||||
|
||||
def generate_stubs(self) -> ast.Module:
|
||||
self.stubs = []
|
||||
self.typing_imports = set()
|
||||
for name, type in self.types._types.items():
|
||||
self.generate_stub(name, type)
|
||||
|
||||
imports = [
|
||||
ast.ImportFrom(
|
||||
module="__future__",
|
||||
names=[ast.alias(name="annotations")],
|
||||
level=0,
|
||||
)
|
||||
]
|
||||
if len(self.typing_imports) != 0:
|
||||
imports.append(
|
||||
ast.ImportFrom(
|
||||
module="typing",
|
||||
names=[
|
||||
ast.alias(name=name) for name in sorted(self.typing_imports)
|
||||
],
|
||||
level=0,
|
||||
)
|
||||
)
|
||||
return ast.Module(body=imports + self.stubs, type_ignores=[])
|
||||
|
||||
def generate_stub(self, name: str, type: Type):
|
||||
base_type: Type = type
|
||||
|
||||
members: dict[str, Member] = self.types._members.get(name, {})
|
||||
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
|
||||
return
|
||||
|
||||
bases: list[ast.expr] = []
|
||||
substitutions: dict[str, Type] = {}
|
||||
bases, substitutions = self.get_bases(type)
|
||||
self.substitutions[name] = substitutions
|
||||
|
||||
body = self.generate_body(members, substitutions)
|
||||
stub = ast.ClassDef(
|
||||
name=name,
|
||||
bases=bases,
|
||||
body=body,
|
||||
keywords=[],
|
||||
decorator_list=[],
|
||||
)
|
||||
self.add_stub(stub)
|
||||
|
||||
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
|
||||
match type:
|
||||
case AliasType(type=base):
|
||||
return [self.dump_type(base)], {}
|
||||
case GenericType(params=params, body=body):
|
||||
self.add_typing_import("Generic")
|
||||
type_vars: ast.expr
|
||||
|
||||
params2: list[TypeVar] = self.define_type_vars(params)
|
||||
if len(params) == 1:
|
||||
type_vars = ast.Name(id=params2[0].name)
|
||||
else:
|
||||
type_vars = ast.Tuple(
|
||||
elts=[ast.Name(id=param.name) for param in params2]
|
||||
)
|
||||
|
||||
substitutions: dict[str, TypeVar] = {
|
||||
param.name: param2 for param, param2 in zip(params, params2)
|
||||
}
|
||||
|
||||
body_bases, body_subsitutions = self.get_bases(body)
|
||||
return (
|
||||
body_bases
|
||||
+ [
|
||||
ast.Subscript(
|
||||
value=ast.Name(id="Generic"),
|
||||
slice=type_vars,
|
||||
)
|
||||
],
|
||||
body_subsitutions | substitutions,
|
||||
)
|
||||
case _:
|
||||
return [], {}
|
||||
|
||||
def generate_body(
|
||||
self, members: dict[str, Member], substitutions: dict[str, Type]
|
||||
) -> list[ast.stmt]:
|
||||
if len(members) == 0:
|
||||
return [ast.Expr(value=Empty)]
|
||||
|
||||
body: list[ast.stmt] = []
|
||||
for name, member in members.items():
|
||||
type: Type = member.type
|
||||
type = substitute_typevars(type, substitutions)
|
||||
match member.kind:
|
||||
case m.MemberKind.PROPERTY:
|
||||
body.append(
|
||||
ast.AnnAssign(
|
||||
target=ast.Name(id=name),
|
||||
annotation=self.dump_type(type),
|
||||
simple=1,
|
||||
)
|
||||
)
|
||||
case m.MemberKind.METHOD:
|
||||
body.extend(self.dump_method(name, type))
|
||||
return body
|
||||
|
||||
def dump_type(self, type: Type) -> ast.expr:
|
||||
match type:
|
||||
case AliasType(name=name) | GenericType(name=name) if (
|
||||
name in self.substitutions
|
||||
):
|
||||
type = substitute_typevars(type, self.substitutions[name])
|
||||
|
||||
match type:
|
||||
case TopType() | UnknownType():
|
||||
self.add_typing_import("Any")
|
||||
return ast.Name(id="Any")
|
||||
case BaseType(name=name):
|
||||
return ast.Name(id=name)
|
||||
case AliasType(name=name):
|
||||
return ast.Name(id=name)
|
||||
case UnitType():
|
||||
return ast.Constant(value=None)
|
||||
case Function():
|
||||
name: str = self.define_protocol(type)
|
||||
return ast.Name(id=name)
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
if len(overloads) == 1:
|
||||
return self.dump_type(overloads[0])
|
||||
return ast.BinOp(
|
||||
left=self.dump_type(OverloadedFunction(overloads=overloads[:-1])),
|
||||
op=ast.BitOr(),
|
||||
right=self.dump_type(overloads[-1]),
|
||||
)
|
||||
|
||||
case ComplexType():
|
||||
name: str = self.new_stub_name()
|
||||
self.generate_stub(name, type)
|
||||
return ast.Name(id=name)
|
||||
|
||||
case ExtensionType():
|
||||
raise NotImplementedError
|
||||
|
||||
case TypeVar():
|
||||
return ast.Name(id=type.name)
|
||||
case GenericType(name=name):
|
||||
params: ast.expr
|
||||
if len(type.params) == 1:
|
||||
params = self.dump_type(type.params[0])
|
||||
else:
|
||||
params = ast.Tuple(
|
||||
elts=[self.dump_type(param) for param in type.params]
|
||||
)
|
||||
return ast.Subscript(
|
||||
value=ast.Name(id=type.name),
|
||||
slice=params,
|
||||
)
|
||||
case AppliedType():
|
||||
args: ast.expr
|
||||
if len(type.args) == 1:
|
||||
args = self.dump_type(type.args[0])
|
||||
else:
|
||||
args = ast.Tuple(elts=[self.dump_type(arg) for arg in type.args])
|
||||
return ast.Subscript(
|
||||
value=ast.Name(id=type.name),
|
||||
slice=args,
|
||||
)
|
||||
|
||||
def dump_method(
|
||||
self, name: str, method: Type, overloaded: bool = False
|
||||
) -> list[ast.stmt]:
|
||||
match method:
|
||||
case Function():
|
||||
if overloaded:
|
||||
self.add_typing_import("overload")
|
||||
return [
|
||||
ast.FunctionDef(
|
||||
name=name,
|
||||
args=self.dump_args(method, with_self=True),
|
||||
returns=self.dump_type(method.returns),
|
||||
body=[ast.Expr(value=Empty)],
|
||||
decorator_list=[ast.Name(id="overload")] if overloaded else [],
|
||||
)
|
||||
]
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
stmts: list[ast.stmt] = []
|
||||
for overload in overloads:
|
||||
stmts.extend(self.dump_method(name, overload, True))
|
||||
return stmts
|
||||
case _:
|
||||
return [
|
||||
ast.AnnAssign(
|
||||
target=ast.Name(id=name),
|
||||
annotation=self.dump_type(method),
|
||||
simple=1,
|
||||
)
|
||||
]
|
||||
|
||||
def dump_args(self, func: Function, with_self: bool = False) -> ast.arguments:
|
||||
pos: list[ast.arg] = [
|
||||
ast.arg(arg=f"_{arg.pos}", annotation=self.dump_type(arg.type))
|
||||
for arg in func.pos_args
|
||||
]
|
||||
mixed: list[ast.arg] = [
|
||||
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
|
||||
for arg in func.args
|
||||
]
|
||||
kw: list[ast.arg] = [
|
||||
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
|
||||
for arg in func.kw_args
|
||||
]
|
||||
defaults: list[ast.expr] = [
|
||||
Empty for arg in func.pos_args + func.args if not arg.required
|
||||
]
|
||||
kw_defaults: list[Optional[ast.expr]] = [
|
||||
None if arg.required else Empty for arg in func.kw_args
|
||||
]
|
||||
if with_self:
|
||||
arg = ast.arg(arg="self", annotation=None)
|
||||
if len(pos) != 0:
|
||||
pos.insert(0, arg)
|
||||
else:
|
||||
mixed.insert(0, arg)
|
||||
return ast.arguments(
|
||||
posonlyargs=pos,
|
||||
args=mixed,
|
||||
kwonlyargs=kw,
|
||||
defaults=defaults,
|
||||
kw_defaults=kw_defaults,
|
||||
)
|
||||
|
||||
def define_protocol(self, func: Function) -> str:
|
||||
self.add_typing_import("Protocol")
|
||||
name: str = self.new_protocol_name()
|
||||
protocol = ast.ClassDef(
|
||||
name=name,
|
||||
bases=[ast.Name(id="Protocol")],
|
||||
keywords=[],
|
||||
body=[
|
||||
ast.FunctionDef(
|
||||
name="__call__",
|
||||
args=self.dump_args(func, with_self=True),
|
||||
returns=self.dump_type(func.returns),
|
||||
body=[ast.Expr(value=Empty)],
|
||||
decorator_list=[],
|
||||
),
|
||||
],
|
||||
decorator_list=[],
|
||||
)
|
||||
self.add_stub(protocol)
|
||||
return name
|
||||
|
||||
def new_protocol_name(self) -> str:
|
||||
name: str = f"_Protocol{self.protocol_idx}"
|
||||
self.protocol_idx += 1
|
||||
return name
|
||||
|
||||
def new_stub_name(self) -> str:
|
||||
name: str = f"_Stub_{self.stub_idx}"
|
||||
self.stub_idx += 1
|
||||
return name
|
||||
|
||||
def new_type_var_name(self) -> str:
|
||||
name: str = f"_T{self.type_var_idx}"
|
||||
self.type_var_idx += 1
|
||||
return name
|
||||
|
||||
def add_stub(self, stub: ast.stmt):
|
||||
self.stubs.append(stub)
|
||||
|
||||
def add_typing_import(self, name: str):
|
||||
self.typing_imports.add(name)
|
||||
|
||||
def define_type_vars(self, vars: list[TypeVar]) -> list[TypeVar]:
|
||||
vars2: list[TypeVar] = []
|
||||
for var in vars:
|
||||
vars2.append(self.define_type_var(var))
|
||||
return vars2
|
||||
|
||||
def define_type_var(self, var: TypeVar) -> TypeVar:
|
||||
name: str = self.new_type_var_name()
|
||||
self.add_typing_import("TypeVar")
|
||||
self.add_stub(
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id=name)],
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="TypeVar"),
|
||||
args=[
|
||||
ast.Constant(value=name),
|
||||
],
|
||||
keywords=(
|
||||
[]
|
||||
if var.bound is None
|
||||
else [
|
||||
ast.keyword(
|
||||
arg="bound",
|
||||
value=self.dump_type(var.bound),
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
return TypeVar(name=name, bound=None)
|
||||
@@ -3,6 +3,7 @@ from typing import Optional
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.midas import (
|
||||
BinaryExpr,
|
||||
CallExpr,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
Expr,
|
||||
@@ -17,6 +18,7 @@ from midas.ast.midas import (
|
||||
MemberKind,
|
||||
MemberStmt,
|
||||
NamedType,
|
||||
ParamSpec,
|
||||
PredicateStmt,
|
||||
Stmt,
|
||||
Type,
|
||||
@@ -265,6 +267,9 @@ class MidasParser(Parser):
|
||||
Returns:
|
||||
Expr: the parsed constraint expression
|
||||
"""
|
||||
return self.expression()
|
||||
|
||||
def expression(self) -> Expr:
|
||||
return self.and_()
|
||||
|
||||
def and_(self) -> Expr:
|
||||
@@ -331,7 +336,55 @@ class MidasParser(Parser):
|
||||
right: Expr = self.unary()
|
||||
location: Location = Location.span(operator.get_location(), right.location)
|
||||
return UnaryExpr(location=location, operator=operator, right=right)
|
||||
return self.reference()
|
||||
return self.call()
|
||||
|
||||
def call(self) -> Expr:
|
||||
expr: Expr = self.reference()
|
||||
while self.match(TokenType.LEFT_PAREN):
|
||||
expr = self.finish_call(expr)
|
||||
return expr
|
||||
|
||||
def finish_call(self, callee: Expr) -> Expr:
|
||||
l_paren: Token = self.previous()
|
||||
pos_args: list[Expr] = []
|
||||
kw_args: dict[str, Expr] = {}
|
||||
keywords: bool = False
|
||||
while not self.match(TokenType.RIGHT_PAREN):
|
||||
if self.check_identifier() and self.check_next(TokenType.EQUAL):
|
||||
keywords = True
|
||||
keyword: Token = self.advance()
|
||||
value: Expr = self.expression()
|
||||
name: str = keyword.lexeme
|
||||
if name in kw_args:
|
||||
self.error(
|
||||
self.peek(),
|
||||
f"Multiple values passed for '{name}', only the last occurrence will be used",
|
||||
)
|
||||
kw_args[name] = value
|
||||
else:
|
||||
value = self.expression()
|
||||
if self.check(TokenType.EQUAL):
|
||||
if keywords:
|
||||
raise self.error(self.peek(), "Invalid keyword argument name")
|
||||
else:
|
||||
raise self.error(
|
||||
self.peek(),
|
||||
"Cannot pass positional arguments after a keyword argument",
|
||||
)
|
||||
pos_args.append(value)
|
||||
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
|
||||
r_paren: Token = self.consume(
|
||||
TokenType.RIGHT_PAREN, "Expected ')' after arguments."
|
||||
)
|
||||
return CallExpr(
|
||||
location=l_paren.location_to(r_paren),
|
||||
callee=callee,
|
||||
arguments=pos_args,
|
||||
keywords=kw_args,
|
||||
)
|
||||
|
||||
def reference(self) -> Expr:
|
||||
"""Parse an attribute access expression or a simpler expression
|
||||
@@ -453,23 +506,35 @@ class MidasParser(Parser):
|
||||
PredicateStmt: the parsed predicate declaration statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
|
||||
name: Token = self.consume_identifier("Expected predicate name")
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
||||
subject: Token = self.consume_identifier("Expected subject name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
||||
|
||||
params: list[ParamSpec] = []
|
||||
while self.check(TokenType.LEFT_PAREN):
|
||||
params.append(self.function_args())
|
||||
|
||||
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
|
||||
condition: Expr = self.constraint()
|
||||
body: Expr = self.constraint()
|
||||
return PredicateStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
name=name,
|
||||
subject=subject,
|
||||
type=type,
|
||||
condition=condition,
|
||||
params=params,
|
||||
body=body,
|
||||
)
|
||||
|
||||
def function(self) -> FunctionType:
|
||||
params: ParamSpec = self.function_args()
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
result: Type = self.type_expr()
|
||||
|
||||
return FunctionType(
|
||||
location=params.l_paren.location_to(self.previous()),
|
||||
params=params,
|
||||
returns=result,
|
||||
)
|
||||
|
||||
def function_args(self) -> ParamSpec:
|
||||
l_paren: Token = self.consume(
|
||||
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
||||
)
|
||||
@@ -526,14 +591,4 @@ class MidasParser(Parser):
|
||||
self.error(token, "Unnamed mixed argument")
|
||||
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
result: Type = self.type_expr()
|
||||
|
||||
return FunctionType(
|
||||
location=l_paren.location_to(self.previous()),
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=result,
|
||||
)
|
||||
return ParamSpec(l_paren=l_paren, pos=pos_args, mixed=args, kw=kw_args)
|
||||
|
||||
@@ -9,13 +9,13 @@ Module(
|
||||
level=0),
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='__midas_alias_0__')],
|
||||
Name(id='__midas_a0__')],
|
||||
value=Constant(value=123.45)),
|
||||
Assert(
|
||||
test=Call(
|
||||
func=Name(id='isinstance'),
|
||||
args=[
|
||||
Name(id='__midas_alias_0__'),
|
||||
Name(id='__midas_a0__'),
|
||||
Name(id='float')],
|
||||
keywords=[]),
|
||||
msg=JoinedStr(
|
||||
@@ -26,7 +26,7 @@ Module(
|
||||
value=Call(
|
||||
func=Name(id='type'),
|
||||
args=[
|
||||
Name(id='__midas_alias_0__')],
|
||||
Name(id='__midas_a0__')],
|
||||
keywords=[]),
|
||||
attr='__name__'),
|
||||
conversion=-1),
|
||||
@@ -34,19 +34,19 @@ Module(
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='distance')],
|
||||
value=Name(id='__midas_alias_0__')),
|
||||
value=Name(id='__midas_a0__')),
|
||||
Delete(
|
||||
targets=[
|
||||
Name(id='__midas_alias_0__')]),
|
||||
Name(id='__midas_a0__')]),
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='__midas_alias_1__')],
|
||||
Name(id='__midas_a1__')],
|
||||
value=Constant(value=6.7)),
|
||||
Assert(
|
||||
test=Call(
|
||||
func=Name(id='isinstance'),
|
||||
args=[
|
||||
Name(id='__midas_alias_1__'),
|
||||
Name(id='__midas_a1__'),
|
||||
Name(id='float')],
|
||||
keywords=[]),
|
||||
msg=JoinedStr(
|
||||
@@ -57,7 +57,7 @@ Module(
|
||||
value=Call(
|
||||
func=Name(id='type'),
|
||||
args=[
|
||||
Name(id='__midas_alias_1__')],
|
||||
Name(id='__midas_a1__')],
|
||||
keywords=[]),
|
||||
attr='__name__'),
|
||||
conversion=-1),
|
||||
@@ -65,10 +65,10 @@ Module(
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='time')],
|
||||
value=Name(id='__midas_alias_1__')),
|
||||
value=Name(id='__midas_a1__')),
|
||||
Delete(
|
||||
targets=[
|
||||
Name(id='__midas_alias_1__')]),
|
||||
Name(id='__midas_a1__')]),
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='speed')],
|
||||
|
||||
@@ -2582,18 +2582,21 @@
|
||||
"name": "__sub__",
|
||||
"type": {
|
||||
"_type": "FunctionType",
|
||||
"pos_args": [
|
||||
{
|
||||
"name": null,
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "GeoLocation"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"params": {
|
||||
"_type": "ParamSpec",
|
||||
"pos": [
|
||||
{
|
||||
"name": null,
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "GeoLocation"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"_type": "GenericType",
|
||||
"type": {
|
||||
@@ -2673,18 +2676,21 @@
|
||||
"name": "__sub__",
|
||||
"type": {
|
||||
"_type": "FunctionType",
|
||||
"pos_args": [
|
||||
{
|
||||
"name": null,
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "Latitude"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"params": {
|
||||
"_type": "ParamSpec",
|
||||
"pos": [
|
||||
{
|
||||
"name": null,
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "Latitude"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"_type": "GenericType",
|
||||
"type": {
|
||||
@@ -2713,18 +2719,21 @@
|
||||
"name": "__sub__",
|
||||
"type": {
|
||||
"_type": "FunctionType",
|
||||
"pos_args": [
|
||||
{
|
||||
"name": null,
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "Longitude"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"params": {
|
||||
"_type": "ParamSpec",
|
||||
"pos": [
|
||||
{
|
||||
"name": null,
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "Longitude"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"_type": "GenericType",
|
||||
"type": {
|
||||
@@ -2745,12 +2754,24 @@
|
||||
{
|
||||
"_type": "PredicateStmt",
|
||||
"name": "Positive",
|
||||
"subject": "v",
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "float"
|
||||
},
|
||||
"condition": {
|
||||
"params": [
|
||||
{
|
||||
"_type": "ParamSpec",
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"name": "v",
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw": []
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
@@ -2766,12 +2787,24 @@
|
||||
{
|
||||
"_type": "PredicateStmt",
|
||||
"name": "StrictlyPositive",
|
||||
"subject": "v",
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "float"
|
||||
},
|
||||
"condition": {
|
||||
"params": [
|
||||
{
|
||||
"_type": "ParamSpec",
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"name": "v",
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw": []
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
@@ -2787,12 +2820,24 @@
|
||||
{
|
||||
"_type": "PredicateStmt",
|
||||
"name": "Equatorial",
|
||||
"subject": "loc",
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "GeoLocation"
|
||||
},
|
||||
"condition": {
|
||||
"params": [
|
||||
{
|
||||
"_type": "ParamSpec",
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"name": "loc",
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "GeoLocation"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw": []
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"_type": "GroupingExpr",
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
@@ -2827,12 +2872,24 @@
|
||||
{
|
||||
"_type": "PredicateStmt",
|
||||
"name": "Arctic",
|
||||
"subject": "loc",
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "GeoLocation"
|
||||
},
|
||||
"condition": {
|
||||
"params": [
|
||||
{
|
||||
"_type": "ParamSpec",
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"name": "loc",
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "GeoLocation"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw": []
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"_type": "GroupingExpr",
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
|
||||
@@ -45,7 +45,7 @@ class GeneratorTester(Tester):
|
||||
typed_ast: TypedAST = checker.type_check(path)
|
||||
|
||||
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
|
||||
generator = Generator(workdir=path.parent)
|
||||
generator = Generator(workdir=path.parent, types=checker.types)
|
||||
result.compiled_ast = generator.generate_ast(typed_ast, path)
|
||||
|
||||
return result
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Optional, Sequence
|
||||
|
||||
from midas.ast.midas import (
|
||||
BinaryExpr,
|
||||
CallExpr,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
Expr,
|
||||
@@ -15,6 +16,7 @@ from midas.ast.midas import (
|
||||
LogicalExpr,
|
||||
MemberStmt,
|
||||
NamedType,
|
||||
ParamSpec,
|
||||
PredicateStmt,
|
||||
Stmt,
|
||||
Type,
|
||||
@@ -78,9 +80,8 @@ class MidasAstJsonSerializer(
|
||||
return {
|
||||
"_type": "PredicateStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"subject": stmt.subject.lexeme,
|
||||
"type": stmt.type.accept(self),
|
||||
"condition": stmt.condition.accept(self),
|
||||
"params": [self._serialize_param_spec(spec) for spec in stmt.params],
|
||||
"body": stmt.body.accept(self),
|
||||
}
|
||||
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
||||
@@ -106,6 +107,14 @@ class MidasAstJsonSerializer(
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_call_expr(self, expr: CallExpr) -> dict:
|
||||
return {
|
||||
"_type": "CallExpr",
|
||||
"callee": expr.callee.accept(self),
|
||||
"arguments": self._serialize_list(expr.arguments),
|
||||
"keywords": {name: arg.accept(self) for name, arg in expr.keywords.items()},
|
||||
}
|
||||
|
||||
def visit_get_expr(self, expr: GetExpr) -> dict:
|
||||
return {
|
||||
"_type": "GetExpr",
|
||||
@@ -163,15 +172,21 @@ class MidasAstJsonSerializer(
|
||||
def visit_function_type(self, type: FunctionType) -> dict:
|
||||
return {
|
||||
"_type": "FunctionType",
|
||||
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
|
||||
"args": [self._serialize_func_arg(arg) for arg in type.args],
|
||||
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
|
||||
"params": self._serialize_param_spec(type.params),
|
||||
"returns": type.returns.accept(self),
|
||||
}
|
||||
|
||||
def _serialize_param_spec(self, spec: ParamSpec) -> dict:
|
||||
return {
|
||||
"_type": "ParamSpec",
|
||||
"pos": [self._serialize_func_arg(arg) for arg in spec.pos],
|
||||
"mixed": [self._serialize_func_arg(arg) for arg in spec.mixed],
|
||||
"kw": [self._serialize_func_arg(arg) for arg in spec.kw],
|
||||
}
|
||||
|
||||
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
|
||||
return {
|
||||
"name": arg.name,
|
||||
"name": arg.name.lexeme if arg.name is not None else None,
|
||||
"type": arg.type.accept(self),
|
||||
"required": arg.required,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user