feat(checker): handle overloaded function calls

This commit is contained in:
2026-06-14 15:48:31 +02:00
parent 04c0d683de
commit 25a96d20e1

View File

@@ -12,12 +12,16 @@ from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver
from midas.checker.types import (
Function,
OverloadedFunction,
Type,
UnitType,
UnknownType,
unfold_type,
)
from midas.parser.python import PythonParser
TypedExpr = tuple[p.Expr, Type]
class ReturnException(Exception):
pass
@@ -354,26 +358,7 @@ class PythonTyper(
)
return UnknownType()
match operation:
case Function() as function:
if not self._check_arity(function, 1, 0, 0):
self.reporter.error(
location,
f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}",
)
return UnknownType()
rhs: Function.Argument = function.pos_args[0]
if not self.is_subtype(right, rhs.type):
self.reporter.error(
location,
f"Wrong type for right-hand side, expected {rhs.type}, got {right}",
)
return UnknownType()
return function.returns
case _:
self.reporter.warning(location, f"Unsupported operation {operation}")
return UnknownType()
return self._get_call_result(location, operation, [(right_expr, right)], {})
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
@@ -393,35 +378,24 @@ class PythonTyper(
)
return UnknownType()
match operation:
case Function() as function:
if not self._check_arity(function, 0, 0, 0):
self.reporter.error(
expr.location,
f"Wrong definition of unary operation. Expected function with 0 parameters, got {function}",
)
return UnknownType()
return function.returns
case _:
self.reporter.warning(
expr.location, f"Unsupported operation {operation}"
)
return UnknownType()
return self._get_call_result(
expr.location, operation, [(expr.right, operand)], {}
)
def visit_call_expr(self, expr: p.CallExpr) -> Type:
callee: Type = self.type_of(expr.callee)
if not isinstance(callee, Function):
self.reporter.error(expr.callee.location, "Callee is not a function")
return UnknownType()
function: Function = callee
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
for arg in mapped:
if not self.is_subtype(arg.type, arg.argument.type):
self.reporter.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
)
return function.returns
positional: list[TypedExpr] = [
(arg, self.type_of(arg)) for arg in expr.arguments
]
keywords: dict[str, TypedExpr] = {
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
}
return self._get_call_result(
location=expr.location,
callee=callee,
positional=positional,
keywords=keywords,
)
def visit_get_expr(self, expr: p.GetExpr) -> Type:
object: Type = self.type_of(expr.object)
@@ -572,9 +546,105 @@ class PythonTyper(
self.reporter.warning(node.location, "FrameType not yet supported")
return UnknownType()
def _get_call_result(
self,
location: Location,
callee: Type,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
match callee:
case Function() as function:
valid: bool
mapped: list[MappedArgument]
valid, mapped = self.map_call_arguments(
function, location, positional, keywords
)
valid = valid and self._are_arguments_valid(mapped)
if not valid:
return UnknownType()
return function.returns
case OverloadedFunction(overloads=overloads):
function = self._match_overload(
overloads, location, positional, keywords
)
if function is None:
return UnknownType()
return function.returns
case _:
self.reporter.error(location, f"{callee} is not callable")
return UnknownType()
def _are_arguments_valid(
self,
arguments: list[MappedArgument],
report_errors: bool = True,
) -> bool:
valid: bool = True
for arg in arguments:
if not self.is_subtype(arg.type, arg.argument.type):
if report_errors:
self.reporter.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
)
valid = False
return valid
def _match_overload(
self,
overloads: list[Type],
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Optional[Function]:
candidates: list[Function] = []
for overload in overloads:
function: Type = unfold_type(overload)
if not isinstance(function, Function):
self.logger.error(
f"Overload is not a function: {overload} is {function}"
)
continue
valid, mapped = self.map_call_arguments(
function=function,
location=location,
positional=positional,
keywords=keywords,
report_errors=False,
)
if valid and self._are_arguments_valid(mapped, report_errors=False):
candidates.append(function)
pos_types: str = ", ".join(str(type) for _, type in positional)
kw_types: str = ", ".join(
f"{name}: {type}" for name, (_, type) in keywords.items()
)
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
if len(candidates) == 0:
self.reporter.error(
location,
f"No matching overload in {overloads} {for_args}",
)
return None
if len(candidates) > 1:
self.reporter.error(
location,
f"Multiple matching overloads {for_args}: {', '.join(map(str, candidates))}",
)
return None
return candidates[0]
def map_call_arguments(
self, function: Function, call: p.CallExpr
) -> list[MappedArgument]:
self,
function: Function,
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> tuple[bool, list[MappedArgument]]:
"""Map call arguments to function parameters as defined in its signature
This method maps positional-only, keyword-only and mixed parameter definitions
@@ -589,12 +659,6 @@ class PythonTyper(
Returns:
list[MappedArgument]: the list of mapped arguments
"""
positional: list[tuple[p.Expr, Type]] = [
(arg, self.type_of(arg)) for arg in call.arguments
]
keywords: dict[str, tuple[p.Expr, Type]] = {
name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
}
set_args: set[str] = set()
required_positional: list[str] = [
@@ -612,6 +676,8 @@ class PythonTyper(
arg.name: arg for arg in function.kw_args
}
valid_call: bool = True
# TODO: handle *args and **kwargs sinks
for arg in positional:
param: Function.Argument
@@ -620,7 +686,11 @@ class PythonTyper(
elif len(mixed_params) != 0:
param = mixed_params.pop(0)
else:
self.reporter.error(arg[0].location, "Too many positional arguments")
if report_errors:
self.reporter.error(
arg[0].location, "Too many positional arguments"
)
valid_call = False
break
name: str = param.name
if name in required_positional:
@@ -640,14 +710,16 @@ class PythonTyper(
for name, arg in keywords.items():
param: Function.Argument
if name not in kw_params:
if name in set_args:
self.reporter.error(
arg[0].location, f"Multiple values for argument '{name}'"
)
else:
self.reporter.error(
arg[0].location, f"Unknown keyword argument '{name}'"
)
if report_errors:
if name in set_args:
self.reporter.error(
arg[0].location, f"Multiple values for argument '{name}'"
)
else:
self.reporter.error(
arg[0].location, f"Unknown keyword argument '{name}'"
)
valid_call = False
continue
param = kw_params.pop(name)
if name in required_positional:
@@ -674,20 +746,24 @@ class PythonTyper(
if len(required_positional) != 0:
plural: str = "" if len(required_positional) == 1 else "s"
args: str = join_args(required_positional)
self.reporter.error(
call.location,
f"Missing required positional argument{plural}: {args}",
)
if report_errors:
self.reporter.error(
location,
f"Missing required positional argument{plural}: {args}",
)
valid_call = False
if len(required_keyword) != 0:
plural: str = "" if len(required_keyword) == 1 else "s"
args: str = join_args(required_keyword)
self.reporter.error(
call.location,
f"Missing required keyword argument{plural}: {args}",
)
if report_errors:
self.reporter.error(
location,
f"Missing required keyword argument{plural}: {args}",
)
valid_call = False
return mapped
return valid_call, mapped
def _check_arity(
self,