feat(parser): add mixed arguments in midas functions

This commit is contained in:
2026-06-13 13:16:24 +02:00
parent 4f5967a151
commit 77263139f6
7 changed files with 57 additions and 40 deletions

View File

@@ -135,6 +135,7 @@ class ExtensionType:
class FunctionType: class FunctionType:
pos_args: list[Argument] pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument] kw_args: list[Argument]
returns: Type returns: Type

View File

@@ -293,6 +293,7 @@ class ExtensionType(Type):
@dataclass(frozen=True) @dataclass(frozen=True)
class FunctionType(Type): class FunctionType(Type):
pos_args: list[Argument] pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument] kw_args: list[Argument]
returns: Type returns: Type

View File

@@ -297,6 +297,14 @@ class MidasAstPrinter(
self._mark_last() self._mark_last()
self._print_function_arg(arg) self._print_function_arg(arg)
self._write_line("args")
with self._child_level():
for i, arg in enumerate(type.args):
self._idx = i
if i == len(type.args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw_args") self._write_line("kw_args")
with self._child_level(): with self._child_level():
for i, arg in enumerate(type.kw_args): for i, arg in enumerate(type.kw_args):
@@ -447,11 +455,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
def visit_function_type(self, type: m.FunctionType) -> str: def visit_function_type(self, type: m.FunctionType) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args] pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args] kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
args: list[str] = pos_args args: list[str] = pos_args
if len(pos_args) != 0: if len(pos_args) != 0:
args.append("/") args.append("/")
args += mixed_args
if len(kw_args) != 0: if len(kw_args) != 0:
args.append("*") args.append("*")
args += kw_args args += kw_args

View File

@@ -164,25 +164,23 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
) )
def visit_function_type(self, type: m.FunctionType) -> Type: def visit_function_type(self, type: m.FunctionType) -> Type:
n_pos_args: int = len(type.pos_args)
n_args: int = len(type.args)
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
return Function.Argument(
pos=i,
name=arg.name.lexeme if arg.name is not None else str(i),
type=arg.type.accept(self),
required=arg.required,
)
return Function( return Function(
pos_args=[ pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
Function.Argument( args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
pos=i,
name=arg.name.lexeme if arg.name is not None else str(i),
type=arg.type.accept(self),
required=arg.required,
)
for i, arg in enumerate(type.pos_args)
],
args=[],
kw_args=[ kw_args=[
Function.Argument( process_arg(arg, i + n_pos_args + n_args)
pos=i, for i, arg in enumerate(type.kw_args)
name=arg.name.lexeme if arg.name is not None else str(i),
type=arg.type.accept(self),
required=arg.required,
)
for i, arg in enumerate(type.kw_args, start=len(type.pos_args))
], ],
returns=type.returns.accept(self), returns=type.returns.accept(self),
) )

View File

@@ -336,7 +336,7 @@ class PythonTyper(
if not self._is_binary_function(function): if not self._is_binary_function(function):
self.reporter.error( self.reporter.error(
expr.location, expr.location,
f"Wrong definition of binary operation. Expected function with 2 positional-only parameters, got {function}", f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}",
) )
return UnknownType() return UnknownType()

View File

@@ -499,15 +499,19 @@ class MidasParser(Parser):
TokenType.LEFT_PAREN, "Expected '(' before function parameters" TokenType.LEFT_PAREN, "Expected '(' before function parameters"
) )
pos_args: list[FunctionType.Argument] = [] pos_args: list[FunctionType.Argument] = []
args: list[FunctionType.Argument] = []
kw_args: list[FunctionType.Argument] = [] kw_args: list[FunctionType.Argument] = []
positional: bool = True section: int = 0
while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN): while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
if positional and ( match section:
self.match(TokenType.STAR) or self.match(TokenType.SLASH) case 0 if self.match(TokenType.SLASH):
): pos_args = args
positional = False args = []
else: section = 1
case 0 | 1 if self.match(TokenType.STAR):
section = 2
case _:
name: Optional[Token] = None name: Optional[Token] = None
if self.check_identifier() and self.check_next(TokenType.COLON): if self.check_identifier() and self.check_next(TokenType.COLON):
name = self.advance() name = self.advance()
@@ -520,13 +524,14 @@ class MidasParser(Parser):
type=type, type=type,
required=not optional, required=not optional,
) )
if positional: if section == 2:
pos_args.append(arg)
else:
kw_args.append(arg) kw_args.append(arg)
else:
args.append(arg)
if not self.match(TokenType.COMMA): if not self.match(TokenType.COMMA):
break break
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters") self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
self.consume(TokenType.ARROW, "Expected '->' before result type") self.consume(TokenType.ARROW, "Expected '->' before result type")
@@ -535,6 +540,7 @@ class MidasParser(Parser):
return FunctionType( return FunctionType(
location=l_paren.location_to(self.previous()), location=l_paren.location_to(self.previous()),
pos_args=pos_args, pos_args=pos_args,
args=args,
kw_args=kw_args, kw_args=kw_args,
returns=result, returns=result,
) )

View File

@@ -173,6 +173,7 @@ class MidasAstJsonSerializer(
return { return {
"_type": "FunctionType", "_type": "FunctionType",
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args], "pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
"args": [self._serialize_func_arg(arg) for arg in type.args],
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args], "kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
"returns": type.returns.accept(self), "returns": type.returns.accept(self),
} }