feat(checker): resolve overloads with subtypes

try to find the most specific overload if multiple matches are found
This commit is contained in:
2026-06-14 16:36:10 +02:00
parent 9594c74952
commit 905132a18e
3 changed files with 81 additions and 10 deletions

View File

@@ -1,8 +1,10 @@
type T1 = object type T1 = object
type T2 = object type T2 = object
type Foo = object type Foo = object
type T2b = T2
extend Foo { extend Foo {
def bar: fn(T1, /) -> int def bar: fn(T1, /) -> int
def bar: fn(T2, /) -> float def bar: fn(T2, /) -> float
def bar: fn(T2b, /) -> int
} }

View File

@@ -12,3 +12,7 @@ func = foo.bar
c = func(t1) c = func(t1)
d = func(t2) d = func(t2)
t2b: T2b
e = foo.bar(t2b)

View File

@@ -34,6 +34,12 @@ class MappedArgument:
argument: Function.Argument argument: Function.Argument
@dataclass(frozen=True, kw_only=True)
class OverloadCandidate:
function: Function
mapped: list[MappedArgument]
class PythonTyper( class PythonTyper(
p.Stmt.Visitor[None], p.Stmt.Visitor[None],
p.Expr.Visitor[Type], p.Expr.Visitor[Type],
@@ -618,7 +624,7 @@ class PythonTyper(
Optional[Function]: the resolved function signature if it can be Optional[Function]: the resolved function signature if it can be
determined unambigously, or `None`. determined unambigously, or `None`.
""" """
candidates: list[Function] = [] candidates: list[OverloadCandidate] = []
for overload in overloads: for overload in overloads:
function: Type = unfold_type(overload) function: Type = unfold_type(overload)
if not isinstance(function, Function): if not isinstance(function, Function):
@@ -634,7 +640,12 @@ class PythonTyper(
report_errors=False, report_errors=False,
) )
if valid and self._are_arguments_valid(mapped, report_errors=False): if valid and self._are_arguments_valid(mapped, report_errors=False):
candidates.append(function) candidates.append(
OverloadCandidate(
function=function,
mapped=mapped,
)
)
pos_types: str = ", ".join(str(type) for _, type in positional) pos_types: str = ", ".join(str(type) for _, type in positional)
kw_types: str = ", ".join( kw_types: str = ", ".join(
@@ -642,19 +653,43 @@ class PythonTyper(
) )
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}" for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
if len(candidates) == 0: n_candidates: int = len(candidates)
# Exactly 1 match -> return it
if n_candidates == 1:
return candidates[0].function
# No match -> invalid call
if n_candidates == 0:
self.reporter.error( self.reporter.error(
location, location,
f"No matching overload in {overloads} {for_args}", f"No matching overload in {overloads} {for_args}",
) )
return None return None
if len(candidates) > 1:
self.reporter.error( # Multiple matches -> see if one <: all others (more specific)
location, for i1, c1 in enumerate(candidates):
f"Multiple matching overloads {for_args}: {', '.join(map(str, candidates))}", mapped1: list[MappedArgument] = c1.mapped
) best_match: bool = True
return None for i2, c2 in enumerate(candidates):
return candidates[0] if i1 == i2:
continue
mapped2: list[MappedArgument] = c2.mapped
if not self._are_mapped_subtypes(mapped1, mapped2):
best_match = False
break
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
if best_match:
return c1.function
candidates_str: str = ", ".join(
str(candidate.function) for candidate in candidates
)
self.reporter.error(
location,
f"Multiple matching overloads {for_args}: {candidates_str}",
)
return None
def map_call_arguments( def map_call_arguments(
self, self,
@@ -788,3 +823,33 @@ class PythonTyper(
valid_call = False valid_call = False
return valid_call, mapped return valid_call, mapped
def _are_mapped_subtypes(
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
) -> bool:
"""Check whether the given argument mappings are subtype/supertype of one another
This function checks whether the argument mappings `mapped1` are subtypes
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
of the corresponding parameter in `mapped2`, `False` is returned.
This is used to check whether a given overload is
a more specific function/ a subtype of another.
Args:
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
Returns:
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
"""
by_expr: dict[p.Expr, Type] = {}
for arg in mapped1:
by_expr[arg.expr] = arg.argument.type
for arg in mapped2:
type2: Type = arg.argument.type
type1: Type = by_expr[arg.expr]
if not self.is_subtype(type1, type2):
return False
return True