diff --git a/gen/midas.py b/gen/midas.py index 42caf4f..43c2f2c 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -26,6 +26,14 @@ class MemberKind(Enum): METHOD = auto() +@dataclass(frozen=True, kw_only=True) +class ParamSpec: + l_paren: Token + pos: list[FunctionType.Argument] + mixed: list[FunctionType.Argument] + kw: list[FunctionType.Argument] + + ###< @@ -128,9 +136,7 @@ class ExtensionType: class FunctionType: - pos_args: list[Argument] - args: list[Argument] - kw_args: list[Argument] + params: ParamSpec returns: Type @dataclass(frozen=True, kw_only=True) diff --git a/midas/ast/midas.py b/midas/ast/midas.py index e71aff9..f052c2b 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -27,6 +27,14 @@ class MemberKind(Enum): METHOD = auto() +@dataclass(frozen=True, kw_only=True) +class ParamSpec: + l_paren: Token + pos: list[FunctionType.Argument] + mixed: list[FunctionType.Argument] + kw: list[FunctionType.Argument] + + ############## # Statements # ############## @@ -279,9 +287,7 @@ class ExtensionType(Type): @dataclass(frozen=True) class FunctionType(Type): - pos_args: list[Argument] - args: list[Argument] - kw_args: list[Argument] + params: ParamSpec returns: Type @dataclass(frozen=True, kw_only=True) diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 694c272..40968ab 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -276,34 +276,41 @@ class MidasAstPrinter( def visit_function_type(self, type: m.FunctionType) -> None: self._write_line("FunctionType") with self._child_level(): - self._write_line("pos_args") - with self._child_level(): - for i, arg in enumerate(type.pos_args): - self._idx = i - if i == len(type.pos_args) - 1: - self._mark_last() - self._print_function_arg(arg) - - self._write_line("args") - with self._child_level(): - for i, arg in enumerate(type.args): - self._idx = i - if i == len(type.args) - 1: - self._mark_last() - self._print_function_arg(arg) - - self._write_line("kw_args") - with self._child_level(): - for i, arg in enumerate(type.kw_args): - self._idx = i - if i == len(type.kw_args) - 1: - self._mark_last() - self._print_function_arg(arg) + self._write_line("params") + with self._child_level(single=True): + self._visit_param_spec(type.params) self._write_line("returns", last=True) with self._child_level(single=True): type.returns.accept(self) + def _visit_param_spec(self, spec: m.ParamSpec) -> None: + self._write_line("ParamSpec") + with self._child_level(): + self._write_line("pos") + with self._child_level(): + for i, arg in enumerate(spec.pos): + self._idx = i + if i == len(spec.pos) - 1: + self._mark_last() + self._print_function_arg(arg) + + self._write_line("mixed") + with self._child_level(): + for i, arg in enumerate(spec.mixed): + self._idx = i + if i == len(spec.mixed) - 1: + self._mark_last() + self._print_function_arg(arg) + + self._write_line("kw", last=True) + with self._child_level(): + for i, arg in enumerate(spec.kw): + self._idx = i + if i == len(spec.kw) - 1: + self._mark_last() + self._print_function_arg(arg) + def _print_function_arg(self, arg: m.FunctionType.Argument) -> None: self._write_line("Argument") with self._child_level(): @@ -436,9 +443,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] return f"{type.base.accept(self)} & {type.extension.accept(self)}" def visit_function_type(self, type: m.FunctionType) -> str: - pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args] - mixed_args: list[str] = [self._print_arg(arg) for arg in type.args] - kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args] + spec: str = self._visit_param_spec(type.params) + return f"fn {spec} -> {type.returns.accept(self)}" + + def _visit_param_spec(self, spec: m.ParamSpec) -> str: + pos_args: list[str] = [self._print_arg(arg) for arg in spec.pos] + mixed_args: list[str] = [self._print_arg(arg) for arg in spec.mixed] + kw_args: list[str] = [self._print_arg(arg) for arg in spec.kw] args: list[str] = pos_args if len(pos_args) != 0: @@ -447,8 +458,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] if len(kw_args) != 0: args.append("*") args += kw_args - - return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}" + return f"({', '.join(args)})" def _print_arg(self, arg: m.FunctionType.Argument) -> str: res: str = "" diff --git a/midas/checker/midas.py b/midas/checker/midas.py index 3764c03..a0ee051 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -1,4 +1,5 @@ import logging +from dataclasses import dataclass from pathlib import Path from typing import Optional @@ -21,6 +22,13 @@ from midas.lexer.token import Token from midas.parser.midas import MidasParser +@dataclass(frozen=True, kw_only=True) +class TypedParamSpec: + pos: list[Function.Argument] + mixed: list[Function.Argument] + kw: list[Function.Argument] + + class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]): """A resolver which evaluates Midas type definitions and build a registry""" @@ -172,8 +180,17 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type ) def visit_function_type(self, type: m.FunctionType) -> Type: - n_pos_args: int = len(type.pos_args) - n_args: int = len(type.args) + params: TypedParamSpec = self._visit_param_spec(type.params) + return Function( + pos_args=params.pos, + args=params.mixed, + kw_args=params.kw, + returns=type.returns.accept(self), + ) + + def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec: + n_pos: int = len(spec.pos) + n_mixed: int = len(spec.mixed) def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument: return Function.Argument( @@ -183,14 +200,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type required=arg.required, ) - return Function( - pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)], - args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)], - kw_args=[ - process_arg(arg, i + n_pos_args + n_args) - for i, arg in enumerate(type.kw_args) - ], - returns=type.returns.accept(self), + return TypedParamSpec( + pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)], + mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)], + kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)], ) def _resolve_type_params(self, params: list[m.TypeParam]): diff --git a/midas/parser/midas.py b/midas/parser/midas.py index 33069f3..1656c68 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -17,6 +17,7 @@ from midas.ast.midas import ( MemberKind, MemberStmt, NamedType, + ParamSpec, PredicateStmt, Stmt, Type, @@ -265,6 +266,9 @@ class MidasParser(Parser): Returns: Expr: the parsed constraint expression """ + return self.expression() + + def expression(self) -> Expr: return self.and_() def and_(self) -> Expr: @@ -470,6 +474,18 @@ class MidasParser(Parser): ) def function(self) -> FunctionType: + params: ParamSpec = self.function_args() + + self.consume(TokenType.ARROW, "Expected '->' before result type") + result: Type = self.type_expr() + + return FunctionType( + location=params.l_paren.location_to(self.previous()), + params=params, + returns=result, + ) + + def function_args(self) -> ParamSpec: l_paren: Token = self.consume( TokenType.LEFT_PAREN, "Expected '(' before function parameters" ) @@ -526,14 +542,4 @@ class MidasParser(Parser): self.error(token, "Unnamed mixed argument") self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters") - - self.consume(TokenType.ARROW, "Expected '->' before result type") - result: Type = self.type_expr() - - return FunctionType( - location=l_paren.location_to(self.previous()), - pos_args=pos_args, - args=args, - kw_args=kw_args, - returns=result, - ) + return ParamSpec(l_paren=l_paren, pos=pos_args, mixed=args, kw=kw_args) diff --git a/tests/serializer/midas.py b/tests/serializer/midas.py index 8bffdb3..87ebce8 100644 --- a/tests/serializer/midas.py +++ b/tests/serializer/midas.py @@ -15,6 +15,7 @@ from midas.ast.midas import ( LogicalExpr, MemberStmt, NamedType, + ParamSpec, PredicateStmt, Stmt, Type, @@ -163,12 +164,18 @@ class MidasAstJsonSerializer( def visit_function_type(self, type: FunctionType) -> dict: return { "_type": "FunctionType", - "pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args], - "args": [self._serialize_func_arg(arg) for arg in type.args], - "kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args], + "params": self._serialize_param_spec(type.params), "returns": type.returns.accept(self), } + def _serialize_param_spec(self, spec: ParamSpec) -> dict: + return { + "_type": "ParamSpec", + "pos": [self._serialize_func_arg(arg) for arg in spec.pos], + "mixed": [self._serialize_func_arg(arg) for arg in spec.mixed], + "kw": [self._serialize_func_arg(arg) for arg in spec.kw], + } + def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict: return { "name": arg.name,