feat(checker): add unifier
add unifier class to infer type parameters from local call context
This commit is contained in:
@@ -19,12 +19,14 @@ from midas.checker.types import (
|
|||||||
AliasType,
|
AliasType,
|
||||||
AppliedType,
|
AppliedType,
|
||||||
Function,
|
Function,
|
||||||
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
Type,
|
Type,
|
||||||
UnitType,
|
UnitType,
|
||||||
UnknownType,
|
UnknownType,
|
||||||
unfold_type,
|
unfold_type,
|
||||||
)
|
)
|
||||||
|
from midas.checker.unifier import Unifier
|
||||||
from midas.parser.python import PythonParser
|
from midas.parser.python import PythonParser
|
||||||
from midas.utils import TypedAST
|
from midas.utils import TypedAST
|
||||||
|
|
||||||
@@ -704,6 +706,28 @@ class PythonTyper(
|
|||||||
location, base, positional, keywords, report_errors
|
location, base, positional, keywords, report_errors
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case GenericType():
|
||||||
|
unifier: Unifier = Unifier(self.types)
|
||||||
|
pos: list[Type] = [a[1] for a in positional]
|
||||||
|
kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()}
|
||||||
|
unified: Optional[Type] = unifier.unify_call(callee, pos, kw)
|
||||||
|
if unified is None:
|
||||||
|
if report_errors:
|
||||||
|
pos_str: str = ", ".join(str(t) for t in pos)
|
||||||
|
kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items())
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return self._get_call_result(
|
||||||
|
location,
|
||||||
|
unified,
|
||||||
|
positional,
|
||||||
|
keywords,
|
||||||
|
report_errors,
|
||||||
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
if report_errors:
|
if report_errors:
|
||||||
self.reporter.error(
|
self.reporter.error(
|
||||||
|
|||||||
149
midas/checker/unifier.py
Normal file
149
midas/checker/unifier.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.types import (
|
||||||
|
AppliedType,
|
||||||
|
Function,
|
||||||
|
GenericType,
|
||||||
|
TopType,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Unifier:
|
||||||
|
def __init__(self, types: TypesRegistry) -> None:
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
self.logger: logging.Logger = logging.getLogger("Unifier")
|
||||||
|
|
||||||
|
def unify_call(
|
||||||
|
self,
|
||||||
|
type: GenericType,
|
||||||
|
positional: list[Type],
|
||||||
|
keywords: dict[str, Type],
|
||||||
|
) -> Optional[Type]:
|
||||||
|
concrete_func: Function = Function(
|
||||||
|
pos_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=i,
|
||||||
|
name=str(i),
|
||||||
|
type=arg,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
for i, arg in enumerate(positional)
|
||||||
|
],
|
||||||
|
args=[],
|
||||||
|
kw_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=len(positional) + i,
|
||||||
|
name=name,
|
||||||
|
type=arg,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
for i, (name, arg) in enumerate(keywords.items())
|
||||||
|
],
|
||||||
|
returns=TopType(), # TODO: use expected type
|
||||||
|
)
|
||||||
|
return self.unify_generic(type, concrete_func)
|
||||||
|
|
||||||
|
def unify_generic(self, template: GenericType, concrete: Type) -> Optional[Type]:
|
||||||
|
substitutions: dict[str, Type] = self.match(template.body, concrete)
|
||||||
|
args: list[Type] = []
|
||||||
|
for param in template.params:
|
||||||
|
if param.name not in substitutions:
|
||||||
|
return None
|
||||||
|
args.append(substitutions[param.name])
|
||||||
|
|
||||||
|
applied: Type = self.types.apply_generic(template, args)
|
||||||
|
return applied
|
||||||
|
|
||||||
|
def match(self, template: Type, concrete: Type) -> dict[str, Type]:
|
||||||
|
# TODO: if concrete is Generic, record bound TypeVar. Then when merging
|
||||||
|
# substitutions, check that the constraint is respected
|
||||||
|
match (template, concrete):
|
||||||
|
case (TypeVar(name=name), _):
|
||||||
|
return {name: concrete}
|
||||||
|
|
||||||
|
case (
|
||||||
|
AppliedType(name=template_name, args=template_args),
|
||||||
|
AppliedType(name=concrete_name, args=concrete_args),
|
||||||
|
) if template_name == concrete_name and len(template_args) == len(
|
||||||
|
concrete_args
|
||||||
|
):
|
||||||
|
substitutions: dict[str, Type] = {}
|
||||||
|
for template_arg, concrete_arg in zip(template_args, concrete_args):
|
||||||
|
new_substistutions: dict[str, Type] = self.match(
|
||||||
|
template_arg, concrete_arg
|
||||||
|
)
|
||||||
|
substitutions = self.merge(substitutions, new_substistutions)
|
||||||
|
|
||||||
|
return substitutions
|
||||||
|
|
||||||
|
case (Function(), Function()):
|
||||||
|
mapped: list[tuple[Function.Argument, Function.Argument]] = (
|
||||||
|
self.map_params(template, concrete)
|
||||||
|
)
|
||||||
|
substitutions: dict[str, Type] = {}
|
||||||
|
for template_arg, concrete_arg in mapped:
|
||||||
|
arg_subs: dict[str, Type] = self.match(
|
||||||
|
template_arg.type, concrete_arg.type
|
||||||
|
)
|
||||||
|
substitutions = self.merge(substitutions, arg_subs)
|
||||||
|
|
||||||
|
return_subs: dict[str, Type] = self.match(
|
||||||
|
template.returns, concrete.returns
|
||||||
|
)
|
||||||
|
substitutions = self.merge(substitutions, return_subs)
|
||||||
|
|
||||||
|
return substitutions
|
||||||
|
|
||||||
|
case _:
|
||||||
|
self.logger.debug(f"Can't match {concrete!r} with {template!r}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def merge(self, subs1: dict[str, Type], subs2: dict[str, Type]) -> dict[str, Type]:
|
||||||
|
merged: dict[str, Type] = subs1.copy()
|
||||||
|
|
||||||
|
for k, v in subs2.items():
|
||||||
|
if k in merged and merged[k] != v:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Substitution already defined for {k} with type {merged[k]}, got {v}"
|
||||||
|
)
|
||||||
|
merged[k] = v
|
||||||
|
return merged
|
||||||
|
|
||||||
|
def map_params(
|
||||||
|
self, func1: Function, func2: Function
|
||||||
|
) -> list[tuple[Function.Argument, Function.Argument]]:
|
||||||
|
pos1: list[Function.Argument] = func1.pos_args
|
||||||
|
mixed1: list[Function.Argument] = func1.args
|
||||||
|
kw1: list[Function.Argument] = func1.kw_args
|
||||||
|
|
||||||
|
pos2: list[Function.Argument] = func2.pos_args
|
||||||
|
mixed2: list[Function.Argument] = func2.args
|
||||||
|
kw2: list[Function.Argument] = func2.kw_args
|
||||||
|
|
||||||
|
mapped: list[tuple[Function.Argument, Function.Argument]] = []
|
||||||
|
|
||||||
|
by_pos2: dict[int, Function.Argument] = {arg.pos: arg for arg in pos2 + mixed2}
|
||||||
|
by_name2: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2 + kw2}
|
||||||
|
|
||||||
|
for arg1 in pos1:
|
||||||
|
if (arg2 := by_pos2.get(arg1.pos)) is not None:
|
||||||
|
mapped.append((arg1, arg2))
|
||||||
|
|
||||||
|
for arg1 in mixed1:
|
||||||
|
# Match both positionally and by name, conflicts are caught
|
||||||
|
# when merging substitutions
|
||||||
|
if (arg2 := by_pos2.get(arg1.pos)) is not None:
|
||||||
|
mapped.append((arg1, arg2))
|
||||||
|
|
||||||
|
if (arg2 := by_name2.get(arg1.name)) is not None:
|
||||||
|
mapped.append((arg1, arg2))
|
||||||
|
|
||||||
|
for arg1 in kw1:
|
||||||
|
if (arg2 := by_name2.get(arg1.name)) is not None:
|
||||||
|
mapped.append((arg1, arg2))
|
||||||
|
|
||||||
|
return mapped
|
||||||
Reference in New Issue
Block a user