diff --git a/midas/checker/preamble.py b/midas/checker/preamble.py index ea7001b..96a4ef7 100644 --- a/midas/checker/preamble.py +++ b/midas/checker/preamble.py @@ -52,6 +52,7 @@ class Preamble(Environment): ), ], returns=self._list_of(map_out), # TODO: replace with Iterable[U] + type_vars=[map_in, map_out], ) def _list_of(self, item_type: Type) -> Type: diff --git a/midas/checker/python.py b/midas/checker/python.py index 22ea98c..e1fb788 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -19,12 +19,14 @@ from midas.checker.types import ( AliasType, AppliedType, Function, + GenericType, OverloadedFunction, Type, UnitType, UnknownType, unfold_type, ) +from midas.checker.unifier import Unifier from midas.parser.python import PythonParser from midas.utils import TypedAST @@ -704,6 +706,28 @@ class PythonTyper( location, base, positional, keywords, report_errors ) + case GenericType(): + unifier: Unifier = Unifier(self.types) + pos: list[Type] = [a[1] for a in positional] + kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()} + unified: Optional[Type] = unifier.unify_call(callee, pos, kw) + if unified is None: + if report_errors: + pos_str: str = ", ".join(str(t) for t in pos) + kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items()) + self.reporter.error( + location, + f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}", + ) + return None + return self._get_call_result( + location, + unified, + positional, + keywords, + report_errors, + ) + case _: if report_errors: self.reporter.error( diff --git a/midas/checker/unifier.py b/midas/checker/unifier.py new file mode 100644 index 0000000..aae9310 --- /dev/null +++ b/midas/checker/unifier.py @@ -0,0 +1,169 @@ +import logging +from typing import Optional + +from midas.checker.registry import TypesRegistry +from midas.checker.types import ( + AppliedType, + Function, + GenericType, + TopType, + Type, + TypeVar, +) + + +class UnificationError(Exception): ... + + +class Unifier: + def __init__(self, types: TypesRegistry) -> None: + self.types: TypesRegistry = types + self.logger: logging.Logger = logging.getLogger("Unifier") + + def unify_call( + self, + type: GenericType, + positional: list[Type], + keywords: dict[str, Type], + ) -> Optional[Type]: + concrete_func: Function = Function( + pos_args=[ + Function.Argument( + pos=i, + name=str(i), + type=arg, + required=True, + ) + for i, arg in enumerate(positional) + ], + args=[], + kw_args=[ + Function.Argument( + pos=len(positional) + i, + name=name, + type=arg, + required=True, + ) + for i, (name, arg) in enumerate(keywords.items()) + ], + returns=TopType(), # TODO: use expected type + ) + return self.unify_generic(type, concrete_func, match_return=False) + + def unify_generic( + self, + template: GenericType, + concrete: Type, + match_return: bool = True, + ) -> Optional[Type]: + substitutions: dict[str, Type] + try: + substitutions = self.match(template.body, concrete, match_return) + except UnificationError: + return None + + args: list[Type] = [] + for param in template.params: + if param.name not in substitutions: + return None + args.append(substitutions[param.name]) + + applied: Type = self.types.apply_generic(template, args) + return applied + + def match( + self, + template: Type, + concrete: Type, + match_return: bool = True, + ) -> dict[str, Type]: + # TODO: if concrete is Generic, record bound TypeVar. Then when merging + # substitutions, check that the constraint is respected + match (template, concrete): + case (TypeVar(name=name), _): + return {name: concrete} + + case ( + AppliedType(name=template_name, args=template_args), + AppliedType(name=concrete_name, args=concrete_args), + ) if template_name == concrete_name and len(template_args) == len( + concrete_args + ): + substitutions: dict[str, Type] = {} + for template_arg, concrete_arg in zip(template_args, concrete_args): + new_substistutions: dict[str, Type] = self.match( + template_arg, concrete_arg + ) + substitutions = self.merge(substitutions, new_substistutions) + + return substitutions + + case (Function(), Function()): + mapped: list[tuple[Function.Argument, Function.Argument]] = ( + self.map_params(template, concrete) + ) + substitutions: dict[str, Type] = {} + for template_arg, concrete_arg in mapped: + arg_subs: dict[str, Type] = self.match( + template_arg.type, concrete_arg.type + ) + substitutions = self.merge(substitutions, arg_subs) + + if match_return: + return_subs: dict[str, Type] = self.match( + template.returns, concrete.returns + ) + substitutions = self.merge(substitutions, return_subs) + + return substitutions + + case _: + self.logger.debug(f"Can't match {concrete!r} with {template!r}") + return {} + + def merge(self, subs1: dict[str, Type], subs2: dict[str, Type]) -> dict[str, Type]: + merged: dict[str, Type] = subs1.copy() + + for k, v in subs2.items(): + if k in merged and merged[k] != v: + self.logger.debug( + f"Substitution already defined for {k} with type {merged[k]}, got {v}" + ) + raise UnificationError + merged[k] = v + return merged + + def map_params( + self, func1: Function, func2: Function + ) -> list[tuple[Function.Argument, Function.Argument]]: + pos1: list[Function.Argument] = func1.pos_args + mixed1: list[Function.Argument] = func1.args + kw1: list[Function.Argument] = func1.kw_args + + pos2: list[Function.Argument] = func2.pos_args + mixed2: list[Function.Argument] = func2.args + kw2: list[Function.Argument] = func2.kw_args + + mapped: list[tuple[Function.Argument, Function.Argument]] = [] + + by_pos2: dict[int, Function.Argument] = {arg.pos: arg for arg in pos2 + mixed2} + by_name2: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2 + kw2} + + for arg1 in pos1: + if (arg2 := by_pos2.get(arg1.pos)) is not None: + mapped.append((arg1, arg2)) + + for arg1 in mixed1: + # Match both positionally and by name, conflicts are caught + # when merging substitutions + if (arg2 := by_pos2.get(arg1.pos)) is not None: + mapped.append((arg1, arg2)) + + if (arg2 := by_name2.get(arg1.name)) is not None: + mapped.append((arg1, arg2)) + + for arg1 in kw1: + if (arg2 := by_name2.get(arg1.name)) is not None: + mapped.append((arg1, arg2)) + + return mapped diff --git a/tests/cases/checker/08_unification.py b/tests/cases/checker/08_unification.py new file mode 100644 index 0000000..ac828af --- /dev/null +++ b/tests/cases/checker/08_unification.py @@ -0,0 +1,14 @@ +def double(value: float) -> float: + return value * 2 + + +def is_odd(value: int) -> bool: + return bool(value % 2) + + +floats: list[float] = [0.2, 0.5, 0.1, 1.2] +ints: list[int] = [1, 2, 6, -3] + +doubled_floats = map(double, floats) +doubled_ints = map(double, ints) +odd_ints = map(is_odd, ints) diff --git a/tests/cases/checker/08_unification.py.ref.json b/tests/cases/checker/08_unification.py.ref.json new file mode 100644 index 0000000..bfaa7bd --- /dev/null +++ b/tests/cases/checker/08_unification.py.ref.json @@ -0,0 +1,874 @@ +{ + "diagnostics": [ + { + "type": "Error", + "location": { + "start": [ + 13, + 15 + ], + "end": [ + 13, + 32 + ] + }, + "message": "Could not unify map[T, U]=(transform: (v: T, /) -> U, iterable: list[T], /) -> list[U] with pos=[(value: float) -> float, list[int]] and kw={}" + } + ], + "judgments": [ + { + "location": { + "from": "L2:11", + "to": "L2:16" + }, + "expr": { + "_type": "VariableExpr", + "name": "value" + }, + "type": { + "name": "float" + } + }, + { + "location": { + "from": "L2:19", + "to": "L2:20" + }, + "expr": { + "_type": "LiteralExpr", + "value": 2 + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L2:11", + "to": "L2:20" + }, + "expr": { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "value" + }, + "operator": "*", + "right": { + "_type": "LiteralExpr", + "value": 2 + } + }, + "type": { + "name": "float" + } + }, + { + "location": { + "from": "L6:11", + "to": "L6:15" + }, + "expr": { + "_type": "VariableExpr", + "name": "bool" + }, + "type": { + "pos_args": [ + { + "pos": 0, + "name": "object", + "type": {}, + "required": false + } + ], + "args": [], + "kw_args": [], + "returns": { + "name": "bool" + } + } + }, + { + "location": { + "from": "L6:16", + "to": "L6:21" + }, + "expr": { + "_type": "VariableExpr", + "name": "value" + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L6:24", + "to": "L6:25" + }, + "expr": { + "_type": "LiteralExpr", + "value": 2 + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L6:16", + "to": "L6:25" + }, + "expr": { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "value" + }, + "operator": "%", + "right": { + "_type": "LiteralExpr", + "value": 2 + } + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L6:11", + "to": "L6:26" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "VariableExpr", + "name": "bool" + }, + "arguments": [ + { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "value" + }, + "operator": "%", + "right": { + "_type": "LiteralExpr", + "value": 2 + } + } + ], + "keywords": {} + }, + "type": { + "name": "bool" + } + }, + { + "location": { + "from": "L9:23", + "to": "L9:26" + }, + "expr": { + "_type": "LiteralExpr", + "value": 0.2 + }, + "type": { + "name": "float" + } + }, + { + "location": { + "from": "L9:28", + "to": "L9:31" + }, + "expr": { + "_type": "LiteralExpr", + "value": 0.5 + }, + "type": { + "name": "float" + } + }, + { + "location": { + "from": "L9:33", + "to": "L9:36" + }, + "expr": { + "_type": "LiteralExpr", + "value": 0.1 + }, + "type": { + "name": "float" + } + }, + { + "location": { + "from": "L9:38", + "to": "L9:41" + }, + "expr": { + "_type": "LiteralExpr", + "value": 1.2 + }, + "type": { + "name": "float" + } + }, + { + "location": { + "from": "L9:22", + "to": "L9:42" + }, + "expr": { + "_type": "ListExpr", + "items": [ + { + "_type": "LiteralExpr", + "value": 0.2 + }, + { + "_type": "LiteralExpr", + "value": 0.5 + }, + { + "_type": "LiteralExpr", + "value": 0.1 + }, + { + "_type": "LiteralExpr", + "value": 1.2 + } + ] + }, + "type": { + "name": "list", + "args": [ + { + "name": "float" + } + ], + "body": { + "name": "list" + } + } + }, + { + "location": { + "from": "L10:19", + "to": "L10:20" + }, + "expr": { + "_type": "LiteralExpr", + "value": 1 + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L10:22", + "to": "L10:23" + }, + "expr": { + "_type": "LiteralExpr", + "value": 2 + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L10:25", + "to": "L10:26" + }, + "expr": { + "_type": "LiteralExpr", + "value": 6 + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L10:29", + "to": "L10:30" + }, + "expr": { + "_type": "LiteralExpr", + "value": 3 + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L10:28", + "to": "L10:30" + }, + "expr": { + "_type": "UnaryExpr", + "operator": "-", + "right": { + "_type": "LiteralExpr", + "value": 3 + } + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L10:18", + "to": "L10:31" + }, + "expr": { + "_type": "ListExpr", + "items": [ + { + "_type": "LiteralExpr", + "value": 1 + }, + { + "_type": "LiteralExpr", + "value": 2 + }, + { + "_type": "LiteralExpr", + "value": 6 + }, + { + "_type": "UnaryExpr", + "operator": "-", + "right": { + "_type": "LiteralExpr", + "value": 3 + } + } + ] + }, + "type": { + "name": "list", + "args": [ + { + "name": "int" + } + ], + "body": { + "name": "list" + } + } + }, + { + "location": { + "from": "L12:17", + "to": "L12:20" + }, + "expr": { + "_type": "VariableExpr", + "name": "map" + }, + "type": { + "name": "map", + "params": [ + { + "name": "T", + "bound": null, + "variance": "INVARIANT" + }, + { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "pos_args": [ + { + "pos": 0, + "name": "transform", + "type": { + "pos_args": [ + { + "pos": 0, + "name": "v", + "type": { + "name": "T", + "bound": null, + "variance": "INVARIANT" + }, + "required": true + } + ], + "args": [], + "kw_args": [], + "returns": { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + }, + "required": true + }, + { + "pos": 1, + "name": "iterable", + "type": { + "name": "list", + "args": [ + { + "name": "T", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "name": "list" + } + }, + "required": true + } + ], + "args": [], + "kw_args": [], + "returns": { + "name": "list", + "args": [ + { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "name": "list" + } + } + } + } + }, + { + "location": { + "from": "L12:21", + "to": "L12:27" + }, + "expr": { + "_type": "VariableExpr", + "name": "double" + }, + "type": { + "pos_args": [], + "args": [ + { + "pos": 0, + "name": "value", + "type": { + "name": "float" + }, + "required": true + } + ], + "kw_args": [], + "returns": { + "name": "float" + } + } + }, + { + "location": { + "from": "L12:29", + "to": "L12:35" + }, + "expr": { + "_type": "VariableExpr", + "name": "floats" + }, + "type": { + "name": "list", + "args": [ + { + "name": "float" + } + ], + "body": { + "name": "list" + } + } + }, + { + "location": { + "from": "L12:17", + "to": "L12:36" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "VariableExpr", + "name": "map" + }, + "arguments": [ + { + "_type": "VariableExpr", + "name": "double" + }, + { + "_type": "VariableExpr", + "name": "floats" + } + ], + "keywords": {} + }, + "type": { + "name": "list", + "args": [ + { + "name": "float" + } + ], + "body": { + "name": "list" + } + } + }, + { + "location": { + "from": "L13:15", + "to": "L13:18" + }, + "expr": { + "_type": "VariableExpr", + "name": "map" + }, + "type": { + "name": "map", + "params": [ + { + "name": "T", + "bound": null, + "variance": "INVARIANT" + }, + { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "pos_args": [ + { + "pos": 0, + "name": "transform", + "type": { + "pos_args": [ + { + "pos": 0, + "name": "v", + "type": { + "name": "T", + "bound": null, + "variance": "INVARIANT" + }, + "required": true + } + ], + "args": [], + "kw_args": [], + "returns": { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + }, + "required": true + }, + { + "pos": 1, + "name": "iterable", + "type": { + "name": "list", + "args": [ + { + "name": "T", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "name": "list" + } + }, + "required": true + } + ], + "args": [], + "kw_args": [], + "returns": { + "name": "list", + "args": [ + { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "name": "list" + } + } + } + } + }, + { + "location": { + "from": "L13:19", + "to": "L13:25" + }, + "expr": { + "_type": "VariableExpr", + "name": "double" + }, + "type": { + "pos_args": [], + "args": [ + { + "pos": 0, + "name": "value", + "type": { + "name": "float" + }, + "required": true + } + ], + "kw_args": [], + "returns": { + "name": "float" + } + } + }, + { + "location": { + "from": "L13:27", + "to": "L13:31" + }, + "expr": { + "_type": "VariableExpr", + "name": "ints" + }, + "type": { + "name": "list", + "args": [ + { + "name": "int" + } + ], + "body": { + "name": "list" + } + } + }, + { + "location": { + "from": "L13:15", + "to": "L13:32" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "VariableExpr", + "name": "map" + }, + "arguments": [ + { + "_type": "VariableExpr", + "name": "double" + }, + { + "_type": "VariableExpr", + "name": "ints" + } + ], + "keywords": {} + }, + "type": {} + }, + { + "location": { + "from": "L14:11", + "to": "L14:14" + }, + "expr": { + "_type": "VariableExpr", + "name": "map" + }, + "type": { + "name": "map", + "params": [ + { + "name": "T", + "bound": null, + "variance": "INVARIANT" + }, + { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "pos_args": [ + { + "pos": 0, + "name": "transform", + "type": { + "pos_args": [ + { + "pos": 0, + "name": "v", + "type": { + "name": "T", + "bound": null, + "variance": "INVARIANT" + }, + "required": true + } + ], + "args": [], + "kw_args": [], + "returns": { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + }, + "required": true + }, + { + "pos": 1, + "name": "iterable", + "type": { + "name": "list", + "args": [ + { + "name": "T", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "name": "list" + } + }, + "required": true + } + ], + "args": [], + "kw_args": [], + "returns": { + "name": "list", + "args": [ + { + "name": "U", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "name": "list" + } + } + } + } + }, + { + "location": { + "from": "L14:15", + "to": "L14:21" + }, + "expr": { + "_type": "VariableExpr", + "name": "is_odd" + }, + "type": { + "pos_args": [], + "args": [ + { + "pos": 0, + "name": "value", + "type": { + "name": "int" + }, + "required": true + } + ], + "kw_args": [], + "returns": { + "name": "bool" + } + } + }, + { + "location": { + "from": "L14:23", + "to": "L14:27" + }, + "expr": { + "_type": "VariableExpr", + "name": "ints" + }, + "type": { + "name": "list", + "args": [ + { + "name": "int" + } + ], + "body": { + "name": "list" + } + } + }, + { + "location": { + "from": "L14:11", + "to": "L14:28" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "VariableExpr", + "name": "map" + }, + "arguments": [ + { + "_type": "VariableExpr", + "name": "is_odd" + }, + { + "_type": "VariableExpr", + "name": "ints" + } + ], + "keywords": {} + }, + "type": { + "name": "list", + "args": [ + { + "name": "bool" + } + ], + "body": { + "name": "list" + } + } + } + ] +} \ No newline at end of file