feat(checker): handle overloaded function calls

This commit is contained in:
2026-06-14 15:48:31 +02:00
parent 04c0d683de
commit 25a96d20e1

View File

@@ -12,12 +12,16 @@ from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver from midas.checker.resolver import Resolver
from midas.checker.types import ( from midas.checker.types import (
Function, Function,
OverloadedFunction,
Type, Type,
UnitType, UnitType,
UnknownType, UnknownType,
unfold_type,
) )
from midas.parser.python import PythonParser from midas.parser.python import PythonParser
TypedExpr = tuple[p.Expr, Type]
class ReturnException(Exception): class ReturnException(Exception):
pass pass
@@ -354,26 +358,7 @@ class PythonTyper(
) )
return UnknownType() return UnknownType()
match operation: return self._get_call_result(location, operation, [(right_expr, right)], {})
case Function() as function:
if not self._check_arity(function, 1, 0, 0):
self.reporter.error(
location,
f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}",
)
return UnknownType()
rhs: Function.Argument = function.pos_args[0]
if not self.is_subtype(right, rhs.type):
self.reporter.error(
location,
f"Wrong type for right-hand side, expected {rhs.type}, got {right}",
)
return UnknownType()
return function.returns
case _:
self.reporter.warning(location, f"Unsupported operation {operation}")
return UnknownType()
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__) method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
@@ -393,35 +378,24 @@ class PythonTyper(
) )
return UnknownType() return UnknownType()
match operation: return self._get_call_result(
case Function() as function: expr.location, operation, [(expr.right, operand)], {}
if not self._check_arity(function, 0, 0, 0): )
self.reporter.error(
expr.location,
f"Wrong definition of unary operation. Expected function with 0 parameters, got {function}",
)
return UnknownType()
return function.returns
case _:
self.reporter.warning(
expr.location, f"Unsupported operation {operation}"
)
return UnknownType()
def visit_call_expr(self, expr: p.CallExpr) -> Type: def visit_call_expr(self, expr: p.CallExpr) -> Type:
callee: Type = self.type_of(expr.callee) callee: Type = self.type_of(expr.callee)
if not isinstance(callee, Function): positional: list[TypedExpr] = [
self.reporter.error(expr.callee.location, "Callee is not a function") (arg, self.type_of(arg)) for arg in expr.arguments
return UnknownType() ]
function: Function = callee keywords: dict[str, TypedExpr] = {
mapped: list[MappedArgument] = self.map_call_arguments(function, expr) name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
for arg in mapped: }
if not self.is_subtype(arg.type, arg.argument.type): return self._get_call_result(
self.reporter.error( location=expr.location,
arg.expr.location, callee=callee,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", positional=positional,
) keywords=keywords,
return function.returns )
def visit_get_expr(self, expr: p.GetExpr) -> Type: def visit_get_expr(self, expr: p.GetExpr) -> Type:
object: Type = self.type_of(expr.object) object: Type = self.type_of(expr.object)
@@ -572,9 +546,105 @@ class PythonTyper(
self.reporter.warning(node.location, "FrameType not yet supported") self.reporter.warning(node.location, "FrameType not yet supported")
return UnknownType() return UnknownType()
def _get_call_result(
self,
location: Location,
callee: Type,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
match callee:
case Function() as function:
valid: bool
mapped: list[MappedArgument]
valid, mapped = self.map_call_arguments(
function, location, positional, keywords
)
valid = valid and self._are_arguments_valid(mapped)
if not valid:
return UnknownType()
return function.returns
case OverloadedFunction(overloads=overloads):
function = self._match_overload(
overloads, location, positional, keywords
)
if function is None:
return UnknownType()
return function.returns
case _:
self.reporter.error(location, f"{callee} is not callable")
return UnknownType()
def _are_arguments_valid(
self,
arguments: list[MappedArgument],
report_errors: bool = True,
) -> bool:
valid: bool = True
for arg in arguments:
if not self.is_subtype(arg.type, arg.argument.type):
if report_errors:
self.reporter.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
)
valid = False
return valid
def _match_overload(
self,
overloads: list[Type],
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Optional[Function]:
candidates: list[Function] = []
for overload in overloads:
function: Type = unfold_type(overload)
if not isinstance(function, Function):
self.logger.error(
f"Overload is not a function: {overload} is {function}"
)
continue
valid, mapped = self.map_call_arguments(
function=function,
location=location,
positional=positional,
keywords=keywords,
report_errors=False,
)
if valid and self._are_arguments_valid(mapped, report_errors=False):
candidates.append(function)
pos_types: str = ", ".join(str(type) for _, type in positional)
kw_types: str = ", ".join(
f"{name}: {type}" for name, (_, type) in keywords.items()
)
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
if len(candidates) == 0:
self.reporter.error(
location,
f"No matching overload in {overloads} {for_args}",
)
return None
if len(candidates) > 1:
self.reporter.error(
location,
f"Multiple matching overloads {for_args}: {', '.join(map(str, candidates))}",
)
return None
return candidates[0]
def map_call_arguments( def map_call_arguments(
self, function: Function, call: p.CallExpr self,
) -> list[MappedArgument]: function: Function,
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> tuple[bool, list[MappedArgument]]:
"""Map call arguments to function parameters as defined in its signature """Map call arguments to function parameters as defined in its signature
This method maps positional-only, keyword-only and mixed parameter definitions This method maps positional-only, keyword-only and mixed parameter definitions
@@ -589,12 +659,6 @@ class PythonTyper(
Returns: Returns:
list[MappedArgument]: the list of mapped arguments list[MappedArgument]: the list of mapped arguments
""" """
positional: list[tuple[p.Expr, Type]] = [
(arg, self.type_of(arg)) for arg in call.arguments
]
keywords: dict[str, tuple[p.Expr, Type]] = {
name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
}
set_args: set[str] = set() set_args: set[str] = set()
required_positional: list[str] = [ required_positional: list[str] = [
@@ -612,6 +676,8 @@ class PythonTyper(
arg.name: arg for arg in function.kw_args arg.name: arg for arg in function.kw_args
} }
valid_call: bool = True
# TODO: handle *args and **kwargs sinks # TODO: handle *args and **kwargs sinks
for arg in positional: for arg in positional:
param: Function.Argument param: Function.Argument
@@ -620,7 +686,11 @@ class PythonTyper(
elif len(mixed_params) != 0: elif len(mixed_params) != 0:
param = mixed_params.pop(0) param = mixed_params.pop(0)
else: else:
self.reporter.error(arg[0].location, "Too many positional arguments") if report_errors:
self.reporter.error(
arg[0].location, "Too many positional arguments"
)
valid_call = False
break break
name: str = param.name name: str = param.name
if name in required_positional: if name in required_positional:
@@ -640,14 +710,16 @@ class PythonTyper(
for name, arg in keywords.items(): for name, arg in keywords.items():
param: Function.Argument param: Function.Argument
if name not in kw_params: if name not in kw_params:
if name in set_args: if report_errors:
self.reporter.error( if name in set_args:
arg[0].location, f"Multiple values for argument '{name}'" self.reporter.error(
) arg[0].location, f"Multiple values for argument '{name}'"
else: )
self.reporter.error( else:
arg[0].location, f"Unknown keyword argument '{name}'" self.reporter.error(
) arg[0].location, f"Unknown keyword argument '{name}'"
)
valid_call = False
continue continue
param = kw_params.pop(name) param = kw_params.pop(name)
if name in required_positional: if name in required_positional:
@@ -674,20 +746,24 @@ class PythonTyper(
if len(required_positional) != 0: if len(required_positional) != 0:
plural: str = "" if len(required_positional) == 1 else "s" plural: str = "" if len(required_positional) == 1 else "s"
args: str = join_args(required_positional) args: str = join_args(required_positional)
self.reporter.error( if report_errors:
call.location, self.reporter.error(
f"Missing required positional argument{plural}: {args}", location,
) f"Missing required positional argument{plural}: {args}",
)
valid_call = False
if len(required_keyword) != 0: if len(required_keyword) != 0:
plural: str = "" if len(required_keyword) == 1 else "s" plural: str = "" if len(required_keyword) == 1 else "s"
args: str = join_args(required_keyword) args: str = join_args(required_keyword)
self.reporter.error( if report_errors:
call.location, self.reporter.error(
f"Missing required keyword argument{plural}: {args}", location,
) f"Missing required keyword argument{plural}: {args}",
)
valid_call = False
return mapped return valid_call, mapped
def _check_arity( def _check_arity(
self, self,