refactor: add param spec for FunctionType
This commit is contained in:
12
gen/midas.py
12
gen/midas.py
@@ -26,6 +26,14 @@ class MemberKind(Enum):
|
||||
METHOD = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
l_paren: Token
|
||||
pos: list[FunctionType.Argument]
|
||||
mixed: list[FunctionType.Argument]
|
||||
kw: list[FunctionType.Argument]
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
@@ -128,9 +136,7 @@ class ExtensionType:
|
||||
|
||||
|
||||
class FunctionType:
|
||||
pos_args: list[Argument]
|
||||
args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
params: ParamSpec
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
|
||||
@@ -27,6 +27,14 @@ class MemberKind(Enum):
|
||||
METHOD = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
l_paren: Token
|
||||
pos: list[FunctionType.Argument]
|
||||
mixed: list[FunctionType.Argument]
|
||||
kw: list[FunctionType.Argument]
|
||||
|
||||
|
||||
##############
|
||||
# Statements #
|
||||
##############
|
||||
@@ -279,9 +287,7 @@ class ExtensionType(Type):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionType(Type):
|
||||
pos_args: list[Argument]
|
||||
args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
params: ParamSpec
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
|
||||
@@ -276,34 +276,41 @@ class MidasAstPrinter(
|
||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||
self._write_line("FunctionType")
|
||||
with self._child_level():
|
||||
self._write_line("pos_args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.pos_args):
|
||||
self._idx = i
|
||||
if i == len(type.pos_args) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.args):
|
||||
self._idx = i
|
||||
if i == len(type.args) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("kw_args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.kw_args):
|
||||
self._idx = i
|
||||
if i == len(type.kw_args) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
self._write_line("params")
|
||||
with self._child_level(single=True):
|
||||
self._visit_param_spec(type.params)
|
||||
|
||||
self._write_line("returns", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.returns.accept(self)
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
|
||||
self._write_line("ParamSpec")
|
||||
with self._child_level():
|
||||
self._write_line("pos")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(spec.pos):
|
||||
self._idx = i
|
||||
if i == len(spec.pos) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("mixed")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(spec.mixed):
|
||||
self._idx = i
|
||||
if i == len(spec.mixed) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("kw", last=True)
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(spec.kw):
|
||||
self._idx = i
|
||||
if i == len(spec.kw) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
|
||||
self._write_line("Argument")
|
||||
with self._child_level():
|
||||
@@ -436,9 +443,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
|
||||
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
|
||||
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
|
||||
spec: str = self._visit_param_spec(type.params)
|
||||
return f"fn {spec} -> {type.returns.accept(self)}"
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
|
||||
pos_args: list[str] = [self._print_arg(arg) for arg in spec.pos]
|
||||
mixed_args: list[str] = [self._print_arg(arg) for arg in spec.mixed]
|
||||
kw_args: list[str] = [self._print_arg(arg) for arg in spec.kw]
|
||||
args: list[str] = pos_args
|
||||
|
||||
if len(pos_args) != 0:
|
||||
@@ -447,8 +458,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
if len(kw_args) != 0:
|
||||
args.append("*")
|
||||
args += kw_args
|
||||
|
||||
return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}"
|
||||
return f"({', '.join(args)})"
|
||||
|
||||
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
|
||||
res: str = ""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -21,6 +22,13 @@ from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TypedParamSpec:
|
||||
pos: list[Function.Argument]
|
||||
mixed: list[Function.Argument]
|
||||
kw: list[Function.Argument]
|
||||
|
||||
|
||||
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||
|
||||
@@ -172,8 +180,17 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> Type:
|
||||
n_pos_args: int = len(type.pos_args)
|
||||
n_args: int = len(type.args)
|
||||
params: TypedParamSpec = self._visit_param_spec(type.params)
|
||||
return Function(
|
||||
pos_args=params.pos,
|
||||
args=params.mixed,
|
||||
kw_args=params.kw,
|
||||
returns=type.returns.accept(self),
|
||||
)
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec:
|
||||
n_pos: int = len(spec.pos)
|
||||
n_mixed: int = len(spec.mixed)
|
||||
|
||||
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
|
||||
return Function.Argument(
|
||||
@@ -183,14 +200,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
required=arg.required,
|
||||
)
|
||||
|
||||
return Function(
|
||||
pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
|
||||
args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
|
||||
kw_args=[
|
||||
process_arg(arg, i + n_pos_args + n_args)
|
||||
for i, arg in enumerate(type.kw_args)
|
||||
],
|
||||
returns=type.returns.accept(self),
|
||||
return TypedParamSpec(
|
||||
pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)],
|
||||
mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)],
|
||||
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
|
||||
)
|
||||
|
||||
def _resolve_type_params(self, params: list[m.TypeParam]):
|
||||
|
||||
@@ -17,6 +17,7 @@ from midas.ast.midas import (
|
||||
MemberKind,
|
||||
MemberStmt,
|
||||
NamedType,
|
||||
ParamSpec,
|
||||
PredicateStmt,
|
||||
Stmt,
|
||||
Type,
|
||||
@@ -265,6 +266,9 @@ class MidasParser(Parser):
|
||||
Returns:
|
||||
Expr: the parsed constraint expression
|
||||
"""
|
||||
return self.expression()
|
||||
|
||||
def expression(self) -> Expr:
|
||||
return self.and_()
|
||||
|
||||
def and_(self) -> Expr:
|
||||
@@ -470,6 +474,18 @@ class MidasParser(Parser):
|
||||
)
|
||||
|
||||
def function(self) -> FunctionType:
|
||||
params: ParamSpec = self.function_args()
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
result: Type = self.type_expr()
|
||||
|
||||
return FunctionType(
|
||||
location=params.l_paren.location_to(self.previous()),
|
||||
params=params,
|
||||
returns=result,
|
||||
)
|
||||
|
||||
def function_args(self) -> ParamSpec:
|
||||
l_paren: Token = self.consume(
|
||||
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
||||
)
|
||||
@@ -526,14 +542,4 @@ class MidasParser(Parser):
|
||||
self.error(token, "Unnamed mixed argument")
|
||||
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
result: Type = self.type_expr()
|
||||
|
||||
return FunctionType(
|
||||
location=l_paren.location_to(self.previous()),
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=result,
|
||||
)
|
||||
return ParamSpec(l_paren=l_paren, pos=pos_args, mixed=args, kw=kw_args)
|
||||
|
||||
@@ -15,6 +15,7 @@ from midas.ast.midas import (
|
||||
LogicalExpr,
|
||||
MemberStmt,
|
||||
NamedType,
|
||||
ParamSpec,
|
||||
PredicateStmt,
|
||||
Stmt,
|
||||
Type,
|
||||
@@ -163,12 +164,18 @@ class MidasAstJsonSerializer(
|
||||
def visit_function_type(self, type: FunctionType) -> dict:
|
||||
return {
|
||||
"_type": "FunctionType",
|
||||
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
|
||||
"args": [self._serialize_func_arg(arg) for arg in type.args],
|
||||
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
|
||||
"params": self._serialize_param_spec(type.params),
|
||||
"returns": type.returns.accept(self),
|
||||
}
|
||||
|
||||
def _serialize_param_spec(self, spec: ParamSpec) -> dict:
|
||||
return {
|
||||
"_type": "ParamSpec",
|
||||
"pos": [self._serialize_func_arg(arg) for arg in spec.pos],
|
||||
"mixed": [self._serialize_func_arg(arg) for arg in spec.mixed],
|
||||
"kw": [self._serialize_func_arg(arg) for arg in spec.kw],
|
||||
}
|
||||
|
||||
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
|
||||
return {
|
||||
"name": arg.name,
|
||||
|
||||
Reference in New Issue
Block a user