Compare commits
5 Commits
776a3fb86c
...
feat/stubs
| Author | SHA1 | Date | |
|---|---|---|---|
|
11422d4364
|
|||
|
e8f8a5ca2f
|
|||
|
df8d71c0a9
|
|||
|
e4fb142f99
|
|||
|
2f8f9d633b
|
23
gen/midas.py
23
gen/midas.py
@@ -26,14 +26,6 @@ class MemberKind(Enum):
|
||||
METHOD = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
l_paren: Token
|
||||
pos: list[FunctionType.Argument]
|
||||
mixed: list[FunctionType.Argument]
|
||||
kw: list[FunctionType.Argument]
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
@@ -58,8 +50,9 @@ class ExtendStmt:
|
||||
|
||||
class PredicateStmt:
|
||||
name: Token
|
||||
params: list[ParamSpec]
|
||||
body: Expr
|
||||
subject: Token
|
||||
type: Type
|
||||
condition: Expr
|
||||
|
||||
|
||||
###<
|
||||
@@ -85,12 +78,6 @@ class UnaryExpr:
|
||||
right: Expr
|
||||
|
||||
|
||||
class CallExpr:
|
||||
callee: Expr
|
||||
arguments: list[Expr]
|
||||
keywords: dict[str, Expr]
|
||||
|
||||
|
||||
class GetExpr:
|
||||
expr: Expr
|
||||
name: Token
|
||||
@@ -141,7 +128,9 @@ class ExtensionType:
|
||||
|
||||
|
||||
class FunctionType:
|
||||
params: ParamSpec
|
||||
pos_args: list[Argument]
|
||||
args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
|
||||
@@ -27,14 +27,6 @@ class MemberKind(Enum):
|
||||
METHOD = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
l_paren: Token
|
||||
pos: list[FunctionType.Argument]
|
||||
mixed: list[FunctionType.Argument]
|
||||
kw: list[FunctionType.Argument]
|
||||
|
||||
|
||||
##############
|
||||
# Statements #
|
||||
##############
|
||||
@@ -94,8 +86,9 @@ class ExtendStmt(Stmt):
|
||||
@dataclass(frozen=True)
|
||||
class PredicateStmt(Stmt):
|
||||
name: Token
|
||||
params: list[ParamSpec]
|
||||
body: Expr
|
||||
subject: Token
|
||||
type: Type
|
||||
condition: Expr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_predicate_stmt(self)
|
||||
@@ -123,9 +116,6 @@ class Expr(ABC):
|
||||
@abstractmethod
|
||||
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_call_expr(self, expr: CallExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_get_expr(self, expr: GetExpr) -> T: ...
|
||||
|
||||
@@ -171,16 +161,6 @@ class UnaryExpr(Expr):
|
||||
return visitor.visit_unary_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CallExpr(Expr):
|
||||
callee: Expr
|
||||
arguments: list[Expr]
|
||||
keywords: dict[str, Expr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_call_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GetExpr(Expr):
|
||||
expr: Expr
|
||||
@@ -299,7 +279,9 @@ class ExtensionType(Type):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionType(Type):
|
||||
params: ParamSpec
|
||||
pos_args: list[Argument]
|
||||
args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
|
||||
@@ -150,17 +150,13 @@ class MidasAstPrinter(
|
||||
self._write_line("PredicateStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
for i, spec in enumerate(stmt.params):
|
||||
self._idx = i
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
self._visit_param_spec(spec)
|
||||
|
||||
self._write_line("body", last=True)
|
||||
self._write_line(f'subject: "{stmt.subject.lexeme}"')
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
stmt.body.accept(self)
|
||||
stmt.type.accept(self)
|
||||
self._write_line("condition", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.condition.accept(self)
|
||||
|
||||
# Expressions
|
||||
|
||||
@@ -199,29 +195,6 @@ class MidasAstPrinter(
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> None:
|
||||
self._write_line("CallExpr")
|
||||
with self._child_level():
|
||||
self._write_line("callee")
|
||||
with self._child_level(single=True):
|
||||
expr.callee.accept(self)
|
||||
self._write_line("arguments")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(expr.arguments):
|
||||
self._idx = i
|
||||
if i == len(expr.arguments) - 1:
|
||||
self._mark_last()
|
||||
arg.accept(self)
|
||||
self._write_line("keywords", last=True)
|
||||
with self._child_level():
|
||||
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||
self._idx = i
|
||||
if i == len(expr.keywords) - 1:
|
||||
self._mark_last()
|
||||
self._write_line(name)
|
||||
with self._child_level(single=True):
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
self._write_line("GetExpr")
|
||||
with self._child_level():
|
||||
@@ -303,41 +276,34 @@ class MidasAstPrinter(
|
||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||
self._write_line("FunctionType")
|
||||
with self._child_level():
|
||||
self._write_line("params")
|
||||
with self._child_level(single=True):
|
||||
self._visit_param_spec(type.params)
|
||||
self._write_line("pos_args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.pos_args):
|
||||
self._idx = i
|
||||
if i == len(type.pos_args) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.args):
|
||||
self._idx = i
|
||||
if i == len(type.args) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("kw_args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.kw_args):
|
||||
self._idx = i
|
||||
if i == len(type.kw_args) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("returns", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.returns.accept(self)
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
|
||||
self._write_line("ParamSpec")
|
||||
with self._child_level():
|
||||
self._write_line("pos")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(spec.pos):
|
||||
self._idx = i
|
||||
if i == len(spec.pos) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("mixed")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(spec.mixed):
|
||||
self._idx = i
|
||||
if i == len(spec.mixed) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("kw", last=True)
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(spec.kw):
|
||||
self._idx = i
|
||||
if i == len(spec.kw) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
|
||||
self._write_line("Argument")
|
||||
with self._child_level():
|
||||
@@ -401,9 +367,10 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
name: str = stmt.name.lexeme
|
||||
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
|
||||
body: str = stmt.body.accept(self)
|
||||
return self.indented(f"predicate {name}{sig} = {body}")
|
||||
subject: str = stmt.subject.lexeme
|
||||
type: str = stmt.type.accept(self)
|
||||
condition: str = stmt.condition.accept(self)
|
||||
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
@@ -422,12 +389,6 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{operator}{right}"
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> str:
|
||||
args: list[str] = [arg.accept(self) for arg in expr.arguments] + [
|
||||
f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items()
|
||||
]
|
||||
return f"{expr.callee.accept(self)}({', '.join(args)})"
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
expr_: str = expr.expr.accept(self)
|
||||
name: str = expr.name.lexeme
|
||||
@@ -475,13 +436,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||
spec: str = self._visit_param_spec(type.params)
|
||||
return f"fn {spec} -> {type.returns.accept(self)}"
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
|
||||
pos_args: list[str] = [self._print_arg(arg) for arg in spec.pos]
|
||||
mixed_args: list[str] = [self._print_arg(arg) for arg in spec.mixed]
|
||||
kw_args: list[str] = [self._print_arg(arg) for arg in spec.kw]
|
||||
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
|
||||
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
|
||||
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
|
||||
args: list[str] = pos_args
|
||||
|
||||
if len(pos_args) != 0:
|
||||
@@ -490,7 +447,8 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
if len(kw_args) != 0:
|
||||
args.append("*")
|
||||
args += kw_args
|
||||
return f"({', '.join(args)})"
|
||||
|
||||
return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}"
|
||||
|
||||
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
|
||||
res: str = ""
|
||||
|
||||
@@ -15,7 +15,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||
"object": {"float", "list", "dict"},
|
||||
"float": {"int"},
|
||||
"int": {"bool"},
|
||||
}
|
||||
|
||||
@@ -1,64 +1,27 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.builtins import define_builtins
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
|
||||
from midas.checker.preamble import Preamble
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter, Reporter
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
AppliedType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
ExtensionType,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
Predicate,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TypedParamSpec:
|
||||
pos: list[Function.Argument]
|
||||
mixed: list[Function.Argument]
|
||||
kw: list[Function.Argument]
|
||||
|
||||
|
||||
TypedExpr = tuple[m.Expr, Type]
|
||||
|
||||
|
||||
class ReturnException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MappedArgument:
|
||||
expr: m.Expr
|
||||
type: Type
|
||||
argument: Function.Argument
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class OverloadCandidate:
|
||||
function: Function
|
||||
mapped: list[MappedArgument]
|
||||
|
||||
|
||||
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
|
||||
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||
|
||||
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||
@@ -68,18 +31,12 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
self.types: TypesRegistry = types
|
||||
self._local_variables: dict[str, TypeVar] = {}
|
||||
|
||||
self._predicate_params: dict[str, Type] = {}
|
||||
|
||||
self._current_name: Optional[str] = None
|
||||
|
||||
define_builtins(self.types)
|
||||
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
|
||||
self.process(builtins_path.read_text(), str(builtins_path))
|
||||
|
||||
self._bool: Type = self.get_type("bool")
|
||||
|
||||
self._preamble: Environment = Preamble(self.types)
|
||||
|
||||
def process(self, source: str, path: Optional[str]):
|
||||
self.reporter = self.reporter.for_file(path)
|
||||
lexer: MidasLexer = MidasLexer(source)
|
||||
@@ -90,10 +47,6 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
self.reporter.error(error.token.get_location(), error.message)
|
||||
self.resolve(stmts)
|
||||
|
||||
def type_of(self, expr: m.Expr) -> Type:
|
||||
type: Type = expr.accept(self)
|
||||
return type
|
||||
|
||||
def get_type(self, name: str) -> Type:
|
||||
"""Get a type from its name
|
||||
|
||||
@@ -110,19 +63,6 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
return self._local_variables[name]
|
||||
return self.types.get_type(name)
|
||||
|
||||
def get_variable(self, name: str) -> Type:
|
||||
if name in self._predicate_params:
|
||||
return self._predicate_params[name]
|
||||
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
|
||||
if predicate is not None:
|
||||
return predicate.type
|
||||
|
||||
global_: Optional[Type] = self._preamble.get(name)
|
||||
if global_ is not None:
|
||||
return global_
|
||||
|
||||
raise NameError(f"Unknown variable '{name}'")
|
||||
|
||||
def resolve(self, stmts: list[m.Stmt]):
|
||||
"""Process a sequence of statements
|
||||
|
||||
@@ -132,11 +72,6 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
for stmt in stmts:
|
||||
stmt.accept(self)
|
||||
|
||||
def assert_bool(self, expr: m.Expr):
|
||||
type: Type = self.type_of(expr)
|
||||
if not self.types.is_subtype(type, self._bool):
|
||||
self.reporter.error(expr.location, f"Must be a boolean but is {type}")
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
name: str = stmt.name.lexeme
|
||||
self._current_name = name
|
||||
@@ -167,167 +102,35 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
base_name,
|
||||
member.name.lexeme,
|
||||
member_type,
|
||||
member.kind == m.MemberKind.METHOD,
|
||||
member.kind,
|
||||
)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||
for spec in stmt.params:
|
||||
for param in spec.mixed:
|
||||
assert param.name is not None
|
||||
self._predicate_params[param.name.lexeme] = param.type.accept(self)
|
||||
self.reporter.warning(stmt.location, "PredicateStmt not yet supported")
|
||||
|
||||
type: Type = self.type_of(stmt.body)
|
||||
params: list[TypedParamSpec] = [
|
||||
self._visit_param_spec(spec) for spec in stmt.params
|
||||
]
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
||||
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
|
||||
|
||||
if not self._is_valid_predicate(type):
|
||||
self.reporter.error(
|
||||
stmt.body.location,
|
||||
f"Predicate function body must evaluate to a boolean, got {type}",
|
||||
)
|
||||
if len(params) != 0:
|
||||
type = self._bool
|
||||
for spec in reversed(params):
|
||||
type = Function(
|
||||
pos_args=spec.pos,
|
||||
args=spec.mixed,
|
||||
kw_args=spec.kw,
|
||||
returns=type,
|
||||
)
|
||||
self._predicate_params = {}
|
||||
self.types.define_predicate(
|
||||
stmt.name.lexeme,
|
||||
Predicate(
|
||||
type=type,
|
||||
body=stmt.body,
|
||||
alias=len(params) == 0,
|
||||
),
|
||||
)
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
|
||||
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
|
||||
|
||||
def _is_valid_predicate(self, body: Type) -> bool:
|
||||
match body:
|
||||
case Function(returns=returns):
|
||||
return self._is_valid_predicate(returns)
|
||||
case _ if self.types.is_subtype(body, self._bool):
|
||||
return True
|
||||
case _:
|
||||
return False
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
|
||||
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type:
|
||||
self.assert_bool(expr.left)
|
||||
self.assert_bool(expr.right)
|
||||
return self._bool
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> None:
|
||||
self.reporter.warning(expr.location, "GetExpr not yet supported")
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
|
||||
method: Optional[str] = MIDAS_BINARY_METHODS.get(expr.operator.type)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator.lexeme}")
|
||||
self.reporter.warning(
|
||||
expr.location, f"Unsupported operator {expr.operator.lexeme}"
|
||||
)
|
||||
return UnknownType()
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
|
||||
self.reporter.warning(expr.location, "VariableExpr not yet supported")
|
||||
|
||||
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||
|
||||
def _visit_binary_expr(
|
||||
self, location: Location, left_expr: m.Expr, right_expr: m.Expr, method: str
|
||||
) -> Type:
|
||||
left: Type = self.type_of(left_expr)
|
||||
right: Type = self.type_of(right_expr)
|
||||
|
||||
operation: Optional[Type] = self.types.lookup_member(left, method)
|
||||
if operation is None:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Undefined operation {method} between {left} and {right}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
result: Optional[Type] = self._get_call_result(
|
||||
location,
|
||||
operation,
|
||||
[(right_expr, right)],
|
||||
{},
|
||||
)
|
||||
return result or UnknownType()
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
|
||||
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator.lexeme}")
|
||||
self.reporter.warning(
|
||||
expr.location, f"Unsupported operator {expr.operator.lexeme}"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
operand: Type = self.type_of(expr.right)
|
||||
operation: Optional[Type] = self.types.lookup_member(operand, method)
|
||||
if operation is None:
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Undefined operation {method} for {operand}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
result: Optional[Type] = self._get_call_result(
|
||||
expr.location,
|
||||
operation,
|
||||
[],
|
||||
{},
|
||||
)
|
||||
return result or UnknownType()
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> Type:
|
||||
callee: Type = expr.callee.accept(self)
|
||||
positional: list[TypedExpr] = [
|
||||
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||
]
|
||||
keywords: dict[str, TypedExpr] = {
|
||||
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
||||
}
|
||||
return (
|
||||
self._get_call_result(
|
||||
expr.location,
|
||||
callee,
|
||||
positional,
|
||||
keywords,
|
||||
)
|
||||
or UnknownType()
|
||||
)
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> Type:
|
||||
object: Type = expr.expr.accept(self)
|
||||
member: Optional[Type] = self.types.lookup_member(object, expr.name.lexeme)
|
||||
if member is None:
|
||||
self.reporter.error(
|
||||
expr.location, f"Unknown member '{expr.name.lexeme}' of {object}"
|
||||
)
|
||||
return UnknownType()
|
||||
return member
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> Type:
|
||||
return self.get_variable(expr.name.lexeme)
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
||||
return expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type:
|
||||
match expr.value:
|
||||
case bool(): # Must be before int
|
||||
return self.types.get_type("bool")
|
||||
case int():
|
||||
return self.types.get_type("int")
|
||||
case float():
|
||||
return self.types.get_type("float")
|
||||
case str():
|
||||
return self.types.get_type("str")
|
||||
case _:
|
||||
self.reporter.warning(expr.location, f"Unknown literal {expr}")
|
||||
return UnknownType()
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
||||
self.reporter.warning(expr.location, "LiteralExpr not yet supported")
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type:
|
||||
return self.get_variable("_")
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||
self.reporter.warning(expr.location, "WildcardExpr not yet supported")
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||
name: str = type.name.lexeme
|
||||
@@ -350,10 +153,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
return UnknownType()
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||
return ConstraintType(
|
||||
type=type.type.accept(self),
|
||||
constraint=type.constraint,
|
||||
)
|
||||
type_: Type = type.type.accept(self)
|
||||
type.constraint.accept(self)
|
||||
# TODO
|
||||
return UnknownType()
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
|
||||
return ComplexType(
|
||||
@@ -369,17 +172,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> Type:
|
||||
params: TypedParamSpec = self._visit_param_spec(type.params)
|
||||
return Function(
|
||||
pos_args=params.pos,
|
||||
args=params.mixed,
|
||||
kw_args=params.kw,
|
||||
returns=type.returns.accept(self),
|
||||
)
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec:
|
||||
n_pos: int = len(spec.pos)
|
||||
n_mixed: int = len(spec.mixed)
|
||||
n_pos_args: int = len(type.pos_args)
|
||||
n_args: int = len(type.args)
|
||||
|
||||
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
|
||||
return Function.Argument(
|
||||
@@ -389,10 +183,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
required=arg.required,
|
||||
)
|
||||
|
||||
return TypedParamSpec(
|
||||
pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)],
|
||||
mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)],
|
||||
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
|
||||
return Function(
|
||||
pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
|
||||
args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
|
||||
kw_args=[
|
||||
process_arg(arg, i + n_pos_args + n_args)
|
||||
for i, arg in enumerate(type.kw_args)
|
||||
],
|
||||
returns=type.returns.accept(self),
|
||||
)
|
||||
|
||||
def _resolve_type_params(self, params: list[m.TypeParam]):
|
||||
@@ -406,343 +204,3 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
self._local_variables[name] = var
|
||||
vars.append(var)
|
||||
return vars
|
||||
|
||||
def _get_call_result(
|
||||
self,
|
||||
location: Location,
|
||||
callee: Type,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
report_errors: bool = True,
|
||||
) -> Optional[Type]:
|
||||
"""Get the result type of a function call
|
||||
|
||||
If the function has overloads, the function will try to resolve the
|
||||
appropriate signature.
|
||||
Argument types are matched to the defined parameters.
|
||||
The function doesn't take the raw expression as a parameter to accommodate
|
||||
for desugared calls such as for operators.
|
||||
|
||||
Args:
|
||||
location (Location): the call location
|
||||
callee (Type): the called function
|
||||
positional (list[TypedExpr]): the list positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Type: the return type of the call, or `None` if either
|
||||
the call is invalid or no overload matched the arguments uniquely
|
||||
"""
|
||||
match callee:
|
||||
case Function() as function:
|
||||
valid: bool
|
||||
mapped: list[MappedArgument]
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function, location, positional, keywords
|
||||
)
|
||||
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
||||
if not valid:
|
||||
return None
|
||||
return function.returns
|
||||
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
function = self._match_overload(
|
||||
overloads, location, positional, keywords, report_errors
|
||||
)
|
||||
if function is None:
|
||||
return None
|
||||
return function.returns
|
||||
|
||||
case AppliedType(body=body):
|
||||
return self._get_call_result(
|
||||
location, body, positional, keywords, report_errors
|
||||
)
|
||||
|
||||
case UnknownType():
|
||||
return UnknownType()
|
||||
|
||||
case _:
|
||||
if report_errors:
|
||||
self.reporter.error(location, f"{callee} is not callable")
|
||||
return None
|
||||
|
||||
def _are_arguments_valid(
|
||||
self,
|
||||
arguments: list[MappedArgument],
|
||||
report_errors: bool = True,
|
||||
) -> bool:
|
||||
"""Check whether the passed argument types correspond to their matched parameter definitions
|
||||
|
||||
Args:
|
||||
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
||||
"""
|
||||
valid: bool = True
|
||||
for arg in arguments:
|
||||
if not self.types.is_subtype(arg.type, arg.argument.type):
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
)
|
||||
valid = False
|
||||
return valid
|
||||
|
||||
def _match_overload(
|
||||
self,
|
||||
overloads: list[Type],
|
||||
location: Location,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
report_errors: bool = True,
|
||||
) -> Optional[Function]:
|
||||
"""Try and resolve the appropriate overload for the given arguments
|
||||
|
||||
Args:
|
||||
overloads (list[Type]): the list of possible overloads
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional[Function]: the resolved function signature if it can be
|
||||
determined unambiguously, or `None`.
|
||||
"""
|
||||
candidates: list[OverloadCandidate] = []
|
||||
for overload in overloads:
|
||||
function: Type = unfold_type(overload)
|
||||
if not isinstance(function, Function):
|
||||
if report_errors:
|
||||
self.logger.error(
|
||||
f"Overload is not a function: {overload} is {function}"
|
||||
)
|
||||
continue
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function=function,
|
||||
location=location,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
report_errors=False,
|
||||
)
|
||||
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||
candidates.append(
|
||||
OverloadCandidate(
|
||||
function=function,
|
||||
mapped=mapped,
|
||||
)
|
||||
)
|
||||
|
||||
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||
kw_types: str = ", ".join(
|
||||
f"{name}: {type}" for name, (_, type) in keywords.items()
|
||||
)
|
||||
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||
|
||||
n_candidates: int = len(candidates)
|
||||
|
||||
# Exactly 1 match -> return it
|
||||
if n_candidates == 1:
|
||||
return candidates[0].function
|
||||
|
||||
# No match -> invalid call
|
||||
if n_candidates == 0:
|
||||
overloads_str: str = ", ".join(map(str, overloads))
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"No matching overload in [{overloads_str}] {for_args}",
|
||||
)
|
||||
return None
|
||||
|
||||
# Multiple matches -> see if one <: all others (more specific)
|
||||
for i1, c1 in enumerate(candidates):
|
||||
mapped1: list[MappedArgument] = c1.mapped
|
||||
best_match: bool = True
|
||||
for i2, c2 in enumerate(candidates):
|
||||
if i1 == i2:
|
||||
continue
|
||||
mapped2: list[MappedArgument] = c2.mapped
|
||||
if not self._are_mapped_subtypes(mapped1, mapped2):
|
||||
best_match = False
|
||||
break
|
||||
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
||||
if best_match:
|
||||
return c1.function
|
||||
|
||||
candidates_str: str = ", ".join(
|
||||
str(candidate.function) for candidate in candidates
|
||||
)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Multiple matching overloads {for_args}: {candidates_str}",
|
||||
)
|
||||
return None
|
||||
|
||||
def map_call_arguments(
|
||||
self,
|
||||
function: Function,
|
||||
location: Location,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
report_errors: bool = True,
|
||||
) -> tuple[bool, list[MappedArgument]]:
|
||||
"""Map call arguments to a function's parameters as defined in its signature
|
||||
|
||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||
with the arguments passed at the call site
|
||||
|
||||
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
||||
unless `report_errors` is set to `False`
|
||||
|
||||
Args:
|
||||
function (Function): the function definition
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
||||
the call is valid and the list of mapped arguments
|
||||
"""
|
||||
set_args: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
arg.name for arg in function.pos_args + function.args if arg.required
|
||||
]
|
||||
required_keyword: list[str] = [
|
||||
arg.name for arg in function.kw_args if arg.required
|
||||
]
|
||||
|
||||
mapped: list[MappedArgument] = []
|
||||
|
||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||
mixed_params: list[Function.Argument] = list(function.args)
|
||||
kw_params: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.kw_args
|
||||
}
|
||||
|
||||
valid_call: bool = True
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Argument
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg[0].location, "Too many positional arguments"
|
||||
)
|
||||
valid_call = False
|
||||
break
|
||||
name: str = param.name
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Argument
|
||||
if name not in kw_params:
|
||||
if report_errors:
|
||||
if name in set_args:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
)
|
||||
else:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||
)
|
||||
valid_call = False
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
def join_args(args: list[str]) -> str:
|
||||
args = list(map(lambda a: f"'{a}'", args))
|
||||
if len(args) == 0:
|
||||
return ""
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
args: str = join_args(required_positional)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required positional argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
args: str = join_args(required_keyword)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required keyword argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
return valid_call, mapped
|
||||
|
||||
def _are_mapped_subtypes(
|
||||
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
|
||||
) -> bool:
|
||||
"""Check whether the given argument mappings are subtype/supertype of one another
|
||||
|
||||
This function checks whether the argument mappings `mapped1` are subtypes
|
||||
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
||||
of the corresponding parameter in `mapped2`, `False` is returned.
|
||||
|
||||
This is used to check whether a given overload is
|
||||
a more specific function/ a subtype of another.
|
||||
|
||||
Args:
|
||||
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
||||
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
||||
|
||||
Returns:
|
||||
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
||||
"""
|
||||
by_expr: dict[m.Expr, Type] = {}
|
||||
for arg in mapped1:
|
||||
by_expr[arg.expr] = arg.argument.type
|
||||
|
||||
for arg in mapped2:
|
||||
type2: Type = arg.argument.type
|
||||
type1: Type = by_expr[arg.expr]
|
||||
if not self.types.is_subtype(type1, type2):
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import ast
|
||||
from typing import Type
|
||||
|
||||
from midas.lexer.token import TokenType
|
||||
|
||||
PY_OPERATOR_METHODS: dict[Type[ast.operator], str] = {
|
||||
OPERATOR_METHODS: dict[Type[ast.operator], str] = {
|
||||
ast.Add: "__add__",
|
||||
ast.Sub: "__sub__",
|
||||
ast.Mult: "__mul__",
|
||||
@@ -19,9 +17,9 @@ PY_OPERATOR_METHODS: dict[Type[ast.operator], str] = {
|
||||
ast.FloorDiv: "__floordiv__",
|
||||
}
|
||||
|
||||
PY_COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
||||
COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
||||
ast.Eq: "__eq__",
|
||||
ast.NotEq: "__eq__",
|
||||
# ast.NotEq: "__noteq__",
|
||||
ast.Lt: "__lt__",
|
||||
ast.LtE: "__le__",
|
||||
ast.Gt: "__gt__",
|
||||
@@ -32,40 +30,9 @@ PY_COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
||||
# ast.NotIn: "__notin__",
|
||||
}
|
||||
|
||||
PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = {
|
||||
UNARY_METHODS: dict[Type[ast.unaryop], str] = {
|
||||
ast.Invert: "__invert__",
|
||||
# ast.Not: "",
|
||||
ast.UAdd: "__pos__",
|
||||
ast.USub: "__neg__",
|
||||
}
|
||||
|
||||
|
||||
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
|
||||
# TokenType.PLUS: "__add__",
|
||||
TokenType.MINUS: "__sub__",
|
||||
TokenType.STAR: "__mul__",
|
||||
TokenType.SLASH: "__truediv__",
|
||||
# TokenType.MODULO: "__mod__",
|
||||
# TokenType.POW: "__pow__",
|
||||
# ast.BitOr: "__or__",
|
||||
# ast.BitXor: "__xor__",
|
||||
# ast.BitAnd: "__and__",
|
||||
# ast.FloorDiv: "__floordiv__",
|
||||
TokenType.EQUAL_EQUAL: "__eq__",
|
||||
TokenType.BANG_EQUAL: "__eq__",
|
||||
TokenType.LESS: "__lt__",
|
||||
TokenType.LESS_EQUAL: "__le__",
|
||||
TokenType.GREATER: "__gt__",
|
||||
TokenType.GREATER_EQUAL: "__ge__",
|
||||
# ast.Is: "__is__",
|
||||
# ast.IsNot: "__isnot__",
|
||||
# ast.In: "__in__",
|
||||
# ast.NotIn: "__notin__",
|
||||
}
|
||||
|
||||
MIDAS_UNARY_METHODS: dict[TokenType, str] = {
|
||||
# ast.Invert: "__invert__",
|
||||
# ast.Not: "",
|
||||
# TokenType.PLUS: "__pos__",
|
||||
TokenType.MINUS: "__neg__",
|
||||
}
|
||||
|
||||
@@ -6,16 +6,13 @@ from typing import Optional
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.operators import (
|
||||
PY_COMPARATOR_METHODS,
|
||||
PY_OPERATOR_METHODS,
|
||||
PY_UNARY_METHODS,
|
||||
)
|
||||
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS
|
||||
from midas.checker.preamble import Preamble
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter, Reporter
|
||||
from midas.checker.resolver import Resolver
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
AppliedType,
|
||||
Function,
|
||||
OverloadedFunction,
|
||||
@@ -380,7 +377,7 @@ class PythonTyper(
|
||||
pass
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||
method: Optional[str] = PY_OPERATOR_METHODS.get(expr.operator.__class__)
|
||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.reporter.warning(
|
||||
@@ -391,7 +388,7 @@ class PythonTyper(
|
||||
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||
method: Optional[str] = PY_COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.reporter.warning(
|
||||
@@ -424,7 +421,7 @@ class PythonTyper(
|
||||
return result or UnknownType()
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
||||
method: Optional[str] = PY_UNARY_METHODS.get(expr.operator.__class__)
|
||||
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.reporter.warning(
|
||||
@@ -656,7 +653,7 @@ class PythonTyper(
|
||||
If the function has overloads, the function will try to resolve the
|
||||
appropriate signature.
|
||||
Argument types are matched to the defined parameters.
|
||||
The function doesn't take the raw expression as a parameter to accommodate
|
||||
The function doesn't take the raw expression as a parameter to accomodate
|
||||
for desugared calls such as for operators.
|
||||
|
||||
Args:
|
||||
@@ -698,9 +695,17 @@ class PythonTyper(
|
||||
case UnknownType():
|
||||
return UnknownType()
|
||||
|
||||
case AliasType(type=base):
|
||||
return self._get_call_result(
|
||||
location, base, positional, keywords, report_errors
|
||||
)
|
||||
|
||||
case _:
|
||||
if report_errors:
|
||||
self.reporter.error(location, f"{callee} is not callable")
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"{callee} ({callee.__class__.__name__}) is not callable",
|
||||
)
|
||||
return None
|
||||
|
||||
def _are_arguments_valid(
|
||||
@@ -747,7 +752,7 @@ class PythonTyper(
|
||||
|
||||
Returns:
|
||||
Optional[Function]: the resolved function signature if it can be
|
||||
determined unambiguously, or `None`.
|
||||
determined unambigously, or `None`.
|
||||
"""
|
||||
candidates: list[OverloadCandidate] = []
|
||||
for overload in overloads:
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.midas import MemberKind
|
||||
from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
ExtensionType,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
Predicate,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
@@ -21,12 +21,17 @@ from midas.checker.types import (
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Member:
|
||||
kind: MemberKind
|
||||
type: Type
|
||||
|
||||
|
||||
class TypesRegistry:
|
||||
def __init__(self) -> None:
|
||||
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
|
||||
self._types: dict[str, Type] = {}
|
||||
self._members: dict[str, dict[str, Type]] = {}
|
||||
self._predicates: dict[str, Predicate] = {}
|
||||
self._members: dict[str, dict[str, Member]] = {}
|
||||
|
||||
def get_type(self, name: str) -> Type:
|
||||
"""Get a type from its name
|
||||
@@ -63,31 +68,38 @@ class TypesRegistry:
|
||||
return type
|
||||
|
||||
def define_member(
|
||||
self, type_name: str, member_name: str, member_type: Type, is_method: bool
|
||||
self,
|
||||
type_name: str,
|
||||
member_name: str,
|
||||
member_type: Type,
|
||||
kind: MemberKind,
|
||||
):
|
||||
members: dict[str, Type] = self._members.setdefault(type_name, {})
|
||||
members: dict[str, Member] = self._members.setdefault(type_name, {})
|
||||
if member_name in members:
|
||||
if not is_method:
|
||||
current: Member = members[member_name]
|
||||
if current.kind != kind:
|
||||
self.logger.error(
|
||||
f"Member '{member_name}' already defined for type {type_name}"
|
||||
f"Member '{member_name}' is already defined as a {current.kind},"
|
||||
+ f" cannot define a {kind} with the same name"
|
||||
)
|
||||
return
|
||||
current: Type = members[member_name]
|
||||
if kind != MemberKind.METHOD:
|
||||
self.logger.error(
|
||||
f"Member '{member_name}' already defined for type {type_name},"
|
||||
+ " only methods can be overloaded"
|
||||
)
|
||||
return
|
||||
|
||||
combined: Type
|
||||
match current:
|
||||
match current.type:
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
combined = OverloadedFunction(overloads=overloads + [member_type])
|
||||
case _:
|
||||
combined = OverloadedFunction(overloads=[current, member_type])
|
||||
members[member_name] = combined
|
||||
combined = OverloadedFunction(overloads=[current.type, member_type])
|
||||
members[member_name] = Member(kind=current.kind, type=combined)
|
||||
|
||||
else:
|
||||
members[member_name] = member_type
|
||||
|
||||
def define_predicate(self, name: str, predicate: Predicate):
|
||||
if name in self._predicates:
|
||||
raise ValueError(f"Predicate {name} already defined")
|
||||
self._predicates[name] = predicate
|
||||
members[member_name] = Member(kind=kind, type=member_type)
|
||||
|
||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||
"""Check whether `type1` is a subtype of `type2`
|
||||
@@ -131,9 +143,6 @@ class TypesRegistry:
|
||||
return False
|
||||
return self.is_subtype(bound, type2)
|
||||
|
||||
case (ConstraintType(type=base1), _):
|
||||
return self.is_subtype(base1, type2)
|
||||
|
||||
return False
|
||||
|
||||
# TODO: verify the logic in here
|
||||
@@ -308,13 +317,13 @@ class TypesRegistry:
|
||||
case BaseType(name=name):
|
||||
if name in self._members:
|
||||
if member_name in self._members[name]:
|
||||
return self._members[name][member_name]
|
||||
return self._members[name][member_name].type
|
||||
return None
|
||||
|
||||
case AliasType(name=name, type=base):
|
||||
if name in self._members:
|
||||
if member_name in self._members[name]:
|
||||
return self._members[name][member_name]
|
||||
return self._members[name][member_name].type
|
||||
return self.lookup_member(base, member_name)
|
||||
|
||||
case AppliedType(name=name, body=body, args=args):
|
||||
@@ -328,7 +337,7 @@ class TypesRegistry:
|
||||
}
|
||||
if name in self._members:
|
||||
if member_name in self._members[name]:
|
||||
member_type: Type = self._members[name][member_name]
|
||||
member_type: Type = self._members[name][member_name].type
|
||||
return substitute_typevars(member_type, substitutions)
|
||||
|
||||
member_type2: Optional[Type] = self.lookup_member(body, member_name)
|
||||
@@ -356,6 +365,3 @@ class TypesRegistry:
|
||||
case _:
|
||||
self.logger.debug(f"Can't get member on {type}")
|
||||
return None
|
||||
|
||||
def lookup_predicate(self, name: str) -> Optional[Predicate]:
|
||||
return self._predicates.get(name)
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, assert_never
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.ast.printer import MidasPrinter
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
@@ -133,16 +130,6 @@ class AppliedType:
|
||||
return f"{self.name}[{', '.join(map(str, self.args))}]"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ConstraintType:
|
||||
type: Type
|
||||
constraint: m.Expr
|
||||
|
||||
def __str__(self) -> str:
|
||||
printer = MidasPrinter()
|
||||
return f"{self.type} where {printer.print(self.constraint)}"
|
||||
|
||||
|
||||
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
def sub_argument(arg: Function.Argument):
|
||||
return Function.Argument(
|
||||
@@ -153,6 +140,9 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
)
|
||||
|
||||
match type:
|
||||
case TopType():
|
||||
return type
|
||||
|
||||
case BaseType(name=name) if name in substitutions:
|
||||
return substitutions[name]
|
||||
|
||||
@@ -208,26 +198,31 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
body=substitute_typevars(body, substitutions),
|
||||
)
|
||||
|
||||
case ConstraintType():
|
||||
return ConstraintType(
|
||||
type=substitute_typevars(type.type, substitutions),
|
||||
constraint=type.constraint,
|
||||
)
|
||||
|
||||
case TypeVar(name=name):
|
||||
if name in substitutions:
|
||||
return substitutions[name]
|
||||
raise ValueError(f"Missing TypeVar substitution for {name}")
|
||||
|
||||
case GenericType(name=name, params=params, body=body):
|
||||
params2: list[TypeVar] = []
|
||||
for param in params:
|
||||
param2: Type = substitute_typevars(param, substitutions)
|
||||
if not isinstance(param2, TypeVar):
|
||||
raise ValueError(
|
||||
f"Invalid type parameter substitution, expected TypeVar, got {param2}"
|
||||
)
|
||||
params2.append(param2)
|
||||
return GenericType(
|
||||
name=name,
|
||||
params=params2,
|
||||
body=substitute_typevars(body, substitutions),
|
||||
)
|
||||
|
||||
case UnknownType() | UnitType():
|
||||
return type
|
||||
|
||||
case TopType() | GenericType():
|
||||
raise NotImplementedError(f"Unsupported type {type}")
|
||||
|
||||
# Ensure exhaustiveness
|
||||
case _:
|
||||
assert_never(type)
|
||||
raise NotImplementedError(f"Unsupported type {type}")
|
||||
|
||||
|
||||
def unfold_type(type: Type) -> Type:
|
||||
@@ -238,65 +233,6 @@ def unfold_type(type: Type) -> Type:
|
||||
return type
|
||||
|
||||
|
||||
def to_annotation(type: Type) -> str:
|
||||
def _args_annotation(func: Function) -> str:
|
||||
if len(func.kw_args) != 0:
|
||||
return "..."
|
||||
|
||||
args: str = ", ".join(
|
||||
to_annotation(arg.type) for arg in func.pos_args + func.args
|
||||
)
|
||||
return f"[{args}]"
|
||||
|
||||
match type:
|
||||
case TopType():
|
||||
return "Any"
|
||||
|
||||
case BaseType(name=name):
|
||||
return name
|
||||
|
||||
case AliasType(name=name):
|
||||
return name
|
||||
|
||||
case UnknownType():
|
||||
return "Any"
|
||||
|
||||
case UnitType():
|
||||
return "None"
|
||||
|
||||
case Function(returns=returns):
|
||||
params_annot: str = _args_annotation(type)
|
||||
return f"Callable[{params_annot}, {to_annotation(returns)}]"
|
||||
|
||||
case OverloadedFunction():
|
||||
return "Callable"
|
||||
|
||||
case ComplexType() | ExtensionType():
|
||||
raise NotImplementedError
|
||||
|
||||
case TypeVar(name=name):
|
||||
return name
|
||||
|
||||
case GenericType(name=name, params=params):
|
||||
return f"{name}[{', '.join(map(to_annotation, params))}]"
|
||||
|
||||
case AppliedType(name=name, args=args):
|
||||
return f"{name}[{', '.join(map(to_annotation, args))}]"
|
||||
|
||||
case ConstraintType():
|
||||
return str(type)
|
||||
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Predicate:
|
||||
type: Type
|
||||
body: m.Expr
|
||||
alias: bool
|
||||
|
||||
|
||||
Type = (
|
||||
TopType
|
||||
| BaseType
|
||||
@@ -310,5 +246,4 @@ Type = (
|
||||
| TypeVar
|
||||
| GenericType
|
||||
| AppliedType
|
||||
| ConstraintType
|
||||
)
|
||||
|
||||
@@ -4,5 +4,6 @@ from .format import format as format
|
||||
from .highlight import highlight as highlight
|
||||
from .parse import parse as parse
|
||||
from .registry import dump_registry as dump_registry
|
||||
from .stubs import stubs as stubs
|
||||
from .types import types as types
|
||||
from .validate import validate as validate
|
||||
|
||||
@@ -38,5 +38,5 @@ def compile(
|
||||
if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)):
|
||||
sys.exit(1)
|
||||
|
||||
generator = Generator(workdir=source_path.parent, types=checker.types)
|
||||
generator = Generator(workdir=source_path.parent)
|
||||
generator.generate(typed_ast, source_path)
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import TextIO
|
||||
|
||||
import click
|
||||
|
||||
from midas.ast.printer import MidasPrinter
|
||||
from midas.checker.checker import TypeChecker
|
||||
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
|
||||
|
||||
@@ -36,7 +35,6 @@ def dump_registry(
|
||||
for types_file in types:
|
||||
checker.import_midas(Path(types_file.name).resolve())
|
||||
|
||||
print("##### Types #####")
|
||||
for name, type in checker.types._types.items():
|
||||
members: dict[str, Type] = checker.types._members.get(name, {})
|
||||
print(f"{name} = {base_type(type)}")
|
||||
@@ -44,17 +42,3 @@ def dump_registry(
|
||||
print(" " * 4 + "Members:")
|
||||
for member_name, member_type in members.items():
|
||||
print(" " * 8 + f"{member_name}: {member_type}")
|
||||
|
||||
print("##### Predicates #####")
|
||||
printer = MidasPrinter()
|
||||
for name, predicate in checker.types._predicates.items():
|
||||
body: str = printer.print(predicate.body)
|
||||
if predicate.alias:
|
||||
print(f"{name}: {predicate.type} = {body}")
|
||||
else:
|
||||
print(f"{name}{predicate.type}:")
|
||||
body = "\n".join(
|
||||
" " + ("return " if i == 0 else "") + line
|
||||
for i, line in enumerate(body.split("\n"))
|
||||
)
|
||||
print(body)
|
||||
|
||||
27
midas/cli/commands/stubs.py
Normal file
27
midas/cli/commands/stubs.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
|
||||
import click
|
||||
|
||||
from midas.checker.checker import TypeChecker
|
||||
from midas.generator.stubs import StubsGenerator
|
||||
|
||||
|
||||
@click.command(help="Generate stubs from Midas definitions")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||
def stubs(
|
||||
file: TextIO,
|
||||
output: TextIO,
|
||||
):
|
||||
source_path: Path = Path(file.name).resolve()
|
||||
|
||||
checker = TypeChecker()
|
||||
checker.import_midas(source_path)
|
||||
|
||||
generator = StubsGenerator(checker.types)
|
||||
module: ast.Module = generator.generate_stubs()
|
||||
module = ast.fix_missing_locations(module)
|
||||
|
||||
output.write(ast.unparse(module))
|
||||
@@ -18,6 +18,7 @@ midas.add_command(commands.highlight)
|
||||
midas.add_command(commands.parse)
|
||||
midas.add_command(commands.dump_registry)
|
||||
midas.add_command(commands.types)
|
||||
midas.add_command(commands.stubs)
|
||||
midas.add_command(commands.validate)
|
||||
|
||||
|
||||
|
||||
@@ -1,224 +0,0 @@
|
||||
import ast
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import (
|
||||
Function,
|
||||
Predicate,
|
||||
Type,
|
||||
to_annotation,
|
||||
)
|
||||
from midas.lexer.token import TokenType
|
||||
|
||||
LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = {
|
||||
TokenType.AND: ast.And,
|
||||
# TokenType.OR: ast.Or,
|
||||
}
|
||||
|
||||
BINARY_OPERATORS: dict[TokenType, type[ast.operator]] = {
|
||||
# TokenType.PLUS: ast.Add,
|
||||
TokenType.MINUS: ast.Sub,
|
||||
TokenType.STAR: ast.Mult,
|
||||
TokenType.SLASH: ast.Div,
|
||||
}
|
||||
|
||||
UNARY_OPERATORS: dict[TokenType, type[ast.unaryop]] = {
|
||||
# TokenType.PLUS: ast.UAdd,
|
||||
TokenType.MINUS: ast.USub,
|
||||
}
|
||||
|
||||
COMPARISON_OPERATORS: dict[TokenType, type[ast.cmpop]] = {
|
||||
TokenType.GREATER: ast.Gt,
|
||||
TokenType.GREATER_EQUAL: ast.GtE,
|
||||
TokenType.LESS: ast.Lt,
|
||||
TokenType.LESS_EQUAL: ast.LtE,
|
||||
TokenType.EQUAL_EQUAL: ast.Eq,
|
||||
TokenType.BANG_EQUAL: ast.NotEq,
|
||||
}
|
||||
|
||||
|
||||
class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||
def __init__(self, types: TypesRegistry):
|
||||
self.types: TypesRegistry = types
|
||||
self._id: int = 0
|
||||
self._definitions: list[ast.stmt] = []
|
||||
self._aliases: dict[str, str] = {}
|
||||
|
||||
def get_definitions(self) -> list[ast.stmt]:
|
||||
return self._definitions
|
||||
|
||||
def generate(self, expr: m.Expr) -> ast.expr:
|
||||
match expr:
|
||||
case m.VariableExpr():
|
||||
return expr.accept(self)
|
||||
case _:
|
||||
func = Function(
|
||||
pos_args=[],
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="_",
|
||||
type=self.types.get_type("Any"),
|
||||
required=True,
|
||||
)
|
||||
],
|
||||
kw_args=[],
|
||||
returns=self.types.get_type("bool"),
|
||||
)
|
||||
alias: str = self.make_alias(None)
|
||||
definition: ast.stmt = self.make_definition(
|
||||
alias, Predicate(type=func, body=expr, alias=False)
|
||||
)
|
||||
self._definitions.append(definition)
|
||||
return ast.Name(id=alias)
|
||||
|
||||
def make_alias(self, name: Optional[str]) -> str:
|
||||
suffix: str
|
||||
if name is None:
|
||||
suffix = f"p{self._id}"
|
||||
self._id += 1
|
||||
else:
|
||||
suffix = name
|
||||
alias: str = f"__midas_{suffix}__"
|
||||
return alias
|
||||
|
||||
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
|
||||
body: ast.expr = predicate.body.accept(self)
|
||||
if predicate.alias:
|
||||
return ast.Assign(
|
||||
targets=[
|
||||
ast.Name(id=name),
|
||||
],
|
||||
value=body,
|
||||
)
|
||||
return self.make_func(name, [ast.Return(value=body)], predicate.type)
|
||||
|
||||
def make_args(self, func: Function) -> ast.arguments:
|
||||
return ast.arguments(
|
||||
posonlyargs=[
|
||||
ast.arg(
|
||||
arg=arg.name,
|
||||
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||
)
|
||||
for arg in func.pos_args
|
||||
],
|
||||
args=[
|
||||
ast.arg(
|
||||
arg=arg.name,
|
||||
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||
)
|
||||
for arg in func.args
|
||||
],
|
||||
kwonlyargs=[
|
||||
ast.arg(
|
||||
arg=arg.name,
|
||||
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||
)
|
||||
for arg in func.kw_args
|
||||
],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
)
|
||||
|
||||
def make_func(
|
||||
self, name: str, inner_body: list[ast.stmt], type: Type, level: int = 0
|
||||
) -> ast.stmt:
|
||||
match type:
|
||||
case Function(returns=Function()):
|
||||
inner_name: str = f"inner{level}"
|
||||
return ast.FunctionDef(
|
||||
name=name,
|
||||
args=self.make_args(type),
|
||||
body=[
|
||||
self.make_func(inner_name, inner_body, type.returns, level + 1),
|
||||
ast.Return(value=ast.Name(id=inner_name)),
|
||||
],
|
||||
returns=ast.Constant(value=to_annotation(type.returns)),
|
||||
decorator_list=[],
|
||||
)
|
||||
|
||||
case Function():
|
||||
return ast.FunctionDef(
|
||||
name=name,
|
||||
args=self.make_args(type),
|
||||
body=inner_body,
|
||||
returns=ast.Constant(value=to_annotation(type.returns)),
|
||||
decorator_list=[],
|
||||
)
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Expected function, got {type!r}")
|
||||
|
||||
def get_predicate(self, name: str) -> Optional[ast.expr]:
|
||||
if name not in self._aliases:
|
||||
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
|
||||
if predicate is None:
|
||||
return None
|
||||
alias: str = self.make_alias(name)
|
||||
self._aliases[name] = alias
|
||||
self._definitions.append(self.make_definition(alias, predicate))
|
||||
|
||||
return ast.Name(id=self._aliases[name])
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> ast.expr:
|
||||
return ast.BoolOp(
|
||||
op=LOGICAL_OPERATORS[expr.operator.type](),
|
||||
values=[
|
||||
expr.left.accept(self),
|
||||
expr.right.accept(self),
|
||||
],
|
||||
)
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> ast.expr:
|
||||
op: TokenType = expr.operator.type
|
||||
if op in BINARY_OPERATORS:
|
||||
return ast.BinOp(
|
||||
left=expr.left.accept(self),
|
||||
op=BINARY_OPERATORS[op](),
|
||||
right=expr.right.accept(self),
|
||||
)
|
||||
if op in COMPARISON_OPERATORS:
|
||||
return ast.Compare(
|
||||
left=expr.left.accept(self),
|
||||
ops=[COMPARISON_OPERATORS[op]()],
|
||||
comparators=[expr.right.accept(self)],
|
||||
)
|
||||
raise ValueError(f"Unexpected binary operator {op}")
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> ast.expr:
|
||||
return ast.UnaryOp(
|
||||
op=UNARY_OPERATORS[expr.operator.type](),
|
||||
operand=expr.right.accept(self),
|
||||
)
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=expr.callee.accept(self),
|
||||
args=[arg.accept(self) for arg in expr.arguments],
|
||||
keywords=[
|
||||
ast.keyword(arg=name, value=arg.accept(self))
|
||||
for name, arg in expr.keywords.items()
|
||||
],
|
||||
)
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> ast.expr:
|
||||
return ast.Attribute(
|
||||
value=expr.expr.accept(self),
|
||||
attr=expr.name.lexeme,
|
||||
)
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> ast.expr:
|
||||
name: str = expr.name.lexeme
|
||||
if (p := self.get_predicate(name)) is not None:
|
||||
return p
|
||||
return ast.Name(id=name)
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> ast.expr:
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> ast.expr:
|
||||
return ast.Constant(value=expr.value)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> ast.expr:
|
||||
return ast.Name(id="_")
|
||||
@@ -2,19 +2,15 @@ import ast
|
||||
import shutil
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional, assert_never
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.printer import MidasPrinter
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
ExtensionType,
|
||||
Function,
|
||||
GenericType,
|
||||
@@ -23,9 +19,7 @@ from midas.checker.types import (
|
||||
Type,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.generator.constraints import ConstraintGenerator
|
||||
from midas.utils import TypedAST
|
||||
|
||||
|
||||
@@ -36,9 +30,12 @@ class Scope:
|
||||
|
||||
|
||||
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
|
||||
def __init__(self, workdir: Path) -> None:
|
||||
self.workdir: Path = workdir.resolve()
|
||||
self.build_dir: Path = self.workdir / "build" / "midas"
|
||||
if self.build_dir.exists():
|
||||
shutil.rmtree(self.build_dir)
|
||||
self.build_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.rel_src_path: Path = Path()
|
||||
|
||||
self._typed_ast: TypedAST = TypedAST(
|
||||
@@ -46,18 +43,13 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
judgements=[],
|
||||
)
|
||||
self._alias_count: int = 0
|
||||
self._predicate_count: int = 0
|
||||
self._scopes: list[Scope] = []
|
||||
|
||||
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
|
||||
self._constraints: list[tuple[m.Expr, ast.expr]] = []
|
||||
|
||||
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
|
||||
self.rel_src_path = src_path.resolve().relative_to(self.workdir)
|
||||
self.rel_src_path = src_path.relative_to(self.workdir)
|
||||
self._typed_ast = typed_ast
|
||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
||||
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
|
||||
module = ast.Module(body=predicates + body, type_ignores=[])
|
||||
module = ast.Module(body=body, type_ignores=[])
|
||||
module = ast.fix_missing_locations(module)
|
||||
return module
|
||||
|
||||
@@ -67,9 +59,6 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
module: ast.AST = self.generate_ast(typed_ast, src_path)
|
||||
compiled: str = ast.unparse(module)
|
||||
if out_path is None:
|
||||
if self.build_dir.exists():
|
||||
shutil.rmtree(self.build_dir)
|
||||
self.build_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = (self.build_dir / self.rel_src_path).resolve()
|
||||
try:
|
||||
_ = out_path.relative_to(self.build_dir)
|
||||
@@ -257,7 +246,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
return generated
|
||||
|
||||
def _make_alias(self, expr: ast.expr) -> ast.expr:
|
||||
name: str = f"__midas_a{self._alias_count}__"
|
||||
name: str = f"__midas_alias_{self._alias_count}__"
|
||||
alias = ast.Name(id=name)
|
||||
self._alias_count += 1
|
||||
self._scopes[-1].aliases.append(name)
|
||||
@@ -287,9 +276,6 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
|
||||
match type:
|
||||
case UnknownType():
|
||||
pass
|
||||
|
||||
case BaseType(name=name):
|
||||
self._add_assert(
|
||||
ast.Call(
|
||||
@@ -315,15 +301,8 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
self._make_cast_assert_message(src_location, expr, type),
|
||||
)
|
||||
|
||||
case AppliedType(body=body):
|
||||
self._make_cast_asserts(src_location, expr, body)
|
||||
|
||||
case ConstraintType(type=base, constraint=constraint):
|
||||
self._make_cast_asserts(src_location, expr, base)
|
||||
self._make_constraint_assert(src_location, expr, constraint)
|
||||
|
||||
case TypeVar():
|
||||
raise RuntimeError("Unexpected TypeVar")
|
||||
case AppliedType():
|
||||
self._make_cast_asserts(src_location, expr, type.body)
|
||||
|
||||
case (
|
||||
TopType()
|
||||
@@ -335,9 +314,8 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
):
|
||||
raise NotImplementedError(f"Can't make assertion for type {type}")
|
||||
|
||||
# Ensure exhaustiveness
|
||||
case _:
|
||||
assert_never(type)
|
||||
case TypeVar():
|
||||
raise RuntimeError("Unexpected TypeVar")
|
||||
|
||||
def _make_cast_assert_message(
|
||||
self, location: Location, expr: ast.expr, type: Type
|
||||
@@ -361,36 +339,3 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
ast.Constant(f" to {type}"),
|
||||
]
|
||||
)
|
||||
|
||||
def _make_constraint_assert(
|
||||
self, src_location: Location, expr: ast.expr, constraint: m.Expr
|
||||
):
|
||||
test_func: ast.expr = self._get_constraint(constraint)
|
||||
self._add_assert(
|
||||
ast.Call(
|
||||
func=test_func,
|
||||
args=[expr],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_constraint_assert_message(src_location, expr, constraint),
|
||||
)
|
||||
|
||||
def _make_constraint_assert_message(
|
||||
self, location: Location, expr: ast.expr, constraint: m.Expr
|
||||
) -> ast.expr:
|
||||
printer = MidasPrinter()
|
||||
constraint_str: str = printer.print(constraint)
|
||||
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||
# f"file.py:L1:1: ConstraintError: Value does not fit constraint 'v > 0'"
|
||||
return ast.Constant(
|
||||
f"{loc_str}: ConstraintError: Value does not fit constraint '{constraint_str}'"
|
||||
)
|
||||
|
||||
def _get_constraint(self, expr: m.Expr) -> ast.expr:
|
||||
for expr2, constraint in self._constraints:
|
||||
if expr2 == expr:
|
||||
return constraint
|
||||
|
||||
constraint: ast.expr = self._constraint_generator.generate(expr)
|
||||
self._constraints.append((expr, constraint))
|
||||
return constraint
|
||||
|
||||
337
midas/generator/stubs.py
Normal file
337
midas/generator/stubs.py
Normal file
@@ -0,0 +1,337 @@
|
||||
import ast
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.registry import Member, TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ComplexType,
|
||||
ExtensionType,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
UnknownType,
|
||||
substitute_typevars,
|
||||
)
|
||||
|
||||
Empty = ast.Constant(value=...)
|
||||
|
||||
|
||||
class StubsGenerator:
|
||||
def __init__(self, types: TypesRegistry) -> None:
|
||||
self.types: TypesRegistry = types
|
||||
self.stubs: list[ast.stmt] = []
|
||||
self.typing_imports: set[str] = set()
|
||||
self.protocol_idx: int = 0
|
||||
self.stub_idx: int = 0
|
||||
self.type_var_idx: int = 0
|
||||
self.substitutions: dict[str, dict[str, Type]] = {}
|
||||
|
||||
def generate_stubs(self) -> ast.Module:
|
||||
self.stubs = []
|
||||
self.typing_imports = set()
|
||||
for name, type in self.types._types.items():
|
||||
self.generate_stub(name, type)
|
||||
|
||||
imports = [
|
||||
ast.ImportFrom(
|
||||
module="__future__",
|
||||
names=[ast.alias(name="annotations")],
|
||||
level=0,
|
||||
)
|
||||
]
|
||||
if len(self.typing_imports) != 0:
|
||||
imports.append(
|
||||
ast.ImportFrom(
|
||||
module="typing",
|
||||
names=[
|
||||
ast.alias(name=name) for name in sorted(self.typing_imports)
|
||||
],
|
||||
level=0,
|
||||
)
|
||||
)
|
||||
return ast.Module(body=imports + self.stubs, type_ignores=[])
|
||||
|
||||
def generate_stub(self, name: str, type: Type):
|
||||
base_type: Type = type
|
||||
|
||||
members: dict[str, Member] = self.types._members.get(name, {})
|
||||
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
|
||||
return
|
||||
|
||||
bases: list[ast.expr] = []
|
||||
substitutions: dict[str, Type] = {}
|
||||
bases, substitutions = self.get_bases(type)
|
||||
self.substitutions[name] = substitutions
|
||||
|
||||
body = self.generate_body(members, substitutions)
|
||||
stub = ast.ClassDef(
|
||||
name=name,
|
||||
bases=bases,
|
||||
body=body,
|
||||
keywords=[],
|
||||
decorator_list=[],
|
||||
)
|
||||
self.add_stub(stub)
|
||||
|
||||
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
|
||||
match type:
|
||||
case AliasType(type=base):
|
||||
return [self.dump_type(base)], {}
|
||||
case GenericType(params=params, body=body):
|
||||
self.add_typing_import("Generic")
|
||||
type_vars: ast.expr
|
||||
|
||||
params2: list[TypeVar] = self.define_type_vars(params)
|
||||
if len(params) == 1:
|
||||
type_vars = ast.Name(id=params2[0].name)
|
||||
else:
|
||||
type_vars = ast.Tuple(
|
||||
elts=[ast.Name(id=param.name) for param in params2]
|
||||
)
|
||||
|
||||
substitutions: dict[str, TypeVar] = {
|
||||
param.name: param2 for param, param2 in zip(params, params2)
|
||||
}
|
||||
|
||||
body_bases, body_subsitutions = self.get_bases(body)
|
||||
return (
|
||||
body_bases
|
||||
+ [
|
||||
ast.Subscript(
|
||||
value=ast.Name(id="Generic"),
|
||||
slice=type_vars,
|
||||
)
|
||||
],
|
||||
body_subsitutions | substitutions,
|
||||
)
|
||||
case _:
|
||||
return [], {}
|
||||
|
||||
def generate_body(
|
||||
self, members: dict[str, Member], substitutions: dict[str, Type]
|
||||
) -> list[ast.stmt]:
|
||||
if len(members) == 0:
|
||||
return [ast.Expr(value=Empty)]
|
||||
|
||||
body: list[ast.stmt] = []
|
||||
for name, member in members.items():
|
||||
type: Type = member.type
|
||||
type = substitute_typevars(type, substitutions)
|
||||
match member.kind:
|
||||
case m.MemberKind.PROPERTY:
|
||||
body.append(
|
||||
ast.AnnAssign(
|
||||
target=ast.Name(id=name),
|
||||
annotation=self.dump_type(type),
|
||||
simple=1,
|
||||
)
|
||||
)
|
||||
case m.MemberKind.METHOD:
|
||||
body.extend(self.dump_method(name, type))
|
||||
return body
|
||||
|
||||
def dump_type(self, type: Type) -> ast.expr:
|
||||
match type:
|
||||
case AliasType(name=name) | GenericType(name=name) if (
|
||||
name in self.substitutions
|
||||
):
|
||||
type = substitute_typevars(type, self.substitutions[name])
|
||||
|
||||
match type:
|
||||
case TopType() | UnknownType():
|
||||
self.add_typing_import("Any")
|
||||
return ast.Name(id="Any")
|
||||
case BaseType(name=name):
|
||||
return ast.Name(id=name)
|
||||
case AliasType(name=name):
|
||||
return ast.Name(id=name)
|
||||
case UnitType():
|
||||
return ast.Constant(value=None)
|
||||
case Function():
|
||||
name: str = self.define_protocol(type)
|
||||
return ast.Name(id=name)
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
if len(overloads) == 1:
|
||||
return self.dump_type(overloads[0])
|
||||
return ast.BinOp(
|
||||
left=self.dump_type(OverloadedFunction(overloads=overloads[:-1])),
|
||||
op=ast.BitOr(),
|
||||
right=self.dump_type(overloads[-1]),
|
||||
)
|
||||
|
||||
case ComplexType():
|
||||
name: str = self.new_stub_name()
|
||||
self.generate_stub(name, type)
|
||||
return ast.Name(id=name)
|
||||
|
||||
case ExtensionType():
|
||||
raise NotImplementedError
|
||||
|
||||
case TypeVar():
|
||||
return ast.Name(id=type.name)
|
||||
case GenericType(name=name):
|
||||
params: ast.expr
|
||||
if len(type.params) == 1:
|
||||
params = self.dump_type(type.params[0])
|
||||
else:
|
||||
params = ast.Tuple(
|
||||
elts=[self.dump_type(param) for param in type.params]
|
||||
)
|
||||
return ast.Subscript(
|
||||
value=ast.Name(id=type.name),
|
||||
slice=params,
|
||||
)
|
||||
case AppliedType():
|
||||
args: ast.expr
|
||||
if len(type.args) == 1:
|
||||
args = self.dump_type(type.args[0])
|
||||
else:
|
||||
args = ast.Tuple(elts=[self.dump_type(arg) for arg in type.args])
|
||||
return ast.Subscript(
|
||||
value=ast.Name(id=type.name),
|
||||
slice=args,
|
||||
)
|
||||
|
||||
def dump_method(
|
||||
self, name: str, method: Type, overloaded: bool = False
|
||||
) -> list[ast.stmt]:
|
||||
match method:
|
||||
case Function():
|
||||
if overloaded:
|
||||
self.add_typing_import("overload")
|
||||
return [
|
||||
ast.FunctionDef(
|
||||
name=name,
|
||||
args=self.dump_args(method, with_self=True),
|
||||
returns=self.dump_type(method.returns),
|
||||
body=[ast.Expr(value=Empty)],
|
||||
decorator_list=[ast.Name(id="overload")] if overloaded else [],
|
||||
)
|
||||
]
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
stmts: list[ast.stmt] = []
|
||||
for overload in overloads:
|
||||
stmts.extend(self.dump_method(name, overload, True))
|
||||
return stmts
|
||||
case _:
|
||||
return [
|
||||
ast.AnnAssign(
|
||||
target=ast.Name(id=name),
|
||||
annotation=self.dump_type(method),
|
||||
simple=1,
|
||||
)
|
||||
]
|
||||
|
||||
def dump_args(self, func: Function, with_self: bool = False) -> ast.arguments:
|
||||
pos: list[ast.arg] = [
|
||||
ast.arg(arg=f"_{arg.pos}", annotation=self.dump_type(arg.type))
|
||||
for arg in func.pos_args
|
||||
]
|
||||
mixed: list[ast.arg] = [
|
||||
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
|
||||
for arg in func.args
|
||||
]
|
||||
kw: list[ast.arg] = [
|
||||
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
|
||||
for arg in func.kw_args
|
||||
]
|
||||
defaults: list[ast.expr] = [
|
||||
Empty for arg in func.pos_args + func.args if not arg.required
|
||||
]
|
||||
kw_defaults: list[Optional[ast.expr]] = [
|
||||
None if arg.required else Empty for arg in func.kw_args
|
||||
]
|
||||
if with_self:
|
||||
arg = ast.arg(arg="self", annotation=None)
|
||||
if len(pos) != 0:
|
||||
pos.insert(0, arg)
|
||||
else:
|
||||
mixed.insert(0, arg)
|
||||
return ast.arguments(
|
||||
posonlyargs=pos,
|
||||
args=mixed,
|
||||
kwonlyargs=kw,
|
||||
defaults=defaults,
|
||||
kw_defaults=kw_defaults,
|
||||
)
|
||||
|
||||
def define_protocol(self, func: Function) -> str:
|
||||
self.add_typing_import("Protocol")
|
||||
name: str = self.new_protocol_name()
|
||||
protocol = ast.ClassDef(
|
||||
name=name,
|
||||
bases=[ast.Name(id="Protocol")],
|
||||
keywords=[],
|
||||
body=[
|
||||
ast.FunctionDef(
|
||||
name="__call__",
|
||||
args=self.dump_args(func, with_self=True),
|
||||
returns=self.dump_type(func.returns),
|
||||
body=[ast.Expr(value=Empty)],
|
||||
decorator_list=[],
|
||||
),
|
||||
],
|
||||
decorator_list=[],
|
||||
)
|
||||
self.add_stub(protocol)
|
||||
return name
|
||||
|
||||
def new_protocol_name(self) -> str:
|
||||
name: str = f"_Protocol{self.protocol_idx}"
|
||||
self.protocol_idx += 1
|
||||
return name
|
||||
|
||||
def new_stub_name(self) -> str:
|
||||
name: str = f"_Stub_{self.stub_idx}"
|
||||
self.stub_idx += 1
|
||||
return name
|
||||
|
||||
def new_type_var_name(self) -> str:
|
||||
name: str = f"_T{self.type_var_idx}"
|
||||
self.type_var_idx += 1
|
||||
return name
|
||||
|
||||
def add_stub(self, stub: ast.stmt):
|
||||
self.stubs.append(stub)
|
||||
|
||||
def add_typing_import(self, name: str):
|
||||
self.typing_imports.add(name)
|
||||
|
||||
def define_type_vars(self, vars: list[TypeVar]) -> list[TypeVar]:
|
||||
vars2: list[TypeVar] = []
|
||||
for var in vars:
|
||||
vars2.append(self.define_type_var(var))
|
||||
return vars2
|
||||
|
||||
def define_type_var(self, var: TypeVar) -> TypeVar:
|
||||
name: str = self.new_type_var_name()
|
||||
self.add_typing_import("TypeVar")
|
||||
self.add_stub(
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id=name)],
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="TypeVar"),
|
||||
args=[
|
||||
ast.Constant(value=name),
|
||||
],
|
||||
keywords=(
|
||||
[]
|
||||
if var.bound is None
|
||||
else [
|
||||
ast.keyword(
|
||||
arg="bound",
|
||||
value=self.dump_type(var.bound),
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
return TypeVar(name=name, bound=None)
|
||||
@@ -69,8 +69,6 @@ class MidasLexer(Lexer):
|
||||
):
|
||||
self.advance()
|
||||
self.add_token(TokenType.WHITESPACE)
|
||||
case '"' | "'":
|
||||
self.scan_string(char)
|
||||
case _:
|
||||
if char.isdigit():
|
||||
self.scan_number()
|
||||
@@ -80,17 +78,6 @@ class MidasLexer(Lexer):
|
||||
self.error("Unexpected character")
|
||||
return None
|
||||
|
||||
def scan_string(self, opening: str):
|
||||
while self.peek() != opening and not self.is_at_end():
|
||||
self.advance()
|
||||
|
||||
if self.is_at_end():
|
||||
self.error("Unterminated string")
|
||||
|
||||
self.advance()
|
||||
value: str = self.source[self.start + 1 : self.idx - 1]
|
||||
self.add_token(TokenType.STRING, value)
|
||||
|
||||
def scan_number(self):
|
||||
"""Scan the rest of number and add it as a token
|
||||
|
||||
|
||||
@@ -43,7 +43,6 @@ class TokenType(Enum):
|
||||
TRUE = auto()
|
||||
FALSE = auto()
|
||||
NONE = auto()
|
||||
STRING = auto()
|
||||
|
||||
# Keywords
|
||||
TYPE = auto()
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Optional
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.midas import (
|
||||
BinaryExpr,
|
||||
CallExpr,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
Expr,
|
||||
@@ -18,7 +17,6 @@ from midas.ast.midas import (
|
||||
MemberKind,
|
||||
MemberStmt,
|
||||
NamedType,
|
||||
ParamSpec,
|
||||
PredicateStmt,
|
||||
Stmt,
|
||||
Type,
|
||||
@@ -267,9 +265,6 @@ class MidasParser(Parser):
|
||||
Returns:
|
||||
Expr: the parsed constraint expression
|
||||
"""
|
||||
return self.expression()
|
||||
|
||||
def expression(self) -> Expr:
|
||||
return self.and_()
|
||||
|
||||
def and_(self) -> Expr:
|
||||
@@ -336,55 +331,7 @@ class MidasParser(Parser):
|
||||
right: Expr = self.unary()
|
||||
location: Location = Location.span(operator.get_location(), right.location)
|
||||
return UnaryExpr(location=location, operator=operator, right=right)
|
||||
return self.call()
|
||||
|
||||
def call(self) -> Expr:
|
||||
expr: Expr = self.reference()
|
||||
while self.match(TokenType.LEFT_PAREN):
|
||||
expr = self.finish_call(expr)
|
||||
return expr
|
||||
|
||||
def finish_call(self, callee: Expr) -> Expr:
|
||||
pos_args: list[Expr] = []
|
||||
kw_args: dict[str, Expr] = {}
|
||||
keywords: bool = False
|
||||
while not self.match(TokenType.RIGHT_PAREN):
|
||||
if self.check_identifier() and self.check_next(TokenType.EQUAL):
|
||||
keywords = True
|
||||
keyword: Token = self.advance()
|
||||
self.advance()
|
||||
value: Expr = self.expression()
|
||||
name: str = keyword.lexeme
|
||||
if name in kw_args:
|
||||
self.error(
|
||||
self.peek(),
|
||||
f"Multiple values passed for '{name}', only the last occurrence will be used",
|
||||
)
|
||||
kw_args[name] = value
|
||||
else:
|
||||
value = self.expression()
|
||||
if self.check(TokenType.EQUAL):
|
||||
if keywords:
|
||||
raise self.error(self.peek(), "Invalid keyword argument name")
|
||||
else:
|
||||
raise self.error(
|
||||
self.peek(),
|
||||
"Cannot pass positional arguments after a keyword argument",
|
||||
)
|
||||
pos_args.append(value)
|
||||
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
|
||||
r_paren: Token = self.consume(
|
||||
TokenType.RIGHT_PAREN, "Expected ')' after arguments."
|
||||
)
|
||||
return CallExpr(
|
||||
location=Location.span(callee.location, r_paren.get_location()),
|
||||
callee=callee,
|
||||
arguments=pos_args,
|
||||
keywords=kw_args,
|
||||
)
|
||||
return self.reference()
|
||||
|
||||
def reference(self) -> Expr:
|
||||
"""Parse an attribute access expression or a simpler expression
|
||||
@@ -418,9 +365,6 @@ class MidasParser(Parser):
|
||||
if self.match(TokenType.NUMBER):
|
||||
return LiteralExpr(location=token.get_location(), value=token.value)
|
||||
|
||||
if self.match(TokenType.STRING):
|
||||
return LiteralExpr(location=token.get_location(), value=token.value)
|
||||
|
||||
if self.match_identifier():
|
||||
return VariableExpr(location=token.get_location(), name=token)
|
||||
|
||||
@@ -509,35 +453,23 @@ class MidasParser(Parser):
|
||||
PredicateStmt: the parsed predicate declaration statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
|
||||
name: Token = self.consume_identifier("Expected predicate name")
|
||||
|
||||
params: list[ParamSpec] = []
|
||||
while self.check(TokenType.LEFT_PAREN):
|
||||
params.append(self.function_args())
|
||||
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
||||
subject: Token = self.consume_identifier("Expected subject name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
||||
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
|
||||
body: Expr = self.constraint()
|
||||
condition: Expr = self.constraint()
|
||||
return PredicateStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
name=name,
|
||||
params=params,
|
||||
body=body,
|
||||
subject=subject,
|
||||
type=type,
|
||||
condition=condition,
|
||||
)
|
||||
|
||||
def function(self) -> FunctionType:
|
||||
params: ParamSpec = self.function_args()
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
result: Type = self.type_expr()
|
||||
|
||||
return FunctionType(
|
||||
location=params.l_paren.location_to(self.previous()),
|
||||
params=params,
|
||||
returns=result,
|
||||
)
|
||||
|
||||
def function_args(self) -> ParamSpec:
|
||||
l_paren: Token = self.consume(
|
||||
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
||||
)
|
||||
@@ -594,4 +526,14 @@ class MidasParser(Parser):
|
||||
self.error(token, "Unnamed mixed argument")
|
||||
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
||||
return ParamSpec(l_paren=l_paren, pos=pos_args, mixed=args, kw=kw_args)
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
result: Type = self.type_expr()
|
||||
|
||||
return FunctionType(
|
||||
location=l_paren.location_to(self.previous()),
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=result,
|
||||
)
|
||||
|
||||
@@ -9,13 +9,13 @@ Module(
|
||||
level=0),
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='__midas_a0__')],
|
||||
Name(id='__midas_alias_0__')],
|
||||
value=Constant(value=123.45)),
|
||||
Assert(
|
||||
test=Call(
|
||||
func=Name(id='isinstance'),
|
||||
args=[
|
||||
Name(id='__midas_a0__'),
|
||||
Name(id='__midas_alias_0__'),
|
||||
Name(id='float')],
|
||||
keywords=[]),
|
||||
msg=JoinedStr(
|
||||
@@ -26,7 +26,7 @@ Module(
|
||||
value=Call(
|
||||
func=Name(id='type'),
|
||||
args=[
|
||||
Name(id='__midas_a0__')],
|
||||
Name(id='__midas_alias_0__')],
|
||||
keywords=[]),
|
||||
attr='__name__'),
|
||||
conversion=-1),
|
||||
@@ -34,19 +34,19 @@ Module(
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='distance')],
|
||||
value=Name(id='__midas_a0__')),
|
||||
value=Name(id='__midas_alias_0__')),
|
||||
Delete(
|
||||
targets=[
|
||||
Name(id='__midas_a0__')]),
|
||||
Name(id='__midas_alias_0__')]),
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='__midas_a1__')],
|
||||
Name(id='__midas_alias_1__')],
|
||||
value=Constant(value=6.7)),
|
||||
Assert(
|
||||
test=Call(
|
||||
func=Name(id='isinstance'),
|
||||
args=[
|
||||
Name(id='__midas_a1__'),
|
||||
Name(id='__midas_alias_1__'),
|
||||
Name(id='float')],
|
||||
keywords=[]),
|
||||
msg=JoinedStr(
|
||||
@@ -57,7 +57,7 @@ Module(
|
||||
value=Call(
|
||||
func=Name(id='type'),
|
||||
args=[
|
||||
Name(id='__midas_a1__')],
|
||||
Name(id='__midas_alias_1__')],
|
||||
keywords=[]),
|
||||
attr='__name__'),
|
||||
conversion=-1),
|
||||
@@ -65,10 +65,10 @@ Module(
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='time')],
|
||||
value=Name(id='__midas_a1__')),
|
||||
value=Name(id='__midas_alias_1__')),
|
||||
Delete(
|
||||
targets=[
|
||||
Name(id='__midas_a1__')]),
|
||||
Name(id='__midas_alias_1__')]),
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='speed')],
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
// Inline
|
||||
type T1 = float where _ > 0
|
||||
|
||||
// Named
|
||||
predicate is_positive(v: float) = v > 0
|
||||
type T2 = float where is_positive(_)
|
||||
|
||||
// Curried
|
||||
predicate in_range(mn: float, mx: float)(v: float) = v >= mn & v < mx
|
||||
type T3 = float where in_range(100, 200)(_)
|
||||
|
||||
// Alias
|
||||
predicate minor = in_range(0, 18)
|
||||
type T4 = float where minor(_)
|
||||
@@ -1,8 +0,0 @@
|
||||
from midas import T1, T2, T3, T4, cast
|
||||
|
||||
t: float = 12.5
|
||||
|
||||
t1: T1 = cast(T1, t)
|
||||
t2: T2 = cast(T2, t)
|
||||
t3: T3 = cast(T3, t)
|
||||
t4: T4 = cast(T4, t)
|
||||
@@ -1,333 +0,0 @@
|
||||
Module(
|
||||
body=[
|
||||
FunctionDef(
|
||||
name='__midas_p0__',
|
||||
args=arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
arg(
|
||||
arg='_',
|
||||
annotation=Constant(value='Any'))],
|
||||
kwonlyargs=[],
|
||||
kw_defaults=[],
|
||||
defaults=[]),
|
||||
body=[
|
||||
Return(
|
||||
value=Compare(
|
||||
left=Name(id='_'),
|
||||
ops=[
|
||||
Gt()],
|
||||
comparators=[
|
||||
Constant(value=0.0)]))],
|
||||
decorator_list=[],
|
||||
returns=Constant(value='bool')),
|
||||
FunctionDef(
|
||||
name='__midas_is_positive__',
|
||||
args=arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
arg(
|
||||
arg='v',
|
||||
annotation=Constant(value='float'))],
|
||||
kwonlyargs=[],
|
||||
kw_defaults=[],
|
||||
defaults=[]),
|
||||
body=[
|
||||
Return(
|
||||
value=Compare(
|
||||
left=Name(id='v'),
|
||||
ops=[
|
||||
Gt()],
|
||||
comparators=[
|
||||
Constant(value=0.0)]))],
|
||||
decorator_list=[],
|
||||
returns=Constant(value='bool')),
|
||||
FunctionDef(
|
||||
name='__midas_p1__',
|
||||
args=arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
arg(
|
||||
arg='_',
|
||||
annotation=Constant(value='Any'))],
|
||||
kwonlyargs=[],
|
||||
kw_defaults=[],
|
||||
defaults=[]),
|
||||
body=[
|
||||
Return(
|
||||
value=Call(
|
||||
func=Name(id='__midas_is_positive__'),
|
||||
args=[
|
||||
Name(id='_')],
|
||||
keywords=[]))],
|
||||
decorator_list=[],
|
||||
returns=Constant(value='bool')),
|
||||
FunctionDef(
|
||||
name='__midas_in_range__',
|
||||
args=arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
arg(
|
||||
arg='mn',
|
||||
annotation=Constant(value='float')),
|
||||
arg(
|
||||
arg='mx',
|
||||
annotation=Constant(value='float'))],
|
||||
kwonlyargs=[],
|
||||
kw_defaults=[],
|
||||
defaults=[]),
|
||||
body=[
|
||||
FunctionDef(
|
||||
name='inner0',
|
||||
args=arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
arg(
|
||||
arg='v',
|
||||
annotation=Constant(value='float'))],
|
||||
kwonlyargs=[],
|
||||
kw_defaults=[],
|
||||
defaults=[]),
|
||||
body=[
|
||||
Return(
|
||||
value=BoolOp(
|
||||
op=And(),
|
||||
values=[
|
||||
Compare(
|
||||
left=Name(id='v'),
|
||||
ops=[
|
||||
GtE()],
|
||||
comparators=[
|
||||
Name(id='mn')]),
|
||||
Compare(
|
||||
left=Name(id='v'),
|
||||
ops=[
|
||||
Lt()],
|
||||
comparators=[
|
||||
Name(id='mx')])]))],
|
||||
decorator_list=[],
|
||||
returns=Constant(value='bool')),
|
||||
Return(
|
||||
value=Name(id='inner0'))],
|
||||
decorator_list=[],
|
||||
returns=Constant(value='Callable[[float], bool]')),
|
||||
FunctionDef(
|
||||
name='__midas_p2__',
|
||||
args=arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
arg(
|
||||
arg='_',
|
||||
annotation=Constant(value='Any'))],
|
||||
kwonlyargs=[],
|
||||
kw_defaults=[],
|
||||
defaults=[]),
|
||||
body=[
|
||||
Return(
|
||||
value=Call(
|
||||
func=Call(
|
||||
func=Name(id='__midas_in_range__'),
|
||||
args=[
|
||||
Constant(value=100.0),
|
||||
Constant(value=200.0)],
|
||||
keywords=[]),
|
||||
args=[
|
||||
Name(id='_')],
|
||||
keywords=[]))],
|
||||
decorator_list=[],
|
||||
returns=Constant(value='bool')),
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='__midas_minor__')],
|
||||
value=Call(
|
||||
func=Name(id='__midas_in_range__'),
|
||||
args=[
|
||||
Constant(value=0.0),
|
||||
Constant(value=18.0)],
|
||||
keywords=[])),
|
||||
FunctionDef(
|
||||
name='__midas_p3__',
|
||||
args=arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
arg(
|
||||
arg='_',
|
||||
annotation=Constant(value='Any'))],
|
||||
kwonlyargs=[],
|
||||
kw_defaults=[],
|
||||
defaults=[]),
|
||||
body=[
|
||||
Return(
|
||||
value=Call(
|
||||
func=Name(id='__midas_minor__'),
|
||||
args=[
|
||||
Name(id='_')],
|
||||
keywords=[]))],
|
||||
decorator_list=[],
|
||||
returns=Constant(value='bool')),
|
||||
ImportFrom(
|
||||
module='midas',
|
||||
names=[
|
||||
alias(name='T1'),
|
||||
alias(name='T2'),
|
||||
alias(name='T3'),
|
||||
alias(name='T4'),
|
||||
alias(name='cast')],
|
||||
level=0),
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='t')],
|
||||
value=Constant(value=12.5)),
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='__midas_a0__')],
|
||||
value=Name(id='t')),
|
||||
Assert(
|
||||
test=Call(
|
||||
func=Name(id='isinstance'),
|
||||
args=[
|
||||
Name(id='__midas_a0__'),
|
||||
Name(id='float')],
|
||||
keywords=[]),
|
||||
msg=JoinedStr(
|
||||
values=[
|
||||
Constant(value='02_constraints.py:L5:10: CastError: Cannot cast '),
|
||||
FormattedValue(
|
||||
value=Attribute(
|
||||
value=Call(
|
||||
func=Name(id='type'),
|
||||
args=[
|
||||
Name(id='__midas_a0__')],
|
||||
keywords=[]),
|
||||
attr='__name__'),
|
||||
conversion=-1),
|
||||
Constant(value=' to float')])),
|
||||
Assert(
|
||||
test=Call(
|
||||
func=Name(id='__midas_p0__'),
|
||||
args=[
|
||||
Name(id='__midas_a0__')],
|
||||
keywords=[]),
|
||||
msg=Constant(value="02_constraints.py:L5:10: ConstraintError: Value does not fit constraint '_ > 0.0'")),
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='t1')],
|
||||
value=Name(id='__midas_a0__')),
|
||||
Delete(
|
||||
targets=[
|
||||
Name(id='__midas_a0__')]),
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='__midas_a1__')],
|
||||
value=Name(id='t')),
|
||||
Assert(
|
||||
test=Call(
|
||||
func=Name(id='isinstance'),
|
||||
args=[
|
||||
Name(id='__midas_a1__'),
|
||||
Name(id='float')],
|
||||
keywords=[]),
|
||||
msg=JoinedStr(
|
||||
values=[
|
||||
Constant(value='02_constraints.py:L6:10: CastError: Cannot cast '),
|
||||
FormattedValue(
|
||||
value=Attribute(
|
||||
value=Call(
|
||||
func=Name(id='type'),
|
||||
args=[
|
||||
Name(id='__midas_a1__')],
|
||||
keywords=[]),
|
||||
attr='__name__'),
|
||||
conversion=-1),
|
||||
Constant(value=' to float')])),
|
||||
Assert(
|
||||
test=Call(
|
||||
func=Name(id='__midas_p1__'),
|
||||
args=[
|
||||
Name(id='__midas_a1__')],
|
||||
keywords=[]),
|
||||
msg=Constant(value="02_constraints.py:L6:10: ConstraintError: Value does not fit constraint 'is_positive(_)'")),
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='t2')],
|
||||
value=Name(id='__midas_a1__')),
|
||||
Delete(
|
||||
targets=[
|
||||
Name(id='__midas_a1__')]),
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='__midas_a2__')],
|
||||
value=Name(id='t')),
|
||||
Assert(
|
||||
test=Call(
|
||||
func=Name(id='isinstance'),
|
||||
args=[
|
||||
Name(id='__midas_a2__'),
|
||||
Name(id='float')],
|
||||
keywords=[]),
|
||||
msg=JoinedStr(
|
||||
values=[
|
||||
Constant(value='02_constraints.py:L7:10: CastError: Cannot cast '),
|
||||
FormattedValue(
|
||||
value=Attribute(
|
||||
value=Call(
|
||||
func=Name(id='type'),
|
||||
args=[
|
||||
Name(id='__midas_a2__')],
|
||||
keywords=[]),
|
||||
attr='__name__'),
|
||||
conversion=-1),
|
||||
Constant(value=' to float')])),
|
||||
Assert(
|
||||
test=Call(
|
||||
func=Name(id='__midas_p2__'),
|
||||
args=[
|
||||
Name(id='__midas_a2__')],
|
||||
keywords=[]),
|
||||
msg=Constant(value="02_constraints.py:L7:10: ConstraintError: Value does not fit constraint 'in_range(100.0, 200.0)(_)'")),
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='t3')],
|
||||
value=Name(id='__midas_a2__')),
|
||||
Delete(
|
||||
targets=[
|
||||
Name(id='__midas_a2__')]),
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='__midas_a3__')],
|
||||
value=Name(id='t')),
|
||||
Assert(
|
||||
test=Call(
|
||||
func=Name(id='isinstance'),
|
||||
args=[
|
||||
Name(id='__midas_a3__'),
|
||||
Name(id='float')],
|
||||
keywords=[]),
|
||||
msg=JoinedStr(
|
||||
values=[
|
||||
Constant(value='02_constraints.py:L8:10: CastError: Cannot cast '),
|
||||
FormattedValue(
|
||||
value=Attribute(
|
||||
value=Call(
|
||||
func=Name(id='type'),
|
||||
args=[
|
||||
Name(id='__midas_a3__')],
|
||||
keywords=[]),
|
||||
attr='__name__'),
|
||||
conversion=-1),
|
||||
Constant(value=' to float')])),
|
||||
Assert(
|
||||
test=Call(
|
||||
func=Name(id='__midas_p3__'),
|
||||
args=[
|
||||
Name(id='__midas_a3__')],
|
||||
keywords=[]),
|
||||
msg=Constant(value="02_constraints.py:L8:10: ConstraintError: Value does not fit constraint 'minor(_)'")),
|
||||
Assign(
|
||||
targets=[
|
||||
Name(id='t4')],
|
||||
value=Name(id='__midas_a3__')),
|
||||
Delete(
|
||||
targets=[
|
||||
Name(id='__midas_a3__')])],
|
||||
type_ignores=[])
|
||||
@@ -2582,9 +2582,7 @@
|
||||
"name": "__sub__",
|
||||
"type": {
|
||||
"_type": "FunctionType",
|
||||
"params": {
|
||||
"_type": "ParamSpec",
|
||||
"pos": [
|
||||
"pos_args": [
|
||||
{
|
||||
"name": null,
|
||||
"type": {
|
||||
@@ -2594,9 +2592,8 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"_type": "GenericType",
|
||||
"type": {
|
||||
@@ -2676,9 +2673,7 @@
|
||||
"name": "__sub__",
|
||||
"type": {
|
||||
"_type": "FunctionType",
|
||||
"params": {
|
||||
"_type": "ParamSpec",
|
||||
"pos": [
|
||||
"pos_args": [
|
||||
{
|
||||
"name": null,
|
||||
"type": {
|
||||
@@ -2688,9 +2683,8 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"_type": "GenericType",
|
||||
"type": {
|
||||
@@ -2719,9 +2713,7 @@
|
||||
"name": "__sub__",
|
||||
"type": {
|
||||
"_type": "FunctionType",
|
||||
"params": {
|
||||
"_type": "ParamSpec",
|
||||
"pos": [
|
||||
"pos_args": [
|
||||
{
|
||||
"name": null,
|
||||
"type": {
|
||||
@@ -2731,9 +2723,8 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"_type": "GenericType",
|
||||
"type": {
|
||||
@@ -2754,24 +2745,12 @@
|
||||
{
|
||||
"_type": "PredicateStmt",
|
||||
"name": "Positive",
|
||||
"params": [
|
||||
{
|
||||
"_type": "ParamSpec",
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"name": "v",
|
||||
"subject": "v",
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw": []
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"condition": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
@@ -2787,24 +2766,12 @@
|
||||
{
|
||||
"_type": "PredicateStmt",
|
||||
"name": "StrictlyPositive",
|
||||
"params": [
|
||||
{
|
||||
"_type": "ParamSpec",
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"name": "v",
|
||||
"subject": "v",
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw": []
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"condition": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
@@ -2820,24 +2787,12 @@
|
||||
{
|
||||
"_type": "PredicateStmt",
|
||||
"name": "Equatorial",
|
||||
"params": [
|
||||
{
|
||||
"_type": "ParamSpec",
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"name": "loc",
|
||||
"subject": "loc",
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "GeoLocation"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw": []
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"condition": {
|
||||
"_type": "GroupingExpr",
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
@@ -2872,24 +2827,12 @@
|
||||
{
|
||||
"_type": "PredicateStmt",
|
||||
"name": "Arctic",
|
||||
"params": [
|
||||
{
|
||||
"_type": "ParamSpec",
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"name": "loc",
|
||||
"subject": "loc",
|
||||
"type": {
|
||||
"_type": "NamedType",
|
||||
"name": "GeoLocation"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw": []
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"condition": {
|
||||
"_type": "GroupingExpr",
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
|
||||
@@ -45,7 +45,7 @@ class GeneratorTester(Tester):
|
||||
typed_ast: TypedAST = checker.type_check(path)
|
||||
|
||||
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
|
||||
generator = Generator(workdir=path.parent, types=checker.types)
|
||||
generator = Generator(workdir=path.parent)
|
||||
result.compiled_ast = generator.generate_ast(typed_ast, path)
|
||||
|
||||
return result
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Optional, Sequence
|
||||
|
||||
from midas.ast.midas import (
|
||||
BinaryExpr,
|
||||
CallExpr,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
Expr,
|
||||
@@ -16,7 +15,6 @@ from midas.ast.midas import (
|
||||
LogicalExpr,
|
||||
MemberStmt,
|
||||
NamedType,
|
||||
ParamSpec,
|
||||
PredicateStmt,
|
||||
Stmt,
|
||||
Type,
|
||||
@@ -80,8 +78,9 @@ class MidasAstJsonSerializer(
|
||||
return {
|
||||
"_type": "PredicateStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"params": [self._serialize_param_spec(spec) for spec in stmt.params],
|
||||
"body": stmt.body.accept(self),
|
||||
"subject": stmt.subject.lexeme,
|
||||
"type": stmt.type.accept(self),
|
||||
"condition": stmt.condition.accept(self),
|
||||
}
|
||||
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
||||
@@ -107,14 +106,6 @@ class MidasAstJsonSerializer(
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_call_expr(self, expr: CallExpr) -> dict:
|
||||
return {
|
||||
"_type": "CallExpr",
|
||||
"callee": expr.callee.accept(self),
|
||||
"arguments": self._serialize_list(expr.arguments),
|
||||
"keywords": {name: arg.accept(self) for name, arg in expr.keywords.items()},
|
||||
}
|
||||
|
||||
def visit_get_expr(self, expr: GetExpr) -> dict:
|
||||
return {
|
||||
"_type": "GetExpr",
|
||||
@@ -172,21 +163,15 @@ class MidasAstJsonSerializer(
|
||||
def visit_function_type(self, type: FunctionType) -> dict:
|
||||
return {
|
||||
"_type": "FunctionType",
|
||||
"params": self._serialize_param_spec(type.params),
|
||||
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
|
||||
"args": [self._serialize_func_arg(arg) for arg in type.args],
|
||||
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
|
||||
"returns": type.returns.accept(self),
|
||||
}
|
||||
|
||||
def _serialize_param_spec(self, spec: ParamSpec) -> dict:
|
||||
return {
|
||||
"_type": "ParamSpec",
|
||||
"pos": [self._serialize_func_arg(arg) for arg in spec.pos],
|
||||
"mixed": [self._serialize_func_arg(arg) for arg in spec.mixed],
|
||||
"kw": [self._serialize_func_arg(arg) for arg in spec.kw],
|
||||
}
|
||||
|
||||
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
|
||||
return {
|
||||
"name": arg.name.lexeme if arg.name is not None else None,
|
||||
"name": arg.name,
|
||||
"type": arg.type.accept(self),
|
||||
"required": arg.required,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user