diff --git a/gen/python.py b/gen/python.py index 76b68a4..91c6058 100644 --- a/gen/python.py +++ b/gen/python.py @@ -44,7 +44,9 @@ class Function: name: str posonlyargs: list[Argument] args: list[Argument] + sink: Optional[Argument] kwonlyargs: list[Argument] + kw_sink: Optional[Argument] returns: Optional[MidasType] body: list[Stmt] @@ -53,6 +55,7 @@ class Function: location: Optional[Location] = None name: str type: Optional[MidasType] + default: Optional[Expr] @property def all_args(self) -> list[Argument]: diff --git a/midas/ast/python.py b/midas/ast/python.py index 7c22b2f..dee4aa7 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -117,7 +117,9 @@ class Function(Stmt): name: str posonlyargs: list[Argument] args: list[Argument] + sink: Optional[Argument] kwonlyargs: list[Argument] + kw_sink: Optional[Argument] returns: Optional[MidasType] body: list[Stmt] @@ -126,6 +128,7 @@ class Function(Stmt): location: Optional[Location] = None name: str type: Optional[MidasType] + default: Optional[Expr] @property def all_args(self) -> list[Argument]: diff --git a/midas/parser/python.py b/midas/parser/python.py index 36d6efb..d580bda 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -136,14 +136,23 @@ class PythonParser: args=ast.arguments( posonlyargs=posonlyargs, args=args, + vararg=sink, kwonlyargs=kwonlyargs, + kwarg=kw_sink, + defaults=defaults, + kw_defaults=kw_defaults, ), returns=returns, body=raw_body, ): - def parse_args(args_list: list[ast.arg]) -> list[Function.Argument]: - return [self._parse_function_argument(arg) for arg in args_list] + def parse_args( + args_list: list[ast.arg], defaults: list[Optional[Expr]] + ) -> list[Function.Argument]: + return [ + self._parse_function_argument(arg, default) + for arg, default in zip(args_list, defaults) + ] body: list[Stmt] = [] for stmt in raw_body: @@ -152,19 +161,49 @@ class PythonParser: body.append(stmts) elif stmts is not None: body.extend(stmts) + + parsed_defaults: list[Optional[Expr]] = [ + self.parse_expr(default) for default in defaults + ] + n_posargs: int = len(posonlyargs) + n_args: int = len(args) + n_all_posargs = n_posargs + n_args + parsed_defaults = [ + None, + ] * (n_all_posargs - len(defaults)) + parsed_defaults + + posargs_defaults: list[Optional[Expr]] = parsed_defaults[:n_posargs] + args_defaults: list[Optional[Expr]] = parsed_defaults[n_posargs:] + kwargs_defaults: list[Optional[Expr]] = [ + self.parse_expr(default) if default is not None else None + for default in kw_defaults + ] + return Function( location=loc, name=name, - posonlyargs=parse_args(posonlyargs), - args=parse_args(args), - kwonlyargs=parse_args(kwonlyargs), + posonlyargs=parse_args(posonlyargs, posargs_defaults), + args=parse_args(args, args_defaults), + sink=( + self._parse_function_argument(sink, None) + if sink is not None + else None + ), + kwonlyargs=parse_args(kwonlyargs, kwargs_defaults), + kw_sink=( + self._parse_function_argument(kw_sink, None) + if kw_sink is not None + else None + ), returns=self._parse_type(returns) if returns is not None else None, body=body, ) case _: print(f"Unsupported function definition: {ast.unparse(node)}") - def _parse_function_argument(self, arg: ast.arg) -> Function.Argument: + def _parse_function_argument( + self, arg: ast.arg, default: Optional[Expr] + ) -> Function.Argument: loc: Location = Location.from_ast(arg) name: str = arg.arg type: Optional[MidasType] = None @@ -174,6 +213,7 @@ class PythonParser: location=loc, name=name, type=type, + default=default, ) def _parse_type(