feat(parser): parse function param defaults and sinks

This commit is contained in:
2026-05-29 15:47:19 +02:00
parent fd5399f50a
commit 3f61f84e5a
3 changed files with 52 additions and 6 deletions

View File

@@ -44,7 +44,9 @@ class Function:
name: str
posonlyargs: list[Argument]
args: list[Argument]
sink: Optional[Argument]
kwonlyargs: list[Argument]
kw_sink: Optional[Argument]
returns: Optional[MidasType]
body: list[Stmt]
@@ -53,6 +55,7 @@ class Function:
location: Optional[Location] = None
name: str
type: Optional[MidasType]
default: Optional[Expr]
@property
def all_args(self) -> list[Argument]:

View File

@@ -117,7 +117,9 @@ class Function(Stmt):
name: str
posonlyargs: list[Argument]
args: list[Argument]
sink: Optional[Argument]
kwonlyargs: list[Argument]
kw_sink: Optional[Argument]
returns: Optional[MidasType]
body: list[Stmt]
@@ -126,6 +128,7 @@ class Function(Stmt):
location: Optional[Location] = None
name: str
type: Optional[MidasType]
default: Optional[Expr]
@property
def all_args(self) -> list[Argument]:

View File

@@ -136,14 +136,23 @@ class PythonParser:
args=ast.arguments(
posonlyargs=posonlyargs,
args=args,
vararg=sink,
kwonlyargs=kwonlyargs,
kwarg=kw_sink,
defaults=defaults,
kw_defaults=kw_defaults,
),
returns=returns,
body=raw_body,
):
def parse_args(args_list: list[ast.arg]) -> list[Function.Argument]:
return [self._parse_function_argument(arg) for arg in args_list]
def parse_args(
args_list: list[ast.arg], defaults: list[Optional[Expr]]
) -> list[Function.Argument]:
return [
self._parse_function_argument(arg, default)
for arg, default in zip(args_list, defaults)
]
body: list[Stmt] = []
for stmt in raw_body:
@@ -152,19 +161,49 @@ class PythonParser:
body.append(stmts)
elif stmts is not None:
body.extend(stmts)
parsed_defaults: list[Optional[Expr]] = [
self.parse_expr(default) for default in defaults
]
n_posargs: int = len(posonlyargs)
n_args: int = len(args)
n_all_posargs = n_posargs + n_args
parsed_defaults = [
None,
] * (n_all_posargs - len(defaults)) + parsed_defaults
posargs_defaults: list[Optional[Expr]] = parsed_defaults[:n_posargs]
args_defaults: list[Optional[Expr]] = parsed_defaults[n_posargs:]
kwargs_defaults: list[Optional[Expr]] = [
self.parse_expr(default) if default is not None else None
for default in kw_defaults
]
return Function(
location=loc,
name=name,
posonlyargs=parse_args(posonlyargs),
args=parse_args(args),
kwonlyargs=parse_args(kwonlyargs),
posonlyargs=parse_args(posonlyargs, posargs_defaults),
args=parse_args(args, args_defaults),
sink=(
self._parse_function_argument(sink, None)
if sink is not None
else None
),
kwonlyargs=parse_args(kwonlyargs, kwargs_defaults),
kw_sink=(
self._parse_function_argument(kw_sink, None)
if kw_sink is not None
else None
),
returns=self._parse_type(returns) if returns is not None else None,
body=body,
)
case _:
print(f"Unsupported function definition: {ast.unparse(node)}")
def _parse_function_argument(self, arg: ast.arg) -> Function.Argument:
def _parse_function_argument(
self, arg: ast.arg, default: Optional[Expr]
) -> Function.Argument:
loc: Location = Location.from_ast(arg)
name: str = arg.arg
type: Optional[MidasType] = None
@@ -174,6 +213,7 @@ class PythonParser:
location=loc,
name=name,
type=type,
default=default,
)
def _parse_type(