feat(checker): improve function unwrapping

This commit is contained in:
2026-07-01 12:59:12 +02:00
parent 9a276c34c7
commit 6ad2ce4b68

View File

@@ -14,7 +14,6 @@ from midas.checker.types import (
OverloadedFunction,
Type,
UnknownType,
unfold_type,
)
from midas.checker.unifier import Unifier
@@ -174,6 +173,33 @@ class CallDispatcher(Generic[E]):
message=message,
)
def _unwrap_function(
self,
callee: Type,
positional: list[TypedExpr[E]],
keywords: dict[str, TypedExpr[E]],
) -> Union[tuple[Function, None], tuple[None, CallError]]:
match callee:
case DerivedType(type=base):
return self._unwrap_function(base, positional, keywords)
case GenericType():
unifier: Unifier = Unifier(self.types)
unified: Optional[Type] = unifier.unify_call(
callee,
[a[1] for a in positional],
{k: v[1] for k, v in keywords.items()},
)
if unified is None:
return None, CallError.IMPOSSIBLE_UNIFICATION
return self._unwrap_function(unified, positional, keywords)
case Function():
return callee, None
case _:
return None, CallError.NOT_CALLABLE
def _are_arguments_valid(
self,
arguments: list[MappedArgument[E]],
@@ -222,13 +248,12 @@ class CallDispatcher(Generic[E]):
"""
candidates: list[OverloadCandidate] = []
for overload in overloads:
function: Type = unfold_type(overload)
if not isinstance(function, Function):
if report_errors:
self.logger.error(
f"Overload is not a function: {overload} is {function}"
)
function, unwrap_error = self._unwrap_function(
overload, positional, keywords
)
if function is None:
continue
valid, mapped = self.map_call_arguments(
function=function,
location=location,