From 41d0c84bbe69f2ee471df7d83522d36756d86d9c Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sun, 21 Jun 2026 13:12:27 +0200 Subject: [PATCH 1/4] feat(checker): add unifier add unifier class to infer type parameters from local call context --- midas/checker/python.py | 24 +++++++ midas/checker/unifier.py | 149 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+) create mode 100644 midas/checker/unifier.py diff --git a/midas/checker/python.py b/midas/checker/python.py index 22ea98c..e1fb788 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -19,12 +19,14 @@ from midas.checker.types import ( AliasType, AppliedType, Function, + GenericType, OverloadedFunction, Type, UnitType, UnknownType, unfold_type, ) +from midas.checker.unifier import Unifier from midas.parser.python import PythonParser from midas.utils import TypedAST @@ -704,6 +706,28 @@ class PythonTyper( location, base, positional, keywords, report_errors ) + case GenericType(): + unifier: Unifier = Unifier(self.types) + pos: list[Type] = [a[1] for a in positional] + kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()} + unified: Optional[Type] = unifier.unify_call(callee, pos, kw) + if unified is None: + if report_errors: + pos_str: str = ", ".join(str(t) for t in pos) + kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items()) + self.reporter.error( + location, + f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}", + ) + return None + return self._get_call_result( + location, + unified, + positional, + keywords, + report_errors, + ) + case _: if report_errors: self.reporter.error( diff --git a/midas/checker/unifier.py b/midas/checker/unifier.py new file mode 100644 index 0000000..e414623 --- /dev/null +++ b/midas/checker/unifier.py @@ -0,0 +1,149 @@ +import logging +from typing import Optional + +from midas.checker.registry import TypesRegistry +from midas.checker.types import ( + AppliedType, + Function, + GenericType, + TopType, + Type, + TypeVar, +) + + +class Unifier: + def __init__(self, types: TypesRegistry) -> None: + self.types: TypesRegistry = types + self.logger: logging.Logger = logging.getLogger("Unifier") + + def unify_call( + self, + type: GenericType, + positional: list[Type], + keywords: dict[str, Type], + ) -> Optional[Type]: + concrete_func: Function = Function( + pos_args=[ + Function.Argument( + pos=i, + name=str(i), + type=arg, + required=True, + ) + for i, arg in enumerate(positional) + ], + args=[], + kw_args=[ + Function.Argument( + pos=len(positional) + i, + name=name, + type=arg, + required=True, + ) + for i, (name, arg) in enumerate(keywords.items()) + ], + returns=TopType(), # TODO: use expected type + ) + return self.unify_generic(type, concrete_func) + + def unify_generic(self, template: GenericType, concrete: Type) -> Optional[Type]: + substitutions: dict[str, Type] = self.match(template.body, concrete) + args: list[Type] = [] + for param in template.params: + if param.name not in substitutions: + return None + args.append(substitutions[param.name]) + + applied: Type = self.types.apply_generic(template, args) + return applied + + def match(self, template: Type, concrete: Type) -> dict[str, Type]: + # TODO: if concrete is Generic, record bound TypeVar. Then when merging + # substitutions, check that the constraint is respected + match (template, concrete): + case (TypeVar(name=name), _): + return {name: concrete} + + case ( + AppliedType(name=template_name, args=template_args), + AppliedType(name=concrete_name, args=concrete_args), + ) if template_name == concrete_name and len(template_args) == len( + concrete_args + ): + substitutions: dict[str, Type] = {} + for template_arg, concrete_arg in zip(template_args, concrete_args): + new_substistutions: dict[str, Type] = self.match( + template_arg, concrete_arg + ) + substitutions = self.merge(substitutions, new_substistutions) + + return substitutions + + case (Function(), Function()): + mapped: list[tuple[Function.Argument, Function.Argument]] = ( + self.map_params(template, concrete) + ) + substitutions: dict[str, Type] = {} + for template_arg, concrete_arg in mapped: + arg_subs: dict[str, Type] = self.match( + template_arg.type, concrete_arg.type + ) + substitutions = self.merge(substitutions, arg_subs) + + return_subs: dict[str, Type] = self.match( + template.returns, concrete.returns + ) + substitutions = self.merge(substitutions, return_subs) + + return substitutions + + case _: + self.logger.debug(f"Can't match {concrete!r} with {template!r}") + return {} + + def merge(self, subs1: dict[str, Type], subs2: dict[str, Type]) -> dict[str, Type]: + merged: dict[str, Type] = subs1.copy() + + for k, v in subs2.items(): + if k in merged and merged[k] != v: + self.logger.debug( + f"Substitution already defined for {k} with type {merged[k]}, got {v}" + ) + merged[k] = v + return merged + + def map_params( + self, func1: Function, func2: Function + ) -> list[tuple[Function.Argument, Function.Argument]]: + pos1: list[Function.Argument] = func1.pos_args + mixed1: list[Function.Argument] = func1.args + kw1: list[Function.Argument] = func1.kw_args + + pos2: list[Function.Argument] = func2.pos_args + mixed2: list[Function.Argument] = func2.args + kw2: list[Function.Argument] = func2.kw_args + + mapped: list[tuple[Function.Argument, Function.Argument]] = [] + + by_pos2: dict[int, Function.Argument] = {arg.pos: arg for arg in pos2 + mixed2} + by_name2: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2 + kw2} + + for arg1 in pos1: + if (arg2 := by_pos2.get(arg1.pos)) is not None: + mapped.append((arg1, arg2)) + + for arg1 in mixed1: + # Match both positionally and by name, conflicts are caught + # when merging substitutions + if (arg2 := by_pos2.get(arg1.pos)) is not None: + mapped.append((arg1, arg2)) + + if (arg2 := by_name2.get(arg1.name)) is not None: + mapped.append((arg1, arg2)) + + for arg1 in kw1: + if (arg2 := by_name2.get(arg1.name)) is not None: + mapped.append((arg1, arg2)) + + return mapped From b591f5508f1532f10a6689b17f3e1b0bfee50261 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sun, 21 Jun 2026 13:17:35 +0200 Subject: [PATCH 2/4] fix(checker): make map definition generic --- midas/checker/preamble.py | 1 + 1 file changed, 1 insertion(+) diff --git a/midas/checker/preamble.py b/midas/checker/preamble.py index ea7001b..96a4ef7 100644 --- a/midas/checker/preamble.py +++ b/midas/checker/preamble.py @@ -52,6 +52,7 @@ class Preamble(Environment): ), ], returns=self._list_of(map_out), # TODO: replace with Iterable[U] + type_vars=[map_in, map_out], ) def _list_of(self, item_type: Type) -> Type: From 29e601128dc1931ccfcdcc778f6020467496eaf7 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sun, 21 Jun 2026 13:19:17 +0200 Subject: [PATCH 3/4] tests: add unification test --- tests/cases/checker/08_unification.py | 14 + .../cases/checker/08_unification.py.ref.json | 869 ++++++++++++++++++ 2 files changed, 883 insertions(+) create mode 100644 tests/cases/checker/08_unification.py create mode 100644 tests/cases/checker/08_unification.py.ref.json diff --git a/tests/cases/checker/08_unification.py b/tests/cases/checker/08_unification.py new file mode 100644 index 0000000..ac828af --- /dev/null +++ b/tests/cases/checker/08_unification.py @@ -0,0 +1,14 @@ +def double(value: float) -> float: + return value * 2 + + +def is_odd(value: int) -> bool: + return bool(value % 2) + + +floats: list[float] = [0.2, 0.5, 0.1, 1.2] +ints: list[int] = [1, 2, 6, -3] + +doubled_floats = map(double, floats) +doubled_ints = map(double, ints) +odd_ints = map(is_odd, ints) diff --git a/tests/cases/checker/08_unification.py.ref.json b/tests/cases/checker/08_unification.py.ref.json new file mode 100644 index 0000000..fd36cbd --- /dev/null +++ b/tests/cases/checker/08_unification.py.ref.json @@ -0,0 +1,869 @@ +{ + "diagnostics": [], + "judgments": [ + { + "location": { + "from": "L2:11", + "to": "L2:16" + }, + "expr": { + "_type": "VariableExpr", + "name": "value" + }, + "type": { + "name": "float" + } + }, + { + "location": { + "from": "L2:19", + "to": "L2:20" + }, + "expr": { + "_type": "LiteralExpr", + "value": 2 + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L2:11", + "to": "L2:20" + }, + "expr": { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "value" + }, + "operator": "*", + "right": { + "_type": "LiteralExpr", + "value": 2 + } + }, + "type": { + "name": "float" + } + }, + { + "location": { + "from": "L6:11", + "to": "L6:15" + }, + "expr": { + "_type": "VariableExpr", + "name": "bool" + }, + "type": { + "pos_args": [ + { + "pos": 0, + "name": "object", + "type": {}, + "required": false + } + ], + "args": [], + "kw_args": [], + "returns": { + "name": "bool" + } + } + }, + { + "location": { + "from": "L6:16", + "to": "L6:21" + }, + "expr": { + "_type": "VariableExpr", + "name": "value" + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L6:24", + "to": "L6:25" + }, + "expr": { + "_type": "LiteralExpr", + "value": 2 + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L6:16", + "to": "L6:25" + }, + "expr": { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "value" + }, + "operator": "%", + "right": { + "_type": "LiteralExpr", + "value": 2 + } + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L6:11", + "to": "L6:26" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "VariableExpr", + "name": "bool" + }, + "arguments": [ + { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "value" + }, + "operator": "%", + "right": { + "_type": "LiteralExpr", + "value": 2 + } + } + ], + "keywords": {} + }, + "type": { + "name": "bool" + } + }, + { + "location": { + "from": "L9:23", + "to": "L9:26" + }, + "expr": { + "_type": "LiteralExpr", + "value": 0.2 + }, + "type": { + "name": "float" + } + }, + { + "location": { + "from": "L9:28", + "to": "L9:31" + }, + "expr": { + "_type": "LiteralExpr", + "value": 0.5 + }, + "type": { + "name": "float" + } + }, + { + "location": { + "from": "L9:33", + "to": "L9:36" + }, + "expr": { + "_type": "LiteralExpr", + "value": 0.1 + }, + "type": { + "name": "float" + } + }, + { + "location": { + "from": "L9:38", + "to": "L9:41" + }, + "expr": { + "_type": "LiteralExpr", + "value": 1.2 + }, + "type": { + "name": "float" + } + }, + { + "location": { + "from": "L9:22", + "to": "L9:42" + }, + "expr": { + "_type": "ListExpr", + "items": [ + { + "_type": "LiteralExpr", + "value": 0.2 + }, + { + "_type": "LiteralExpr", + "value": 0.5 + }, + { + "_type": "LiteralExpr", + "value": 0.1 + }, + { + "_type": "LiteralExpr", + "value": 1.2 + } + ] + }, + "type": { + "name": "list", + "args": [ + { + "name": "float" + } + ], + "body": { + "name": "list" + } + } + }, + { + "location": { + "from": "L10:19", + "to": "L10:20" + }, + "expr": { + "_type": "LiteralExpr", + "value": 1 + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L10:22", + "to": "L10:23" + }, + "expr": { + "_type": "LiteralExpr", + "value": 2 + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L10:25", + "to": "L10:26" + }, + "expr": { + "_type": "LiteralExpr", + "value": 6 + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L10:29", + "to": "L10:30" + }, + "expr": { + "_type": "LiteralExpr", + "value": 3 + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L10:28", + "to": "L10:30" + }, + "expr": { + "_type": "UnaryExpr", + "operator": "-", + "right": { + "_type": "LiteralExpr", + "value": 3 + } + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L10:18", + "to": "L10:31" + }, + "expr": { + "_type": "ListExpr", + "items": [ + { + "_type": "LiteralExpr", + "value": 1 + }, + { + "_type": "LiteralExpr", + "value": 2 + }, + { + "_type": "LiteralExpr", + "value": 6 + }, + { + "_type": "UnaryExpr", + "operator": "-", + "right": { + "_type": "LiteralExpr", + "value": 3 + } + } + ] + }, + "type": { + "name": "list", + "args": [ + { + "name": "int" + } + ], + "body": { + "name": "list" + } + } + }, + { + "location": { + "from": "L12:17", + "to": "L12:20" + }, + "expr": { + "_type": "VariableExpr", + "name": "map" + }, + "type": { + "name": "map", + "params": [ + { + "name": "T", + "bound": null, + "variance": "INVARIANT" + }, + { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "pos_args": [ + { + "pos": 0, + "name": "transform", + "type": { + "pos_args": [ + { + "pos": 0, + "name": "v", + "type": { + "name": "T", + "bound": null, + "variance": "INVARIANT" + }, + "required": true + } + ], + "args": [], + "kw_args": [], + "returns": { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + }, + "required": true + }, + { + "pos": 1, + "name": "iterable", + "type": { + "name": "list", + "args": [ + { + "name": "T", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "name": "list" + } + }, + "required": true + } + ], + "args": [], + "kw_args": [], + "returns": { + "name": "list", + "args": [ + { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "name": "list" + } + } + } + } + }, + { + "location": { + "from": "L12:21", + "to": "L12:27" + }, + "expr": { + "_type": "VariableExpr", + "name": "double" + }, + "type": { + "pos_args": [], + "args": [ + { + "pos": 0, + "name": "value", + "type": { + "name": "float" + }, + "required": true + } + ], + "kw_args": [], + "returns": { + "name": "float" + } + } + }, + { + "location": { + "from": "L12:29", + "to": "L12:35" + }, + "expr": { + "_type": "VariableExpr", + "name": "floats" + }, + "type": { + "name": "list", + "args": [ + { + "name": "float" + } + ], + "body": { + "name": "list" + } + } + }, + { + "location": { + "from": "L12:17", + "to": "L12:36" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "VariableExpr", + "name": "map" + }, + "arguments": [ + { + "_type": "VariableExpr", + "name": "double" + }, + { + "_type": "VariableExpr", + "name": "floats" + } + ], + "keywords": {} + }, + "type": { + "name": "list", + "args": [ + { + "name": "float" + } + ], + "body": { + "name": "list" + } + } + }, + { + "location": { + "from": "L13:15", + "to": "L13:18" + }, + "expr": { + "_type": "VariableExpr", + "name": "map" + }, + "type": { + "name": "map", + "params": [ + { + "name": "T", + "bound": null, + "variance": "INVARIANT" + }, + { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "pos_args": [ + { + "pos": 0, + "name": "transform", + "type": { + "pos_args": [ + { + "pos": 0, + "name": "v", + "type": { + "name": "T", + "bound": null, + "variance": "INVARIANT" + }, + "required": true + } + ], + "args": [], + "kw_args": [], + "returns": { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + }, + "required": true + }, + { + "pos": 1, + "name": "iterable", + "type": { + "name": "list", + "args": [ + { + "name": "T", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "name": "list" + } + }, + "required": true + } + ], + "args": [], + "kw_args": [], + "returns": { + "name": "list", + "args": [ + { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "name": "list" + } + } + } + } + }, + { + "location": { + "from": "L13:19", + "to": "L13:25" + }, + "expr": { + "_type": "VariableExpr", + "name": "double" + }, + "type": { + "pos_args": [], + "args": [ + { + "pos": 0, + "name": "value", + "type": { + "name": "float" + }, + "required": true + } + ], + "kw_args": [], + "returns": { + "name": "float" + } + } + }, + { + "location": { + "from": "L13:27", + "to": "L13:31" + }, + "expr": { + "_type": "VariableExpr", + "name": "ints" + }, + "type": { + "name": "list", + "args": [ + { + "name": "int" + } + ], + "body": { + "name": "list" + } + } + }, + { + "location": { + "from": "L13:15", + "to": "L13:32" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "VariableExpr", + "name": "map" + }, + "arguments": [ + { + "_type": "VariableExpr", + "name": "double" + }, + { + "_type": "VariableExpr", + "name": "ints" + } + ], + "keywords": {} + }, + "type": { + "name": "list", + "args": [ + { + "name": "float" + } + ], + "body": { + "name": "list" + } + } + }, + { + "location": { + "from": "L14:11", + "to": "L14:14" + }, + "expr": { + "_type": "VariableExpr", + "name": "map" + }, + "type": { + "name": "map", + "params": [ + { + "name": "T", + "bound": null, + "variance": "INVARIANT" + }, + { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "pos_args": [ + { + "pos": 0, + "name": "transform", + "type": { + "pos_args": [ + { + "pos": 0, + "name": "v", + "type": { + "name": "T", + "bound": null, + "variance": "INVARIANT" + }, + "required": true + } + ], + "args": [], + "kw_args": [], + "returns": { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + }, + "required": true + }, + { + "pos": 1, + "name": "iterable", + "type": { + "name": "list", + "args": [ + { + "name": "T", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "name": "list" + } + }, + "required": true + } + ], + "args": [], + "kw_args": [], + "returns": { + "name": "list", + "args": [ + { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "name": "list" + } + } + } + } + }, + { + "location": { + "from": "L14:15", + "to": "L14:21" + }, + "expr": { + "_type": "VariableExpr", + "name": "is_odd" + }, + "type": { + "pos_args": [], + "args": [ + { + "pos": 0, + "name": "value", + "type": { + "name": "int" + }, + "required": true + } + ], + "kw_args": [], + "returns": { + "name": "bool" + } + } + }, + { + "location": { + "from": "L14:23", + "to": "L14:27" + }, + "expr": { + "_type": "VariableExpr", + "name": "ints" + }, + "type": { + "name": "list", + "args": [ + { + "name": "int" + } + ], + "body": { + "name": "list" + } + } + }, + { + "location": { + "from": "L14:11", + "to": "L14:28" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "VariableExpr", + "name": "map" + }, + "arguments": [ + { + "_type": "VariableExpr", + "name": "is_odd" + }, + { + "_type": "VariableExpr", + "name": "ints" + } + ], + "keywords": {} + }, + "type": { + "name": "list", + "args": [ + { + "name": "bool" + } + ], + "body": { + "name": "list" + } + } + } + ] +} \ No newline at end of file From 4395e9339be59d266da57ca5d2e08d48702b6182 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sun, 21 Jun 2026 13:35:44 +0200 Subject: [PATCH 4/4] fix(checker): abort unification on conflict --- midas/checker/unifier.py | 36 ++++++++++++++----- .../cases/checker/08_unification.py.ref.json | 29 ++++++++------- 2 files changed, 45 insertions(+), 20 deletions(-) diff --git a/midas/checker/unifier.py b/midas/checker/unifier.py index e414623..aae9310 100644 --- a/midas/checker/unifier.py +++ b/midas/checker/unifier.py @@ -12,6 +12,9 @@ from midas.checker.types import ( ) +class UnificationError(Exception): ... + + class Unifier: def __init__(self, types: TypesRegistry) -> None: self.types: TypesRegistry = types @@ -45,10 +48,20 @@ class Unifier: ], returns=TopType(), # TODO: use expected type ) - return self.unify_generic(type, concrete_func) + return self.unify_generic(type, concrete_func, match_return=False) + + def unify_generic( + self, + template: GenericType, + concrete: Type, + match_return: bool = True, + ) -> Optional[Type]: + substitutions: dict[str, Type] + try: + substitutions = self.match(template.body, concrete, match_return) + except UnificationError: + return None - def unify_generic(self, template: GenericType, concrete: Type) -> Optional[Type]: - substitutions: dict[str, Type] = self.match(template.body, concrete) args: list[Type] = [] for param in template.params: if param.name not in substitutions: @@ -58,7 +71,12 @@ class Unifier: applied: Type = self.types.apply_generic(template, args) return applied - def match(self, template: Type, concrete: Type) -> dict[str, Type]: + def match( + self, + template: Type, + concrete: Type, + match_return: bool = True, + ) -> dict[str, Type]: # TODO: if concrete is Generic, record bound TypeVar. Then when merging # substitutions, check that the constraint is respected match (template, concrete): @@ -91,10 +109,11 @@ class Unifier: ) substitutions = self.merge(substitutions, arg_subs) - return_subs: dict[str, Type] = self.match( - template.returns, concrete.returns - ) - substitutions = self.merge(substitutions, return_subs) + if match_return: + return_subs: dict[str, Type] = self.match( + template.returns, concrete.returns + ) + substitutions = self.merge(substitutions, return_subs) return substitutions @@ -110,6 +129,7 @@ class Unifier: self.logger.debug( f"Substitution already defined for {k} with type {merged[k]}, got {v}" ) + raise UnificationError merged[k] = v return merged diff --git a/tests/cases/checker/08_unification.py.ref.json b/tests/cases/checker/08_unification.py.ref.json index fd36cbd..bfaa7bd 100644 --- a/tests/cases/checker/08_unification.py.ref.json +++ b/tests/cases/checker/08_unification.py.ref.json @@ -1,5 +1,20 @@ { - "diagnostics": [], + "diagnostics": [ + { + "type": "Error", + "location": { + "start": [ + 13, + 15 + ], + "end": [ + 13, + 32 + ] + }, + "message": "Could not unify map[T, U]=(transform: (v: T, /) -> U, iterable: list[T], /) -> list[U] with pos=[(value: float) -> float, list[int]] and kw={}" + } + ], "judgments": [ { "location": { @@ -682,17 +697,7 @@ ], "keywords": {} }, - "type": { - "name": "list", - "args": [ - { - "name": "float" - } - ], - "body": { - "name": "list" - } - } + "type": {} }, { "location": {