feat(checker): resolve overloads with subtypes
try to find the most specific overload if multiple matches are found
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
type T1 = object
|
||||
type T2 = object
|
||||
type Foo = object
|
||||
type T2b = T2
|
||||
|
||||
extend Foo {
|
||||
def bar: fn(T1, /) -> int
|
||||
def bar: fn(T2, /) -> float
|
||||
def bar: fn(T2b, /) -> int
|
||||
}
|
||||
|
||||
@@ -12,3 +12,7 @@ func = foo.bar
|
||||
|
||||
c = func(t1)
|
||||
d = func(t2)
|
||||
|
||||
t2b: T2b
|
||||
|
||||
e = foo.bar(t2b)
|
||||
|
||||
@@ -34,6 +34,12 @@ class MappedArgument:
|
||||
argument: Function.Argument
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class OverloadCandidate:
|
||||
function: Function
|
||||
mapped: list[MappedArgument]
|
||||
|
||||
|
||||
class PythonTyper(
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[Type],
|
||||
@@ -618,7 +624,7 @@ class PythonTyper(
|
||||
Optional[Function]: the resolved function signature if it can be
|
||||
determined unambigously, or `None`.
|
||||
"""
|
||||
candidates: list[Function] = []
|
||||
candidates: list[OverloadCandidate] = []
|
||||
for overload in overloads:
|
||||
function: Type = unfold_type(overload)
|
||||
if not isinstance(function, Function):
|
||||
@@ -634,7 +640,12 @@ class PythonTyper(
|
||||
report_errors=False,
|
||||
)
|
||||
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||
candidates.append(function)
|
||||
candidates.append(
|
||||
OverloadCandidate(
|
||||
function=function,
|
||||
mapped=mapped,
|
||||
)
|
||||
)
|
||||
|
||||
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||
kw_types: str = ", ".join(
|
||||
@@ -642,19 +653,43 @@ class PythonTyper(
|
||||
)
|
||||
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||
|
||||
if len(candidates) == 0:
|
||||
n_candidates: int = len(candidates)
|
||||
|
||||
# Exactly 1 match -> return it
|
||||
if n_candidates == 1:
|
||||
return candidates[0].function
|
||||
|
||||
# No match -> invalid call
|
||||
if n_candidates == 0:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"No matching overload in {overloads} {for_args}",
|
||||
)
|
||||
return None
|
||||
if len(candidates) > 1:
|
||||
|
||||
# Multiple matches -> see if one <: all others (more specific)
|
||||
for i1, c1 in enumerate(candidates):
|
||||
mapped1: list[MappedArgument] = c1.mapped
|
||||
best_match: bool = True
|
||||
for i2, c2 in enumerate(candidates):
|
||||
if i1 == i2:
|
||||
continue
|
||||
mapped2: list[MappedArgument] = c2.mapped
|
||||
if not self._are_mapped_subtypes(mapped1, mapped2):
|
||||
best_match = False
|
||||
break
|
||||
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
||||
if best_match:
|
||||
return c1.function
|
||||
|
||||
candidates_str: str = ", ".join(
|
||||
str(candidate.function) for candidate in candidates
|
||||
)
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Multiple matching overloads {for_args}: {', '.join(map(str, candidates))}",
|
||||
f"Multiple matching overloads {for_args}: {candidates_str}",
|
||||
)
|
||||
return None
|
||||
return candidates[0]
|
||||
|
||||
def map_call_arguments(
|
||||
self,
|
||||
@@ -788,3 +823,33 @@ class PythonTyper(
|
||||
valid_call = False
|
||||
|
||||
return valid_call, mapped
|
||||
|
||||
def _are_mapped_subtypes(
|
||||
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
|
||||
) -> bool:
|
||||
"""Check whether the given argument mappings are subtype/supertype of one another
|
||||
|
||||
This function checks whether the argument mappings `mapped1` are subtypes
|
||||
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
||||
of the corresponding parameter in `mapped2`, `False` is returned.
|
||||
|
||||
This is used to check whether a given overload is
|
||||
a more specific function/ a subtype of another.
|
||||
|
||||
Args:
|
||||
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
||||
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
||||
|
||||
Returns:
|
||||
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
||||
"""
|
||||
by_expr: dict[p.Expr, Type] = {}
|
||||
for arg in mapped1:
|
||||
by_expr[arg.expr] = arg.argument.type
|
||||
|
||||
for arg in mapped2:
|
||||
type2: Type = arg.argument.type
|
||||
type1: Type = by_expr[arg.expr]
|
||||
if not self.is_subtype(type1, type2):
|
||||
return False
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user