diff --git a/midas/checker/python.py b/midas/checker/python.py index 22ea98c..e1fb788 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -19,12 +19,14 @@ from midas.checker.types import ( AliasType, AppliedType, Function, + GenericType, OverloadedFunction, Type, UnitType, UnknownType, unfold_type, ) +from midas.checker.unifier import Unifier from midas.parser.python import PythonParser from midas.utils import TypedAST @@ -704,6 +706,28 @@ class PythonTyper( location, base, positional, keywords, report_errors ) + case GenericType(): + unifier: Unifier = Unifier(self.types) + pos: list[Type] = [a[1] for a in positional] + kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()} + unified: Optional[Type] = unifier.unify_call(callee, pos, kw) + if unified is None: + if report_errors: + pos_str: str = ", ".join(str(t) for t in pos) + kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items()) + self.reporter.error( + location, + f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}", + ) + return None + return self._get_call_result( + location, + unified, + positional, + keywords, + report_errors, + ) + case _: if report_errors: self.reporter.error( diff --git a/midas/checker/unifier.py b/midas/checker/unifier.py new file mode 100644 index 0000000..e414623 --- /dev/null +++ b/midas/checker/unifier.py @@ -0,0 +1,149 @@ +import logging +from typing import Optional + +from midas.checker.registry import TypesRegistry +from midas.checker.types import ( + AppliedType, + Function, + GenericType, + TopType, + Type, + TypeVar, +) + + +class Unifier: + def __init__(self, types: TypesRegistry) -> None: + self.types: TypesRegistry = types + self.logger: logging.Logger = logging.getLogger("Unifier") + + def unify_call( + self, + type: GenericType, + positional: list[Type], + keywords: dict[str, Type], + ) -> Optional[Type]: + concrete_func: Function = Function( + pos_args=[ + Function.Argument( + pos=i, + name=str(i), + type=arg, + required=True, + ) + for i, arg in enumerate(positional) + ], + args=[], + kw_args=[ + Function.Argument( + pos=len(positional) + i, + name=name, + type=arg, + required=True, + ) + for i, (name, arg) in enumerate(keywords.items()) + ], + returns=TopType(), # TODO: use expected type + ) + return self.unify_generic(type, concrete_func) + + def unify_generic(self, template: GenericType, concrete: Type) -> Optional[Type]: + substitutions: dict[str, Type] = self.match(template.body, concrete) + args: list[Type] = [] + for param in template.params: + if param.name not in substitutions: + return None + args.append(substitutions[param.name]) + + applied: Type = self.types.apply_generic(template, args) + return applied + + def match(self, template: Type, concrete: Type) -> dict[str, Type]: + # TODO: if concrete is Generic, record bound TypeVar. Then when merging + # substitutions, check that the constraint is respected + match (template, concrete): + case (TypeVar(name=name), _): + return {name: concrete} + + case ( + AppliedType(name=template_name, args=template_args), + AppliedType(name=concrete_name, args=concrete_args), + ) if template_name == concrete_name and len(template_args) == len( + concrete_args + ): + substitutions: dict[str, Type] = {} + for template_arg, concrete_arg in zip(template_args, concrete_args): + new_substistutions: dict[str, Type] = self.match( + template_arg, concrete_arg + ) + substitutions = self.merge(substitutions, new_substistutions) + + return substitutions + + case (Function(), Function()): + mapped: list[tuple[Function.Argument, Function.Argument]] = ( + self.map_params(template, concrete) + ) + substitutions: dict[str, Type] = {} + for template_arg, concrete_arg in mapped: + arg_subs: dict[str, Type] = self.match( + template_arg.type, concrete_arg.type + ) + substitutions = self.merge(substitutions, arg_subs) + + return_subs: dict[str, Type] = self.match( + template.returns, concrete.returns + ) + substitutions = self.merge(substitutions, return_subs) + + return substitutions + + case _: + self.logger.debug(f"Can't match {concrete!r} with {template!r}") + return {} + + def merge(self, subs1: dict[str, Type], subs2: dict[str, Type]) -> dict[str, Type]: + merged: dict[str, Type] = subs1.copy() + + for k, v in subs2.items(): + if k in merged and merged[k] != v: + self.logger.debug( + f"Substitution already defined for {k} with type {merged[k]}, got {v}" + ) + merged[k] = v + return merged + + def map_params( + self, func1: Function, func2: Function + ) -> list[tuple[Function.Argument, Function.Argument]]: + pos1: list[Function.Argument] = func1.pos_args + mixed1: list[Function.Argument] = func1.args + kw1: list[Function.Argument] = func1.kw_args + + pos2: list[Function.Argument] = func2.pos_args + mixed2: list[Function.Argument] = func2.args + kw2: list[Function.Argument] = func2.kw_args + + mapped: list[tuple[Function.Argument, Function.Argument]] = [] + + by_pos2: dict[int, Function.Argument] = {arg.pos: arg for arg in pos2 + mixed2} + by_name2: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2 + kw2} + + for arg1 in pos1: + if (arg2 := by_pos2.get(arg1.pos)) is not None: + mapped.append((arg1, arg2)) + + for arg1 in mixed1: + # Match both positionally and by name, conflicts are caught + # when merging substitutions + if (arg2 := by_pos2.get(arg1.pos)) is not None: + mapped.append((arg1, arg2)) + + if (arg2 := by_name2.get(arg1.name)) is not None: + mapped.append((arg1, arg2)) + + for arg1 in kw1: + if (arg2 := by_name2.get(arg1.name)) is not None: + mapped.append((arg1, arg2)) + + return mapped