From 25a96d20e1cdcb7cfcadffd5de783f6fca9d75a0 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sun, 14 Jun 2026 15:48:31 +0200 Subject: [PATCH] feat(checker): handle overloaded function calls --- midas/checker/python.py | 220 +++++++++++++++++++++++++++------------- 1 file changed, 148 insertions(+), 72 deletions(-) diff --git a/midas/checker/python.py b/midas/checker/python.py index 45b9a33..dbe8c8e 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -12,12 +12,16 @@ from midas.checker.reporter import FileReporter, Reporter from midas.checker.resolver import Resolver from midas.checker.types import ( Function, + OverloadedFunction, Type, UnitType, UnknownType, + unfold_type, ) from midas.parser.python import PythonParser +TypedExpr = tuple[p.Expr, Type] + class ReturnException(Exception): pass @@ -354,26 +358,7 @@ class PythonTyper( ) return UnknownType() - match operation: - case Function() as function: - if not self._check_arity(function, 1, 0, 0): - self.reporter.error( - location, - f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}", - ) - return UnknownType() - - rhs: Function.Argument = function.pos_args[0] - if not self.is_subtype(right, rhs.type): - self.reporter.error( - location, - f"Wrong type for right-hand side, expected {rhs.type}, got {right}", - ) - return UnknownType() - return function.returns - case _: - self.reporter.warning(location, f"Unsupported operation {operation}") - return UnknownType() + return self._get_call_result(location, operation, [(right_expr, right)], {}) def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__) @@ -393,35 +378,24 @@ class PythonTyper( ) return UnknownType() - match operation: - case Function() as function: - if not self._check_arity(function, 0, 0, 0): - self.reporter.error( - expr.location, - f"Wrong definition of unary operation. Expected function with 0 parameters, got {function}", - ) - return UnknownType() - return function.returns - case _: - self.reporter.warning( - expr.location, f"Unsupported operation {operation}" - ) - return UnknownType() + return self._get_call_result( + expr.location, operation, [(expr.right, operand)], {} + ) def visit_call_expr(self, expr: p.CallExpr) -> Type: callee: Type = self.type_of(expr.callee) - if not isinstance(callee, Function): - self.reporter.error(expr.callee.location, "Callee is not a function") - return UnknownType() - function: Function = callee - mapped: list[MappedArgument] = self.map_call_arguments(function, expr) - for arg in mapped: - if not self.is_subtype(arg.type, arg.argument.type): - self.reporter.error( - arg.expr.location, - f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", - ) - return function.returns + positional: list[TypedExpr] = [ + (arg, self.type_of(arg)) for arg in expr.arguments + ] + keywords: dict[str, TypedExpr] = { + name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items() + } + return self._get_call_result( + location=expr.location, + callee=callee, + positional=positional, + keywords=keywords, + ) def visit_get_expr(self, expr: p.GetExpr) -> Type: object: Type = self.type_of(expr.object) @@ -572,9 +546,105 @@ class PythonTyper( self.reporter.warning(node.location, "FrameType not yet supported") return UnknownType() + def _get_call_result( + self, + location: Location, + callee: Type, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + ) -> Type: + match callee: + case Function() as function: + valid: bool + mapped: list[MappedArgument] + valid, mapped = self.map_call_arguments( + function, location, positional, keywords + ) + valid = valid and self._are_arguments_valid(mapped) + if not valid: + return UnknownType() + return function.returns + + case OverloadedFunction(overloads=overloads): + function = self._match_overload( + overloads, location, positional, keywords + ) + if function is None: + return UnknownType() + return function.returns + case _: + self.reporter.error(location, f"{callee} is not callable") + return UnknownType() + + def _are_arguments_valid( + self, + arguments: list[MappedArgument], + report_errors: bool = True, + ) -> bool: + valid: bool = True + for arg in arguments: + if not self.is_subtype(arg.type, arg.argument.type): + if report_errors: + self.reporter.error( + arg.expr.location, + f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", + ) + valid = False + return valid + + def _match_overload( + self, + overloads: list[Type], + location: Location, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + ) -> Optional[Function]: + candidates: list[Function] = [] + for overload in overloads: + function: Type = unfold_type(overload) + if not isinstance(function, Function): + self.logger.error( + f"Overload is not a function: {overload} is {function}" + ) + continue + valid, mapped = self.map_call_arguments( + function=function, + location=location, + positional=positional, + keywords=keywords, + report_errors=False, + ) + if valid and self._are_arguments_valid(mapped, report_errors=False): + candidates.append(function) + + pos_types: str = ", ".join(str(type) for _, type in positional) + kw_types: str = ", ".join( + f"{name}: {type}" for name, (_, type) in keywords.items() + ) + for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}" + + if len(candidates) == 0: + self.reporter.error( + location, + f"No matching overload in {overloads} {for_args}", + ) + return None + if len(candidates) > 1: + self.reporter.error( + location, + f"Multiple matching overloads {for_args}: {', '.join(map(str, candidates))}", + ) + return None + return candidates[0] + def map_call_arguments( - self, function: Function, call: p.CallExpr - ) -> list[MappedArgument]: + self, + function: Function, + location: Location, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + report_errors: bool = True, + ) -> tuple[bool, list[MappedArgument]]: """Map call arguments to function parameters as defined in its signature This method maps positional-only, keyword-only and mixed parameter definitions @@ -589,12 +659,6 @@ class PythonTyper( Returns: list[MappedArgument]: the list of mapped arguments """ - positional: list[tuple[p.Expr, Type]] = [ - (arg, self.type_of(arg)) for arg in call.arguments - ] - keywords: dict[str, tuple[p.Expr, Type]] = { - name: (arg, self.type_of(arg)) for name, arg in call.keywords.items() - } set_args: set[str] = set() required_positional: list[str] = [ @@ -612,6 +676,8 @@ class PythonTyper( arg.name: arg for arg in function.kw_args } + valid_call: bool = True + # TODO: handle *args and **kwargs sinks for arg in positional: param: Function.Argument @@ -620,7 +686,11 @@ class PythonTyper( elif len(mixed_params) != 0: param = mixed_params.pop(0) else: - self.reporter.error(arg[0].location, "Too many positional arguments") + if report_errors: + self.reporter.error( + arg[0].location, "Too many positional arguments" + ) + valid_call = False break name: str = param.name if name in required_positional: @@ -640,14 +710,16 @@ class PythonTyper( for name, arg in keywords.items(): param: Function.Argument if name not in kw_params: - if name in set_args: - self.reporter.error( - arg[0].location, f"Multiple values for argument '{name}'" - ) - else: - self.reporter.error( - arg[0].location, f"Unknown keyword argument '{name}'" - ) + if report_errors: + if name in set_args: + self.reporter.error( + arg[0].location, f"Multiple values for argument '{name}'" + ) + else: + self.reporter.error( + arg[0].location, f"Unknown keyword argument '{name}'" + ) + valid_call = False continue param = kw_params.pop(name) if name in required_positional: @@ -674,20 +746,24 @@ class PythonTyper( if len(required_positional) != 0: plural: str = "" if len(required_positional) == 1 else "s" args: str = join_args(required_positional) - self.reporter.error( - call.location, - f"Missing required positional argument{plural}: {args}", - ) + if report_errors: + self.reporter.error( + location, + f"Missing required positional argument{plural}: {args}", + ) + valid_call = False if len(required_keyword) != 0: plural: str = "" if len(required_keyword) == 1 else "s" args: str = join_args(required_keyword) - self.reporter.error( - call.location, - f"Missing required keyword argument{plural}: {args}", - ) + if report_errors: + self.reporter.error( + location, + f"Missing required keyword argument{plural}: {args}", + ) + valid_call = False - return mapped + return valid_call, mapped def _check_arity( self,