feat(parser): add mixed arguments in midas functions

This commit is contained in:
2026-06-13 13:16:24 +02:00
parent 4f5967a151
commit 77263139f6
7 changed files with 57 additions and 40 deletions

View File

@@ -135,6 +135,7 @@ class ExtensionType:
class FunctionType:
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
returns: Type

View File

@@ -293,6 +293,7 @@ class ExtensionType(Type):
@dataclass(frozen=True)
class FunctionType(Type):
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
returns: Type

View File

@@ -297,6 +297,14 @@ class MidasAstPrinter(
self._mark_last()
self._print_function_arg(arg)
self._write_line("args")
with self._child_level():
for i, arg in enumerate(type.args):
self._idx = i
if i == len(type.args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw_args")
with self._child_level():
for i, arg in enumerate(type.kw_args):
@@ -447,11 +455,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
def visit_function_type(self, type: m.FunctionType) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
args: list[str] = pos_args
if len(pos_args) != 0:
args.append("/")
args += mixed_args
if len(kw_args) != 0:
args.append("*")
args += kw_args

View File

@@ -164,25 +164,23 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
)
def visit_function_type(self, type: m.FunctionType) -> Type:
n_pos_args: int = len(type.pos_args)
n_args: int = len(type.args)
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
return Function.Argument(
pos=i,
name=arg.name.lexeme if arg.name is not None else str(i),
type=arg.type.accept(self),
required=arg.required,
)
return Function(
pos_args=[
Function.Argument(
pos=i,
name=arg.name.lexeme if arg.name is not None else str(i),
type=arg.type.accept(self),
required=arg.required,
)
for i, arg in enumerate(type.pos_args)
],
args=[],
pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
kw_args=[
Function.Argument(
pos=i,
name=arg.name.lexeme if arg.name is not None else str(i),
type=arg.type.accept(self),
required=arg.required,
)
for i, arg in enumerate(type.kw_args, start=len(type.pos_args))
process_arg(arg, i + n_pos_args + n_args)
for i, arg in enumerate(type.kw_args)
],
returns=type.returns.accept(self),
)

View File

@@ -336,7 +336,7 @@ class PythonTyper(
if not self._is_binary_function(function):
self.reporter.error(
expr.location,
f"Wrong definition of binary operation. Expected function with 2 positional-only parameters, got {function}",
f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}",
)
return UnknownType()

View File

@@ -499,34 +499,39 @@ class MidasParser(Parser):
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
)
pos_args: list[FunctionType.Argument] = []
args: list[FunctionType.Argument] = []
kw_args: list[FunctionType.Argument] = []
positional: bool = True
section: int = 0
while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
if positional and (
self.match(TokenType.STAR) or self.match(TokenType.SLASH)
):
positional = False
else:
name: Optional[Token] = None
if self.check_identifier() and self.check_next(TokenType.COLON):
name = self.advance()
self.advance()
type: Type = self.type_expr()
optional: bool = self.match(TokenType.QMARK)
arg = FunctionType.Argument(
location=None,
name=name,
type=type,
required=not optional,
)
if positional:
pos_args.append(arg)
else:
kw_args.append(arg)
match section:
case 0 if self.match(TokenType.SLASH):
pos_args = args
args = []
section = 1
case 0 | 1 if self.match(TokenType.STAR):
section = 2
case _:
name: Optional[Token] = None
if self.check_identifier() and self.check_next(TokenType.COLON):
name = self.advance()
self.advance()
type: Type = self.type_expr()
optional: bool = self.match(TokenType.QMARK)
arg = FunctionType.Argument(
location=None,
name=name,
type=type,
required=not optional,
)
if section == 2:
kw_args.append(arg)
else:
args.append(arg)
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
self.consume(TokenType.ARROW, "Expected '->' before result type")
@@ -535,6 +540,7 @@ class MidasParser(Parser):
return FunctionType(
location=l_paren.location_to(self.previous()),
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=result,
)

View File

@@ -173,6 +173,7 @@ class MidasAstJsonSerializer(
return {
"_type": "FunctionType",
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
"args": [self._serialize_func_arg(arg) for arg in type.args],
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
"returns": type.returns.accept(self),
}