Merge pull request 'Generic call unification' (#18) from feat/unification into main

Reviewed-on: #18
This commit was merged in pull request #18.
This commit is contained in:
2026-06-21 11:41:48 +00:00
5 changed files with 1082 additions and 0 deletions

View File

@@ -52,6 +52,7 @@ class Preamble(Environment):
),
],
returns=self._list_of(map_out), # TODO: replace with Iterable[U]
type_vars=[map_in, map_out],
)
def _list_of(self, item_type: Type) -> Type:

View File

@@ -19,12 +19,14 @@ from midas.checker.types import (
AliasType,
AppliedType,
Function,
GenericType,
OverloadedFunction,
Type,
UnitType,
UnknownType,
unfold_type,
)
from midas.checker.unifier import Unifier
from midas.parser.python import PythonParser
from midas.utils import TypedAST
@@ -704,6 +706,28 @@ class PythonTyper(
location, base, positional, keywords, report_errors
)
case GenericType():
unifier: Unifier = Unifier(self.types)
pos: list[Type] = [a[1] for a in positional]
kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()}
unified: Optional[Type] = unifier.unify_call(callee, pos, kw)
if unified is None:
if report_errors:
pos_str: str = ", ".join(str(t) for t in pos)
kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items())
self.reporter.error(
location,
f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}",
)
return None
return self._get_call_result(
location,
unified,
positional,
keywords,
report_errors,
)
case _:
if report_errors:
self.reporter.error(

169
midas/checker/unifier.py Normal file
View File

@@ -0,0 +1,169 @@
import logging
from typing import Optional
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
AppliedType,
Function,
GenericType,
TopType,
Type,
TypeVar,
)
class UnificationError(Exception): ...
class Unifier:
def __init__(self, types: TypesRegistry) -> None:
self.types: TypesRegistry = types
self.logger: logging.Logger = logging.getLogger("Unifier")
def unify_call(
self,
type: GenericType,
positional: list[Type],
keywords: dict[str, Type],
) -> Optional[Type]:
concrete_func: Function = Function(
pos_args=[
Function.Argument(
pos=i,
name=str(i),
type=arg,
required=True,
)
for i, arg in enumerate(positional)
],
args=[],
kw_args=[
Function.Argument(
pos=len(positional) + i,
name=name,
type=arg,
required=True,
)
for i, (name, arg) in enumerate(keywords.items())
],
returns=TopType(), # TODO: use expected type
)
return self.unify_generic(type, concrete_func, match_return=False)
def unify_generic(
self,
template: GenericType,
concrete: Type,
match_return: bool = True,
) -> Optional[Type]:
substitutions: dict[str, Type]
try:
substitutions = self.match(template.body, concrete, match_return)
except UnificationError:
return None
args: list[Type] = []
for param in template.params:
if param.name not in substitutions:
return None
args.append(substitutions[param.name])
applied: Type = self.types.apply_generic(template, args)
return applied
def match(
self,
template: Type,
concrete: Type,
match_return: bool = True,
) -> dict[str, Type]:
# TODO: if concrete is Generic, record bound TypeVar. Then when merging
# substitutions, check that the constraint is respected
match (template, concrete):
case (TypeVar(name=name), _):
return {name: concrete}
case (
AppliedType(name=template_name, args=template_args),
AppliedType(name=concrete_name, args=concrete_args),
) if template_name == concrete_name and len(template_args) == len(
concrete_args
):
substitutions: dict[str, Type] = {}
for template_arg, concrete_arg in zip(template_args, concrete_args):
new_substistutions: dict[str, Type] = self.match(
template_arg, concrete_arg
)
substitutions = self.merge(substitutions, new_substistutions)
return substitutions
case (Function(), Function()):
mapped: list[tuple[Function.Argument, Function.Argument]] = (
self.map_params(template, concrete)
)
substitutions: dict[str, Type] = {}
for template_arg, concrete_arg in mapped:
arg_subs: dict[str, Type] = self.match(
template_arg.type, concrete_arg.type
)
substitutions = self.merge(substitutions, arg_subs)
if match_return:
return_subs: dict[str, Type] = self.match(
template.returns, concrete.returns
)
substitutions = self.merge(substitutions, return_subs)
return substitutions
case _:
self.logger.debug(f"Can't match {concrete!r} with {template!r}")
return {}
def merge(self, subs1: dict[str, Type], subs2: dict[str, Type]) -> dict[str, Type]:
merged: dict[str, Type] = subs1.copy()
for k, v in subs2.items():
if k in merged and merged[k] != v:
self.logger.debug(
f"Substitution already defined for {k} with type {merged[k]}, got {v}"
)
raise UnificationError
merged[k] = v
return merged
def map_params(
self, func1: Function, func2: Function
) -> list[tuple[Function.Argument, Function.Argument]]:
pos1: list[Function.Argument] = func1.pos_args
mixed1: list[Function.Argument] = func1.args
kw1: list[Function.Argument] = func1.kw_args
pos2: list[Function.Argument] = func2.pos_args
mixed2: list[Function.Argument] = func2.args
kw2: list[Function.Argument] = func2.kw_args
mapped: list[tuple[Function.Argument, Function.Argument]] = []
by_pos2: dict[int, Function.Argument] = {arg.pos: arg for arg in pos2 + mixed2}
by_name2: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2 + kw2}
for arg1 in pos1:
if (arg2 := by_pos2.get(arg1.pos)) is not None:
mapped.append((arg1, arg2))
for arg1 in mixed1:
# Match both positionally and by name, conflicts are caught
# when merging substitutions
if (arg2 := by_pos2.get(arg1.pos)) is not None:
mapped.append((arg1, arg2))
if (arg2 := by_name2.get(arg1.name)) is not None:
mapped.append((arg1, arg2))
for arg1 in kw1:
if (arg2 := by_name2.get(arg1.name)) is not None:
mapped.append((arg1, arg2))
return mapped

View File

@@ -0,0 +1,14 @@
def double(value: float) -> float:
return value * 2
def is_odd(value: int) -> bool:
return bool(value % 2)
floats: list[float] = [0.2, 0.5, 0.1, 1.2]
ints: list[int] = [1, 2, 6, -3]
doubled_floats = map(double, floats)
doubled_ints = map(double, ints)
odd_ints = map(is_odd, ints)

View File

@@ -0,0 +1,874 @@
{
"diagnostics": [
{
"type": "Error",
"location": {
"start": [
13,
15
],
"end": [
13,
32
]
},
"message": "Could not unify map[T, U]=(transform: (v: T, /) -> U, iterable: list[T], /) -> list[U] with pos=[(value: float) -> float, list[int]] and kw={}"
}
],
"judgments": [
{
"location": {
"from": "L2:11",
"to": "L2:16"
},
"expr": {
"_type": "VariableExpr",
"name": "value"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L2:19",
"to": "L2:20"
},
"expr": {
"_type": "LiteralExpr",
"value": 2
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L2:11",
"to": "L2:20"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "value"
},
"operator": "*",
"right": {
"_type": "LiteralExpr",
"value": 2
}
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L6:11",
"to": "L6:15"
},
"expr": {
"_type": "VariableExpr",
"name": "bool"
},
"type": {
"pos_args": [
{
"pos": 0,
"name": "object",
"type": {},
"required": false
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "bool"
}
}
},
{
"location": {
"from": "L6:16",
"to": "L6:21"
},
"expr": {
"_type": "VariableExpr",
"name": "value"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:24",
"to": "L6:25"
},
"expr": {
"_type": "LiteralExpr",
"value": 2
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:16",
"to": "L6:25"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "value"
},
"operator": "%",
"right": {
"_type": "LiteralExpr",
"value": 2
}
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:11",
"to": "L6:26"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "VariableExpr",
"name": "bool"
},
"arguments": [
{
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "value"
},
"operator": "%",
"right": {
"_type": "LiteralExpr",
"value": 2
}
}
],
"keywords": {}
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L9:23",
"to": "L9:26"
},
"expr": {
"_type": "LiteralExpr",
"value": 0.2
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:28",
"to": "L9:31"
},
"expr": {
"_type": "LiteralExpr",
"value": 0.5
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:33",
"to": "L9:36"
},
"expr": {
"_type": "LiteralExpr",
"value": 0.1
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:38",
"to": "L9:41"
},
"expr": {
"_type": "LiteralExpr",
"value": 1.2
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:22",
"to": "L9:42"
},
"expr": {
"_type": "ListExpr",
"items": [
{
"_type": "LiteralExpr",
"value": 0.2
},
{
"_type": "LiteralExpr",
"value": 0.5
},
{
"_type": "LiteralExpr",
"value": 0.1
},
{
"_type": "LiteralExpr",
"value": 1.2
}
]
},
"type": {
"name": "list",
"args": [
{
"name": "float"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L10:19",
"to": "L10:20"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:22",
"to": "L10:23"
},
"expr": {
"_type": "LiteralExpr",
"value": 2
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:25",
"to": "L10:26"
},
"expr": {
"_type": "LiteralExpr",
"value": 6
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:29",
"to": "L10:30"
},
"expr": {
"_type": "LiteralExpr",
"value": 3
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:28",
"to": "L10:30"
},
"expr": {
"_type": "UnaryExpr",
"operator": "-",
"right": {
"_type": "LiteralExpr",
"value": 3
}
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:18",
"to": "L10:31"
},
"expr": {
"_type": "ListExpr",
"items": [
{
"_type": "LiteralExpr",
"value": 1
},
{
"_type": "LiteralExpr",
"value": 2
},
{
"_type": "LiteralExpr",
"value": 6
},
{
"_type": "UnaryExpr",
"operator": "-",
"right": {
"_type": "LiteralExpr",
"value": 3
}
}
]
},
"type": {
"name": "list",
"args": [
{
"name": "int"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L12:17",
"to": "L12:20"
},
"expr": {
"_type": "VariableExpr",
"name": "map"
},
"type": {
"name": "map",
"params": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"pos_args": [
{
"pos": 0,
"name": "transform",
"type": {
"pos_args": [
{
"pos": 0,
"name": "v",
"type": {
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
},
"required": true
},
{
"pos": 1,
"name": "iterable",
"type": {
"name": "list",
"args": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "list",
"args": [
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
}
}
}
},
{
"location": {
"from": "L12:21",
"to": "L12:27"
},
"expr": {
"_type": "VariableExpr",
"name": "double"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "float"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "float"
}
}
},
{
"location": {
"from": "L12:29",
"to": "L12:35"
},
"expr": {
"_type": "VariableExpr",
"name": "floats"
},
"type": {
"name": "list",
"args": [
{
"name": "float"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L12:17",
"to": "L12:36"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "VariableExpr",
"name": "map"
},
"arguments": [
{
"_type": "VariableExpr",
"name": "double"
},
{
"_type": "VariableExpr",
"name": "floats"
}
],
"keywords": {}
},
"type": {
"name": "list",
"args": [
{
"name": "float"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L13:15",
"to": "L13:18"
},
"expr": {
"_type": "VariableExpr",
"name": "map"
},
"type": {
"name": "map",
"params": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"pos_args": [
{
"pos": 0,
"name": "transform",
"type": {
"pos_args": [
{
"pos": 0,
"name": "v",
"type": {
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
},
"required": true
},
{
"pos": 1,
"name": "iterable",
"type": {
"name": "list",
"args": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "list",
"args": [
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
}
}
}
},
{
"location": {
"from": "L13:19",
"to": "L13:25"
},
"expr": {
"_type": "VariableExpr",
"name": "double"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "float"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "float"
}
}
},
{
"location": {
"from": "L13:27",
"to": "L13:31"
},
"expr": {
"_type": "VariableExpr",
"name": "ints"
},
"type": {
"name": "list",
"args": [
{
"name": "int"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L13:15",
"to": "L13:32"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "VariableExpr",
"name": "map"
},
"arguments": [
{
"_type": "VariableExpr",
"name": "double"
},
{
"_type": "VariableExpr",
"name": "ints"
}
],
"keywords": {}
},
"type": {}
},
{
"location": {
"from": "L14:11",
"to": "L14:14"
},
"expr": {
"_type": "VariableExpr",
"name": "map"
},
"type": {
"name": "map",
"params": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"pos_args": [
{
"pos": 0,
"name": "transform",
"type": {
"pos_args": [
{
"pos": 0,
"name": "v",
"type": {
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
},
"required": true
},
{
"pos": 1,
"name": "iterable",
"type": {
"name": "list",
"args": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "list",
"args": [
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
}
}
}
},
{
"location": {
"from": "L14:15",
"to": "L14:21"
},
"expr": {
"_type": "VariableExpr",
"name": "is_odd"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "int"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "bool"
}
}
},
{
"location": {
"from": "L14:23",
"to": "L14:27"
},
"expr": {
"_type": "VariableExpr",
"name": "ints"
},
"type": {
"name": "list",
"args": [
{
"name": "int"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L14:11",
"to": "L14:28"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "VariableExpr",
"name": "map"
},
"arguments": [
{
"_type": "VariableExpr",
"name": "is_odd"
},
{
"_type": "VariableExpr",
"name": "ints"
}
],
"keywords": {}
},
"type": {
"name": "list",
"args": [
{
"name": "bool"
}
],
"body": {
"name": "list"
}
}
}
]
}