From 77263139f6b12a13431d9e6bc9bc6f0d07f75da1 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 13:16:24 +0200 Subject: [PATCH] feat(parser): add mixed arguments in midas functions --- gen/midas.py | 1 + midas/ast/midas.py | 1 + midas/ast/printer.py | 10 ++++++++ midas/checker/midas.py | 32 ++++++++++++------------- midas/checker/python.py | 2 +- midas/parser/midas.py | 50 ++++++++++++++++++++++----------------- tests/serializer/midas.py | 1 + 7 files changed, 57 insertions(+), 40 deletions(-) diff --git a/gen/midas.py b/gen/midas.py index 72813d4..5405b6c 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -135,6 +135,7 @@ class ExtensionType: class FunctionType: pos_args: list[Argument] + args: list[Argument] kw_args: list[Argument] returns: Type diff --git a/midas/ast/midas.py b/midas/ast/midas.py index affd768..c35e856 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -293,6 +293,7 @@ class ExtensionType(Type): @dataclass(frozen=True) class FunctionType(Type): pos_args: list[Argument] + args: list[Argument] kw_args: list[Argument] returns: Type diff --git a/midas/ast/printer.py b/midas/ast/printer.py index c9a9d33..d094d42 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -297,6 +297,14 @@ class MidasAstPrinter( self._mark_last() self._print_function_arg(arg) + self._write_line("args") + with self._child_level(): + for i, arg in enumerate(type.args): + self._idx = i + if i == len(type.args) - 1: + self._mark_last() + self._print_function_arg(arg) + self._write_line("kw_args") with self._child_level(): for i, arg in enumerate(type.kw_args): @@ -447,11 +455,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] def visit_function_type(self, type: m.FunctionType) -> str: pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args] + mixed_args: list[str] = [self._print_arg(arg) for arg in type.args] kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args] args: list[str] = pos_args if len(pos_args) != 0: args.append("/") + args += mixed_args if len(kw_args) != 0: args.append("*") args += kw_args diff --git a/midas/checker/midas.py b/midas/checker/midas.py index 25096ae..e874e63 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -164,25 +164,23 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type ) def visit_function_type(self, type: m.FunctionType) -> Type: + n_pos_args: int = len(type.pos_args) + n_args: int = len(type.args) + + def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument: + return Function.Argument( + pos=i, + name=arg.name.lexeme if arg.name is not None else str(i), + type=arg.type.accept(self), + required=arg.required, + ) + return Function( - pos_args=[ - Function.Argument( - pos=i, - name=arg.name.lexeme if arg.name is not None else str(i), - type=arg.type.accept(self), - required=arg.required, - ) - for i, arg in enumerate(type.pos_args) - ], - args=[], + pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)], + args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)], kw_args=[ - Function.Argument( - pos=i, - name=arg.name.lexeme if arg.name is not None else str(i), - type=arg.type.accept(self), - required=arg.required, - ) - for i, arg in enumerate(type.kw_args, start=len(type.pos_args)) + process_arg(arg, i + n_pos_args + n_args) + for i, arg in enumerate(type.kw_args) ], returns=type.returns.accept(self), ) diff --git a/midas/checker/python.py b/midas/checker/python.py index 7d10392..57e6687 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -336,7 +336,7 @@ class PythonTyper( if not self._is_binary_function(function): self.reporter.error( expr.location, - f"Wrong definition of binary operation. Expected function with 2 positional-only parameters, got {function}", + f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}", ) return UnknownType() diff --git a/midas/parser/midas.py b/midas/parser/midas.py index ce94b2d..06f44a4 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -499,34 +499,39 @@ class MidasParser(Parser): TokenType.LEFT_PAREN, "Expected '(' before function parameters" ) pos_args: list[FunctionType.Argument] = [] + args: list[FunctionType.Argument] = [] kw_args: list[FunctionType.Argument] = [] - positional: bool = True + section: int = 0 while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN): - if positional and ( - self.match(TokenType.STAR) or self.match(TokenType.SLASH) - ): - positional = False - else: - name: Optional[Token] = None - if self.check_identifier() and self.check_next(TokenType.COLON): - name = self.advance() - self.advance() - type: Type = self.type_expr() - optional: bool = self.match(TokenType.QMARK) - arg = FunctionType.Argument( - location=None, - name=name, - type=type, - required=not optional, - ) - if positional: - pos_args.append(arg) - else: - kw_args.append(arg) + match section: + case 0 if self.match(TokenType.SLASH): + pos_args = args + args = [] + section = 1 + case 0 | 1 if self.match(TokenType.STAR): + section = 2 + case _: + name: Optional[Token] = None + if self.check_identifier() and self.check_next(TokenType.COLON): + name = self.advance() + self.advance() + type: Type = self.type_expr() + optional: bool = self.match(TokenType.QMARK) + arg = FunctionType.Argument( + location=None, + name=name, + type=type, + required=not optional, + ) + if section == 2: + kw_args.append(arg) + else: + args.append(arg) if not self.match(TokenType.COMMA): break + self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters") self.consume(TokenType.ARROW, "Expected '->' before result type") @@ -535,6 +540,7 @@ class MidasParser(Parser): return FunctionType( location=l_paren.location_to(self.previous()), pos_args=pos_args, + args=args, kw_args=kw_args, returns=result, ) diff --git a/tests/serializer/midas.py b/tests/serializer/midas.py index af4c0b1..f1e55da 100644 --- a/tests/serializer/midas.py +++ b/tests/serializer/midas.py @@ -173,6 +173,7 @@ class MidasAstJsonSerializer( return { "_type": "FunctionType", "pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args], + "args": [self._serialize_func_arg(arg) for arg in type.args], "kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args], "returns": type.returns.accept(self), }