From 905132a18e07701f576d70c08b3e206642cb8d16 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sun, 14 Jun 2026 16:36:10 +0200 Subject: [PATCH] feat(checker): resolve overloads with subtypes try to find the most specific overload if multiple matches are found --- .../06_overloads.midas | 2 + .../01_simple_type_checking/06_overloads.py | 4 + midas/checker/python.py | 85 ++++++++++++++++--- 3 files changed, 81 insertions(+), 10 deletions(-) diff --git a/examples/01_simple_type_checking/06_overloads.midas b/examples/01_simple_type_checking/06_overloads.midas index 47c80e0..777c410 100644 --- a/examples/01_simple_type_checking/06_overloads.midas +++ b/examples/01_simple_type_checking/06_overloads.midas @@ -1,8 +1,10 @@ type T1 = object type T2 = object type Foo = object +type T2b = T2 extend Foo { def bar: fn(T1, /) -> int def bar: fn(T2, /) -> float + def bar: fn(T2b, /) -> int } diff --git a/examples/01_simple_type_checking/06_overloads.py b/examples/01_simple_type_checking/06_overloads.py index 105d5ce..86406e0 100644 --- a/examples/01_simple_type_checking/06_overloads.py +++ b/examples/01_simple_type_checking/06_overloads.py @@ -12,3 +12,7 @@ func = foo.bar c = func(t1) d = func(t2) + +t2b: T2b + +e = foo.bar(t2b) diff --git a/midas/checker/python.py b/midas/checker/python.py index 25d5d8b..9ce9399 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -34,6 +34,12 @@ class MappedArgument: argument: Function.Argument +@dataclass(frozen=True, kw_only=True) +class OverloadCandidate: + function: Function + mapped: list[MappedArgument] + + class PythonTyper( p.Stmt.Visitor[None], p.Expr.Visitor[Type], @@ -618,7 +624,7 @@ class PythonTyper( Optional[Function]: the resolved function signature if it can be determined unambigously, or `None`. """ - candidates: list[Function] = [] + candidates: list[OverloadCandidate] = [] for overload in overloads: function: Type = unfold_type(overload) if not isinstance(function, Function): @@ -634,7 +640,12 @@ class PythonTyper( report_errors=False, ) if valid and self._are_arguments_valid(mapped, report_errors=False): - candidates.append(function) + candidates.append( + OverloadCandidate( + function=function, + mapped=mapped, + ) + ) pos_types: str = ", ".join(str(type) for _, type in positional) kw_types: str = ", ".join( @@ -642,19 +653,43 @@ class PythonTyper( ) for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}" - if len(candidates) == 0: + n_candidates: int = len(candidates) + + # Exactly 1 match -> return it + if n_candidates == 1: + return candidates[0].function + + # No match -> invalid call + if n_candidates == 0: self.reporter.error( location, f"No matching overload in {overloads} {for_args}", ) return None - if len(candidates) > 1: - self.reporter.error( - location, - f"Multiple matching overloads {for_args}: {', '.join(map(str, candidates))}", - ) - return None - return candidates[0] + + # Multiple matches -> see if one <: all others (more specific) + for i1, c1 in enumerate(candidates): + mapped1: list[MappedArgument] = c1.mapped + best_match: bool = True + for i2, c2 in enumerate(candidates): + if i1 == i2: + continue + mapped2: list[MappedArgument] = c2.mapped + if not self._are_mapped_subtypes(mapped1, mapped2): + best_match = False + break + self.logger.debug(f"{c1.function} is a full overload of {c2.function}") + if best_match: + return c1.function + + candidates_str: str = ", ".join( + str(candidate.function) for candidate in candidates + ) + self.reporter.error( + location, + f"Multiple matching overloads {for_args}: {candidates_str}", + ) + return None def map_call_arguments( self, @@ -788,3 +823,33 @@ class PythonTyper( valid_call = False return valid_call, mapped + + def _are_mapped_subtypes( + self, mapped1: list[MappedArgument], mapped2: list[MappedArgument] + ) -> bool: + """Check whether the given argument mappings are subtype/supertype of one another + + This function checks whether the argument mappings `mapped1` are subtypes + of `mapped2`. If any of the parameter type in `mapped1` is not a subtype + of the corresponding parameter in `mapped2`, `False` is returned. + + This is used to check whether a given overload is + a more specific function/ a subtype of another. + + Args: + mapped1 (list[MappedArgument]): the first argument mappings (subtype) + mapped2 (list[MappedArgument]): the second argument mappings (supertype) + + Returns: + bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise + """ + by_expr: dict[p.Expr, Type] = {} + for arg in mapped1: + by_expr[arg.expr] = arg.argument.type + + for arg in mapped2: + type2: Type = arg.argument.type + type1: Type = by_expr[arg.expr] + if not self.is_subtype(type1, type2): + return False + return True