diff --git a/midas/checker/unifier.py b/midas/checker/unifier.py index e414623..aae9310 100644 --- a/midas/checker/unifier.py +++ b/midas/checker/unifier.py @@ -12,6 +12,9 @@ from midas.checker.types import ( ) +class UnificationError(Exception): ... + + class Unifier: def __init__(self, types: TypesRegistry) -> None: self.types: TypesRegistry = types @@ -45,10 +48,20 @@ class Unifier: ], returns=TopType(), # TODO: use expected type ) - return self.unify_generic(type, concrete_func) + return self.unify_generic(type, concrete_func, match_return=False) + + def unify_generic( + self, + template: GenericType, + concrete: Type, + match_return: bool = True, + ) -> Optional[Type]: + substitutions: dict[str, Type] + try: + substitutions = self.match(template.body, concrete, match_return) + except UnificationError: + return None - def unify_generic(self, template: GenericType, concrete: Type) -> Optional[Type]: - substitutions: dict[str, Type] = self.match(template.body, concrete) args: list[Type] = [] for param in template.params: if param.name not in substitutions: @@ -58,7 +71,12 @@ class Unifier: applied: Type = self.types.apply_generic(template, args) return applied - def match(self, template: Type, concrete: Type) -> dict[str, Type]: + def match( + self, + template: Type, + concrete: Type, + match_return: bool = True, + ) -> dict[str, Type]: # TODO: if concrete is Generic, record bound TypeVar. Then when merging # substitutions, check that the constraint is respected match (template, concrete): @@ -91,10 +109,11 @@ class Unifier: ) substitutions = self.merge(substitutions, arg_subs) - return_subs: dict[str, Type] = self.match( - template.returns, concrete.returns - ) - substitutions = self.merge(substitutions, return_subs) + if match_return: + return_subs: dict[str, Type] = self.match( + template.returns, concrete.returns + ) + substitutions = self.merge(substitutions, return_subs) return substitutions @@ -110,6 +129,7 @@ class Unifier: self.logger.debug( f"Substitution already defined for {k} with type {merged[k]}, got {v}" ) + raise UnificationError merged[k] = v return merged diff --git a/tests/cases/checker/08_unification.py.ref.json b/tests/cases/checker/08_unification.py.ref.json index fd36cbd..bfaa7bd 100644 --- a/tests/cases/checker/08_unification.py.ref.json +++ b/tests/cases/checker/08_unification.py.ref.json @@ -1,5 +1,20 @@ { - "diagnostics": [], + "diagnostics": [ + { + "type": "Error", + "location": { + "start": [ + 13, + 15 + ], + "end": [ + 13, + 32 + ] + }, + "message": "Could not unify map[T, U]=(transform: (v: T, /) -> U, iterable: list[T], /) -> list[U] with pos=[(value: float) -> float, list[int]] and kw={}" + } + ], "judgments": [ { "location": { @@ -682,17 +697,7 @@ ], "keywords": {} }, - "type": { - "name": "list", - "args": [ - { - "name": "float" - } - ], - "body": { - "name": "list" - } - } + "type": {} }, { "location": {