feat(checker): resolve overloads with subtypes
try to find the most specific overload if multiple matches are found
This commit is contained in:
@@ -1,8 +1,10 @@
|
|||||||
type T1 = object
|
type T1 = object
|
||||||
type T2 = object
|
type T2 = object
|
||||||
type Foo = object
|
type Foo = object
|
||||||
|
type T2b = T2
|
||||||
|
|
||||||
extend Foo {
|
extend Foo {
|
||||||
def bar: fn(T1, /) -> int
|
def bar: fn(T1, /) -> int
|
||||||
def bar: fn(T2, /) -> float
|
def bar: fn(T2, /) -> float
|
||||||
|
def bar: fn(T2b, /) -> int
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,3 +12,7 @@ func = foo.bar
|
|||||||
|
|
||||||
c = func(t1)
|
c = func(t1)
|
||||||
d = func(t2)
|
d = func(t2)
|
||||||
|
|
||||||
|
t2b: T2b
|
||||||
|
|
||||||
|
e = foo.bar(t2b)
|
||||||
|
|||||||
@@ -34,6 +34,12 @@ class MappedArgument:
|
|||||||
argument: Function.Argument
|
argument: Function.Argument
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class OverloadCandidate:
|
||||||
|
function: Function
|
||||||
|
mapped: list[MappedArgument]
|
||||||
|
|
||||||
|
|
||||||
class PythonTyper(
|
class PythonTyper(
|
||||||
p.Stmt.Visitor[None],
|
p.Stmt.Visitor[None],
|
||||||
p.Expr.Visitor[Type],
|
p.Expr.Visitor[Type],
|
||||||
@@ -618,7 +624,7 @@ class PythonTyper(
|
|||||||
Optional[Function]: the resolved function signature if it can be
|
Optional[Function]: the resolved function signature if it can be
|
||||||
determined unambigously, or `None`.
|
determined unambigously, or `None`.
|
||||||
"""
|
"""
|
||||||
candidates: list[Function] = []
|
candidates: list[OverloadCandidate] = []
|
||||||
for overload in overloads:
|
for overload in overloads:
|
||||||
function: Type = unfold_type(overload)
|
function: Type = unfold_type(overload)
|
||||||
if not isinstance(function, Function):
|
if not isinstance(function, Function):
|
||||||
@@ -634,7 +640,12 @@ class PythonTyper(
|
|||||||
report_errors=False,
|
report_errors=False,
|
||||||
)
|
)
|
||||||
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||||
candidates.append(function)
|
candidates.append(
|
||||||
|
OverloadCandidate(
|
||||||
|
function=function,
|
||||||
|
mapped=mapped,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
pos_types: str = ", ".join(str(type) for _, type in positional)
|
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||||
kw_types: str = ", ".join(
|
kw_types: str = ", ".join(
|
||||||
@@ -642,19 +653,43 @@ class PythonTyper(
|
|||||||
)
|
)
|
||||||
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||||
|
|
||||||
if len(candidates) == 0:
|
n_candidates: int = len(candidates)
|
||||||
|
|
||||||
|
# Exactly 1 match -> return it
|
||||||
|
if n_candidates == 1:
|
||||||
|
return candidates[0].function
|
||||||
|
|
||||||
|
# No match -> invalid call
|
||||||
|
if n_candidates == 0:
|
||||||
self.reporter.error(
|
self.reporter.error(
|
||||||
location,
|
location,
|
||||||
f"No matching overload in {overloads} {for_args}",
|
f"No matching overload in {overloads} {for_args}",
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
if len(candidates) > 1:
|
|
||||||
|
# Multiple matches -> see if one <: all others (more specific)
|
||||||
|
for i1, c1 in enumerate(candidates):
|
||||||
|
mapped1: list[MappedArgument] = c1.mapped
|
||||||
|
best_match: bool = True
|
||||||
|
for i2, c2 in enumerate(candidates):
|
||||||
|
if i1 == i2:
|
||||||
|
continue
|
||||||
|
mapped2: list[MappedArgument] = c2.mapped
|
||||||
|
if not self._are_mapped_subtypes(mapped1, mapped2):
|
||||||
|
best_match = False
|
||||||
|
break
|
||||||
|
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
||||||
|
if best_match:
|
||||||
|
return c1.function
|
||||||
|
|
||||||
|
candidates_str: str = ", ".join(
|
||||||
|
str(candidate.function) for candidate in candidates
|
||||||
|
)
|
||||||
self.reporter.error(
|
self.reporter.error(
|
||||||
location,
|
location,
|
||||||
f"Multiple matching overloads {for_args}: {', '.join(map(str, candidates))}",
|
f"Multiple matching overloads {for_args}: {candidates_str}",
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
return candidates[0]
|
|
||||||
|
|
||||||
def map_call_arguments(
|
def map_call_arguments(
|
||||||
self,
|
self,
|
||||||
@@ -788,3 +823,33 @@ class PythonTyper(
|
|||||||
valid_call = False
|
valid_call = False
|
||||||
|
|
||||||
return valid_call, mapped
|
return valid_call, mapped
|
||||||
|
|
||||||
|
def _are_mapped_subtypes(
|
||||||
|
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
|
||||||
|
) -> bool:
|
||||||
|
"""Check whether the given argument mappings are subtype/supertype of one another
|
||||||
|
|
||||||
|
This function checks whether the argument mappings `mapped1` are subtypes
|
||||||
|
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
||||||
|
of the corresponding parameter in `mapped2`, `False` is returned.
|
||||||
|
|
||||||
|
This is used to check whether a given overload is
|
||||||
|
a more specific function/ a subtype of another.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
||||||
|
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
||||||
|
"""
|
||||||
|
by_expr: dict[p.Expr, Type] = {}
|
||||||
|
for arg in mapped1:
|
||||||
|
by_expr[arg.expr] = arg.argument.type
|
||||||
|
|
||||||
|
for arg in mapped2:
|
||||||
|
type2: Type = arg.argument.type
|
||||||
|
type1: Type = by_expr[arg.expr]
|
||||||
|
if not self.is_subtype(type1, type2):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|||||||
Reference in New Issue
Block a user