feat(parser): add mixed arguments in midas functions
This commit is contained in:
@@ -135,6 +135,7 @@ class ExtensionType:
|
||||
|
||||
class FunctionType:
|
||||
pos_args: list[Argument]
|
||||
args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
returns: Type
|
||||
|
||||
|
||||
@@ -293,6 +293,7 @@ class ExtensionType(Type):
|
||||
@dataclass(frozen=True)
|
||||
class FunctionType(Type):
|
||||
pos_args: list[Argument]
|
||||
args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
returns: Type
|
||||
|
||||
|
||||
@@ -297,6 +297,14 @@ class MidasAstPrinter(
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.args):
|
||||
self._idx = i
|
||||
if i == len(type.args) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("kw_args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.kw_args):
|
||||
@@ -447,11 +455,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
|
||||
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
|
||||
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
|
||||
args: list[str] = pos_args
|
||||
|
||||
if len(pos_args) != 0:
|
||||
args.append("/")
|
||||
args += mixed_args
|
||||
if len(kw_args) != 0:
|
||||
args.append("*")
|
||||
args += kw_args
|
||||
|
||||
@@ -164,25 +164,23 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> Type:
|
||||
n_pos_args: int = len(type.pos_args)
|
||||
n_args: int = len(type.args)
|
||||
|
||||
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
|
||||
return Function.Argument(
|
||||
pos=i,
|
||||
name=arg.name.lexeme if arg.name is not None else str(i),
|
||||
type=arg.type.accept(self),
|
||||
required=arg.required,
|
||||
)
|
||||
|
||||
return Function(
|
||||
pos_args=[
|
||||
Function.Argument(
|
||||
pos=i,
|
||||
name=arg.name.lexeme if arg.name is not None else str(i),
|
||||
type=arg.type.accept(self),
|
||||
required=arg.required,
|
||||
)
|
||||
for i, arg in enumerate(type.pos_args)
|
||||
],
|
||||
args=[],
|
||||
pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
|
||||
args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=i,
|
||||
name=arg.name.lexeme if arg.name is not None else str(i),
|
||||
type=arg.type.accept(self),
|
||||
required=arg.required,
|
||||
)
|
||||
for i, arg in enumerate(type.kw_args, start=len(type.pos_args))
|
||||
process_arg(arg, i + n_pos_args + n_args)
|
||||
for i, arg in enumerate(type.kw_args)
|
||||
],
|
||||
returns=type.returns.accept(self),
|
||||
)
|
||||
|
||||
@@ -336,7 +336,7 @@ class PythonTyper(
|
||||
if not self._is_binary_function(function):
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Wrong definition of binary operation. Expected function with 2 positional-only parameters, got {function}",
|
||||
f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
|
||||
@@ -499,34 +499,39 @@ class MidasParser(Parser):
|
||||
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
||||
)
|
||||
pos_args: list[FunctionType.Argument] = []
|
||||
args: list[FunctionType.Argument] = []
|
||||
kw_args: list[FunctionType.Argument] = []
|
||||
|
||||
positional: bool = True
|
||||
section: int = 0
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
|
||||
if positional and (
|
||||
self.match(TokenType.STAR) or self.match(TokenType.SLASH)
|
||||
):
|
||||
positional = False
|
||||
else:
|
||||
name: Optional[Token] = None
|
||||
if self.check_identifier() and self.check_next(TokenType.COLON):
|
||||
name = self.advance()
|
||||
self.advance()
|
||||
type: Type = self.type_expr()
|
||||
optional: bool = self.match(TokenType.QMARK)
|
||||
arg = FunctionType.Argument(
|
||||
location=None,
|
||||
name=name,
|
||||
type=type,
|
||||
required=not optional,
|
||||
)
|
||||
if positional:
|
||||
pos_args.append(arg)
|
||||
else:
|
||||
kw_args.append(arg)
|
||||
match section:
|
||||
case 0 if self.match(TokenType.SLASH):
|
||||
pos_args = args
|
||||
args = []
|
||||
section = 1
|
||||
case 0 | 1 if self.match(TokenType.STAR):
|
||||
section = 2
|
||||
case _:
|
||||
name: Optional[Token] = None
|
||||
if self.check_identifier() and self.check_next(TokenType.COLON):
|
||||
name = self.advance()
|
||||
self.advance()
|
||||
type: Type = self.type_expr()
|
||||
optional: bool = self.match(TokenType.QMARK)
|
||||
arg = FunctionType.Argument(
|
||||
location=None,
|
||||
name=name,
|
||||
type=type,
|
||||
required=not optional,
|
||||
)
|
||||
if section == 2:
|
||||
kw_args.append(arg)
|
||||
else:
|
||||
args.append(arg)
|
||||
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
@@ -535,6 +540,7 @@ class MidasParser(Parser):
|
||||
return FunctionType(
|
||||
location=l_paren.location_to(self.previous()),
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=result,
|
||||
)
|
||||
|
||||
@@ -173,6 +173,7 @@ class MidasAstJsonSerializer(
|
||||
return {
|
||||
"_type": "FunctionType",
|
||||
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
|
||||
"args": [self._serialize_func_arg(arg) for arg in type.args],
|
||||
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
|
||||
"returns": type.returns.accept(self),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user