refactor: add param spec for FunctionType
This commit is contained in:
12
gen/midas.py
12
gen/midas.py
@@ -26,6 +26,14 @@ class MemberKind(Enum):
|
|||||||
METHOD = auto()
|
METHOD = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class ParamSpec:
|
||||||
|
l_paren: Token
|
||||||
|
pos: list[FunctionType.Argument]
|
||||||
|
mixed: list[FunctionType.Argument]
|
||||||
|
kw: list[FunctionType.Argument]
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
|
|
||||||
|
|
||||||
@@ -128,9 +136,7 @@ class ExtensionType:
|
|||||||
|
|
||||||
|
|
||||||
class FunctionType:
|
class FunctionType:
|
||||||
pos_args: list[Argument]
|
params: ParamSpec
|
||||||
args: list[Argument]
|
|
||||||
kw_args: list[Argument]
|
|
||||||
returns: Type
|
returns: Type
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
|||||||
@@ -27,6 +27,14 @@ class MemberKind(Enum):
|
|||||||
METHOD = auto()
|
METHOD = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class ParamSpec:
|
||||||
|
l_paren: Token
|
||||||
|
pos: list[FunctionType.Argument]
|
||||||
|
mixed: list[FunctionType.Argument]
|
||||||
|
kw: list[FunctionType.Argument]
|
||||||
|
|
||||||
|
|
||||||
##############
|
##############
|
||||||
# Statements #
|
# Statements #
|
||||||
##############
|
##############
|
||||||
@@ -279,9 +287,7 @@ class ExtensionType(Type):
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class FunctionType(Type):
|
class FunctionType(Type):
|
||||||
pos_args: list[Argument]
|
params: ParamSpec
|
||||||
args: list[Argument]
|
|
||||||
kw_args: list[Argument]
|
|
||||||
returns: Type
|
returns: Type
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
|||||||
@@ -276,34 +276,41 @@ class MidasAstPrinter(
|
|||||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||||
self._write_line("FunctionType")
|
self._write_line("FunctionType")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
self._write_line("pos_args")
|
self._write_line("params")
|
||||||
with self._child_level():
|
with self._child_level(single=True):
|
||||||
for i, arg in enumerate(type.pos_args):
|
self._visit_param_spec(type.params)
|
||||||
self._idx = i
|
|
||||||
if i == len(type.pos_args) - 1:
|
|
||||||
self._mark_last()
|
|
||||||
self._print_function_arg(arg)
|
|
||||||
|
|
||||||
self._write_line("args")
|
|
||||||
with self._child_level():
|
|
||||||
for i, arg in enumerate(type.args):
|
|
||||||
self._idx = i
|
|
||||||
if i == len(type.args) - 1:
|
|
||||||
self._mark_last()
|
|
||||||
self._print_function_arg(arg)
|
|
||||||
|
|
||||||
self._write_line("kw_args")
|
|
||||||
with self._child_level():
|
|
||||||
for i, arg in enumerate(type.kw_args):
|
|
||||||
self._idx = i
|
|
||||||
if i == len(type.kw_args) - 1:
|
|
||||||
self._mark_last()
|
|
||||||
self._print_function_arg(arg)
|
|
||||||
|
|
||||||
self._write_line("returns", last=True)
|
self._write_line("returns", last=True)
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
type.returns.accept(self)
|
type.returns.accept(self)
|
||||||
|
|
||||||
|
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
|
||||||
|
self._write_line("ParamSpec")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("pos")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(spec.pos):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(spec.pos) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_function_arg(arg)
|
||||||
|
|
||||||
|
self._write_line("mixed")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(spec.mixed):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(spec.mixed) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_function_arg(arg)
|
||||||
|
|
||||||
|
self._write_line("kw", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(spec.kw):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(spec.kw) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_function_arg(arg)
|
||||||
|
|
||||||
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
|
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
|
||||||
self._write_line("Argument")
|
self._write_line("Argument")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
@@ -436,9 +443,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
|
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
|
||||||
|
|
||||||
def visit_function_type(self, type: m.FunctionType) -> str:
|
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||||
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
|
spec: str = self._visit_param_spec(type.params)
|
||||||
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
|
return f"fn {spec} -> {type.returns.accept(self)}"
|
||||||
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
|
|
||||||
|
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
|
||||||
|
pos_args: list[str] = [self._print_arg(arg) for arg in spec.pos]
|
||||||
|
mixed_args: list[str] = [self._print_arg(arg) for arg in spec.mixed]
|
||||||
|
kw_args: list[str] = [self._print_arg(arg) for arg in spec.kw]
|
||||||
args: list[str] = pos_args
|
args: list[str] = pos_args
|
||||||
|
|
||||||
if len(pos_args) != 0:
|
if len(pos_args) != 0:
|
||||||
@@ -447,8 +458,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
if len(kw_args) != 0:
|
if len(kw_args) != 0:
|
||||||
args.append("*")
|
args.append("*")
|
||||||
args += kw_args
|
args += kw_args
|
||||||
|
return f"({', '.join(args)})"
|
||||||
return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}"
|
|
||||||
|
|
||||||
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
|
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
|
||||||
res: str = ""
|
res: str = ""
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -21,6 +22,13 @@ from midas.lexer.token import Token
|
|||||||
from midas.parser.midas import MidasParser
|
from midas.parser.midas import MidasParser
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class TypedParamSpec:
|
||||||
|
pos: list[Function.Argument]
|
||||||
|
mixed: list[Function.Argument]
|
||||||
|
kw: list[Function.Argument]
|
||||||
|
|
||||||
|
|
||||||
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
||||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||||
|
|
||||||
@@ -172,8 +180,17 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
)
|
)
|
||||||
|
|
||||||
def visit_function_type(self, type: m.FunctionType) -> Type:
|
def visit_function_type(self, type: m.FunctionType) -> Type:
|
||||||
n_pos_args: int = len(type.pos_args)
|
params: TypedParamSpec = self._visit_param_spec(type.params)
|
||||||
n_args: int = len(type.args)
|
return Function(
|
||||||
|
pos_args=params.pos,
|
||||||
|
args=params.mixed,
|
||||||
|
kw_args=params.kw,
|
||||||
|
returns=type.returns.accept(self),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec:
|
||||||
|
n_pos: int = len(spec.pos)
|
||||||
|
n_mixed: int = len(spec.mixed)
|
||||||
|
|
||||||
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
|
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
|
||||||
return Function.Argument(
|
return Function.Argument(
|
||||||
@@ -183,14 +200,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
required=arg.required,
|
required=arg.required,
|
||||||
)
|
)
|
||||||
|
|
||||||
return Function(
|
return TypedParamSpec(
|
||||||
pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
|
pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)],
|
||||||
args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
|
mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)],
|
||||||
kw_args=[
|
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
|
||||||
process_arg(arg, i + n_pos_args + n_args)
|
|
||||||
for i, arg in enumerate(type.kw_args)
|
|
||||||
],
|
|
||||||
returns=type.returns.accept(self),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _resolve_type_params(self, params: list[m.TypeParam]):
|
def _resolve_type_params(self, params: list[m.TypeParam]):
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from midas.ast.midas import (
|
|||||||
MemberKind,
|
MemberKind,
|
||||||
MemberStmt,
|
MemberStmt,
|
||||||
NamedType,
|
NamedType,
|
||||||
|
ParamSpec,
|
||||||
PredicateStmt,
|
PredicateStmt,
|
||||||
Stmt,
|
Stmt,
|
||||||
Type,
|
Type,
|
||||||
@@ -265,6 +266,9 @@ class MidasParser(Parser):
|
|||||||
Returns:
|
Returns:
|
||||||
Expr: the parsed constraint expression
|
Expr: the parsed constraint expression
|
||||||
"""
|
"""
|
||||||
|
return self.expression()
|
||||||
|
|
||||||
|
def expression(self) -> Expr:
|
||||||
return self.and_()
|
return self.and_()
|
||||||
|
|
||||||
def and_(self) -> Expr:
|
def and_(self) -> Expr:
|
||||||
@@ -470,6 +474,18 @@ class MidasParser(Parser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def function(self) -> FunctionType:
|
def function(self) -> FunctionType:
|
||||||
|
params: ParamSpec = self.function_args()
|
||||||
|
|
||||||
|
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||||
|
result: Type = self.type_expr()
|
||||||
|
|
||||||
|
return FunctionType(
|
||||||
|
location=params.l_paren.location_to(self.previous()),
|
||||||
|
params=params,
|
||||||
|
returns=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
def function_args(self) -> ParamSpec:
|
||||||
l_paren: Token = self.consume(
|
l_paren: Token = self.consume(
|
||||||
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
||||||
)
|
)
|
||||||
@@ -526,14 +542,4 @@ class MidasParser(Parser):
|
|||||||
self.error(token, "Unnamed mixed argument")
|
self.error(token, "Unnamed mixed argument")
|
||||||
|
|
||||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
||||||
|
return ParamSpec(l_paren=l_paren, pos=pos_args, mixed=args, kw=kw_args)
|
||||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
|
||||||
result: Type = self.type_expr()
|
|
||||||
|
|
||||||
return FunctionType(
|
|
||||||
location=l_paren.location_to(self.previous()),
|
|
||||||
pos_args=pos_args,
|
|
||||||
args=args,
|
|
||||||
kw_args=kw_args,
|
|
||||||
returns=result,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from midas.ast.midas import (
|
|||||||
LogicalExpr,
|
LogicalExpr,
|
||||||
MemberStmt,
|
MemberStmt,
|
||||||
NamedType,
|
NamedType,
|
||||||
|
ParamSpec,
|
||||||
PredicateStmt,
|
PredicateStmt,
|
||||||
Stmt,
|
Stmt,
|
||||||
Type,
|
Type,
|
||||||
@@ -163,12 +164,18 @@ class MidasAstJsonSerializer(
|
|||||||
def visit_function_type(self, type: FunctionType) -> dict:
|
def visit_function_type(self, type: FunctionType) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "FunctionType",
|
"_type": "FunctionType",
|
||||||
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
|
"params": self._serialize_param_spec(type.params),
|
||||||
"args": [self._serialize_func_arg(arg) for arg in type.args],
|
|
||||||
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
|
|
||||||
"returns": type.returns.accept(self),
|
"returns": type.returns.accept(self),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _serialize_param_spec(self, spec: ParamSpec) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "ParamSpec",
|
||||||
|
"pos": [self._serialize_func_arg(arg) for arg in spec.pos],
|
||||||
|
"mixed": [self._serialize_func_arg(arg) for arg in spec.mixed],
|
||||||
|
"kw": [self._serialize_func_arg(arg) for arg in spec.kw],
|
||||||
|
}
|
||||||
|
|
||||||
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
|
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
|
||||||
return {
|
return {
|
||||||
"name": arg.name,
|
"name": arg.name,
|
||||||
|
|||||||
Reference in New Issue
Block a user