feat(checker): map and check function call arguments

This commit is contained in:
2026-05-29 15:48:36 +02:00
parent 3f61f84e5a
commit d16e192a3a
2 changed files with 116 additions and 8 deletions

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass
import logging
from pathlib import Path
from typing import Optional
@@ -19,6 +20,13 @@ class ReturnException(Exception):
pass
@dataclass(frozen=True, kw_only=True)
class MappedArgument:
expr: p.Expr
type: Type
argument: Function.Argument
class Checker(
p.Stmt.Visitor[None],
p.Expr.Visitor[Type],
@@ -126,15 +134,18 @@ class Checker(
kw_args: list[Function.Argument] = []
def eval_arg_type(arg: p.Function.Argument) -> Type:
if arg.type is None:
return UnknownType()
return arg.type.accept(self)
if arg.type is not None:
return arg.type.accept(self)
if arg.default is not None:
return arg.default.accept(self)
return UnknownType()
for arg in stmt.posonlyargs:
pos_args.append(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
for arg in stmt.args:
@@ -142,6 +153,7 @@ class Checker(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
for arg in stmt.kwonlyargs:
@@ -149,6 +161,7 @@ class Checker(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
@@ -175,7 +188,9 @@ class Checker(
else:
returns = inferred_return
# TODO: handle *args and **kwargs sinks
function: Function = Function(
name=stmt.name,
pos_args=pos_args,
args=args,
kw_args=kw_args,
@@ -240,11 +255,18 @@ class Checker(
self.import_midas(path)
return UnknownType()
callee: Type = self.evaluate(expr.callee)
arguments: list[Type] = [self.evaluate(arg) for arg in expr.arguments]
keywords: dict[str, Type] = {
name: self.evaluate(arg) for name, arg in expr.keywords.items()
}
return UnknownType()
if not isinstance(callee, Function):
self.error(expr.callee.location, "Callee is not a function")
return UnknownType()
function: Function = callee
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
for arg in mapped:
if arg.type != arg.argument.type:
self.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
)
return function.returns
def visit_get_expr(self, expr: p.GetExpr) -> Type: ...
@@ -277,3 +299,87 @@ class Checker(
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
def visit_frame_type(self, node: p.FrameType) -> Type: ...
def map_call_arguments(
self, function: Function, call: p.CallExpr
) -> list[MappedArgument]:
positional: list[tuple[p.Expr, Type]] = [
(arg, self.evaluate(arg)) for arg in call.arguments
]
keywords: dict[str, tuple[p.Expr, Type]] = {
name: (arg, self.evaluate(arg)) for name, arg in call.keywords.items()
}
set_args: set[str] = set()
required_positional: set[str] = {
arg.name for arg in function.pos_args + function.args if arg.required
}
required_keyword: set[str] = {
arg.name for arg in function.kw_args if arg.required
}
mapped: list[MappedArgument] = []
pos_params: list[Function.Argument] = list(function.pos_args)
mixed_params: list[Function.Argument] = list(function.args)
kw_params: dict[str, Function.Argument] = {
arg.name: arg for arg in function.kw_args
}
# TODO: handle *args and **kwargs sinks
for arg in positional:
param: Function.Argument
if len(pos_params) != 0:
param = pos_params.pop(0)
elif len(mixed_params) != 0:
param = mixed_params.pop(0)
else:
self.error(arg[0].location, "Too many positional arguments")
break
required_positional.discard(param.name)
required_keyword.discard(param.name)
set_args.add(param.name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
kw_params.update({arg.name: arg for arg in mixed_params})
for name, arg in keywords.items():
param: Function.Argument
if name not in kw_params:
if name in set_args:
self.error(
arg[0].location, f"Multiple values for argument '{name}'"
)
else:
self.error(arg[0].location, f"Unknown keyword argument '{name}'")
continue
param = kw_params.pop(name)
required_positional.discard(name)
required_keyword.discard(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
if len(required_positional) != 0:
self.error(
call.location,
f"Missing required positional arguments: {required_positional}",
)
if len(required_keyword) != 0:
self.error(
call.location,
f"Missing required keyword arguments: {required_keyword}",
)
return mapped

View File

@@ -26,6 +26,7 @@ class UnitType:
@dataclass(frozen=True, kw_only=True)
class Function:
name: str
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
@@ -35,6 +36,7 @@ class Function:
class Argument:
name: str
type: Type
required: bool
Type = BaseType | SimpleType | UnknownType | UnitType | Function