fix(checker): abort unification on conflict

This commit is contained in:
2026-06-21 13:35:44 +02:00
parent 29e601128d
commit 4395e9339b
2 changed files with 45 additions and 20 deletions

View File

@@ -12,6 +12,9 @@ from midas.checker.types import (
) )
class UnificationError(Exception): ...
class Unifier: class Unifier:
def __init__(self, types: TypesRegistry) -> None: def __init__(self, types: TypesRegistry) -> None:
self.types: TypesRegistry = types self.types: TypesRegistry = types
@@ -45,10 +48,20 @@ class Unifier:
], ],
returns=TopType(), # TODO: use expected type returns=TopType(), # TODO: use expected type
) )
return self.unify_generic(type, concrete_func) return self.unify_generic(type, concrete_func, match_return=False)
def unify_generic(
self,
template: GenericType,
concrete: Type,
match_return: bool = True,
) -> Optional[Type]:
substitutions: dict[str, Type]
try:
substitutions = self.match(template.body, concrete, match_return)
except UnificationError:
return None
def unify_generic(self, template: GenericType, concrete: Type) -> Optional[Type]:
substitutions: dict[str, Type] = self.match(template.body, concrete)
args: list[Type] = [] args: list[Type] = []
for param in template.params: for param in template.params:
if param.name not in substitutions: if param.name not in substitutions:
@@ -58,7 +71,12 @@ class Unifier:
applied: Type = self.types.apply_generic(template, args) applied: Type = self.types.apply_generic(template, args)
return applied return applied
def match(self, template: Type, concrete: Type) -> dict[str, Type]: def match(
self,
template: Type,
concrete: Type,
match_return: bool = True,
) -> dict[str, Type]:
# TODO: if concrete is Generic, record bound TypeVar. Then when merging # TODO: if concrete is Generic, record bound TypeVar. Then when merging
# substitutions, check that the constraint is respected # substitutions, check that the constraint is respected
match (template, concrete): match (template, concrete):
@@ -91,10 +109,11 @@ class Unifier:
) )
substitutions = self.merge(substitutions, arg_subs) substitutions = self.merge(substitutions, arg_subs)
return_subs: dict[str, Type] = self.match( if match_return:
template.returns, concrete.returns return_subs: dict[str, Type] = self.match(
) template.returns, concrete.returns
substitutions = self.merge(substitutions, return_subs) )
substitutions = self.merge(substitutions, return_subs)
return substitutions return substitutions
@@ -110,6 +129,7 @@ class Unifier:
self.logger.debug( self.logger.debug(
f"Substitution already defined for {k} with type {merged[k]}, got {v}" f"Substitution already defined for {k} with type {merged[k]}, got {v}"
) )
raise UnificationError
merged[k] = v merged[k] = v
return merged return merged

View File

@@ -1,5 +1,20 @@
{ {
"diagnostics": [], "diagnostics": [
{
"type": "Error",
"location": {
"start": [
13,
15
],
"end": [
13,
32
]
},
"message": "Could not unify map[T, U]=(transform: (v: T, /) -> U, iterable: list[T], /) -> list[U] with pos=[(value: float) -> float, list[int]] and kw={}"
}
],
"judgments": [ "judgments": [
{ {
"location": { "location": {
@@ -682,17 +697,7 @@
], ],
"keywords": {} "keywords": {}
}, },
"type": { "type": {}
"name": "list",
"args": [
{
"name": "float"
}
],
"body": {
"name": "list"
}
}
}, },
{ {
"location": { "location": {