fix(checker): abort unification on conflict
This commit is contained in:
@@ -12,6 +12,9 @@ from midas.checker.types import (
|
||||
)
|
||||
|
||||
|
||||
class UnificationError(Exception): ...
|
||||
|
||||
|
||||
class Unifier:
|
||||
def __init__(self, types: TypesRegistry) -> None:
|
||||
self.types: TypesRegistry = types
|
||||
@@ -45,10 +48,20 @@ class Unifier:
|
||||
],
|
||||
returns=TopType(), # TODO: use expected type
|
||||
)
|
||||
return self.unify_generic(type, concrete_func)
|
||||
return self.unify_generic(type, concrete_func, match_return=False)
|
||||
|
||||
def unify_generic(
|
||||
self,
|
||||
template: GenericType,
|
||||
concrete: Type,
|
||||
match_return: bool = True,
|
||||
) -> Optional[Type]:
|
||||
substitutions: dict[str, Type]
|
||||
try:
|
||||
substitutions = self.match(template.body, concrete, match_return)
|
||||
except UnificationError:
|
||||
return None
|
||||
|
||||
def unify_generic(self, template: GenericType, concrete: Type) -> Optional[Type]:
|
||||
substitutions: dict[str, Type] = self.match(template.body, concrete)
|
||||
args: list[Type] = []
|
||||
for param in template.params:
|
||||
if param.name not in substitutions:
|
||||
@@ -58,7 +71,12 @@ class Unifier:
|
||||
applied: Type = self.types.apply_generic(template, args)
|
||||
return applied
|
||||
|
||||
def match(self, template: Type, concrete: Type) -> dict[str, Type]:
|
||||
def match(
|
||||
self,
|
||||
template: Type,
|
||||
concrete: Type,
|
||||
match_return: bool = True,
|
||||
) -> dict[str, Type]:
|
||||
# TODO: if concrete is Generic, record bound TypeVar. Then when merging
|
||||
# substitutions, check that the constraint is respected
|
||||
match (template, concrete):
|
||||
@@ -91,10 +109,11 @@ class Unifier:
|
||||
)
|
||||
substitutions = self.merge(substitutions, arg_subs)
|
||||
|
||||
return_subs: dict[str, Type] = self.match(
|
||||
template.returns, concrete.returns
|
||||
)
|
||||
substitutions = self.merge(substitutions, return_subs)
|
||||
if match_return:
|
||||
return_subs: dict[str, Type] = self.match(
|
||||
template.returns, concrete.returns
|
||||
)
|
||||
substitutions = self.merge(substitutions, return_subs)
|
||||
|
||||
return substitutions
|
||||
|
||||
@@ -110,6 +129,7 @@ class Unifier:
|
||||
self.logger.debug(
|
||||
f"Substitution already defined for {k} with type {merged[k]}, got {v}"
|
||||
)
|
||||
raise UnificationError
|
||||
merged[k] = v
|
||||
return merged
|
||||
|
||||
|
||||
@@ -1,5 +1,20 @@
|
||||
{
|
||||
"diagnostics": [],
|
||||
"diagnostics": [
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
13,
|
||||
15
|
||||
],
|
||||
"end": [
|
||||
13,
|
||||
32
|
||||
]
|
||||
},
|
||||
"message": "Could not unify map[T, U]=(transform: (v: T, /) -> U, iterable: list[T], /) -> list[U] with pos=[(value: float) -> float, list[int]] and kw={}"
|
||||
}
|
||||
],
|
||||
"judgments": [
|
||||
{
|
||||
"location": {
|
||||
@@ -682,17 +697,7 @@
|
||||
],
|
||||
"keywords": {}
|
||||
},
|
||||
"type": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "float"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
}
|
||||
"type": {}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
|
||||
Reference in New Issue
Block a user