348 lines
12 KiB
Python
348 lines
12 KiB
Python
import logging
|
|
from typing import Optional
|
|
|
|
from midas.checker.builtins import BUILTIN_SUBTYPES
|
|
from midas.checker.types import (
|
|
AliasType,
|
|
AppliedType,
|
|
BaseType,
|
|
ComplexType,
|
|
ExtensionType,
|
|
Function,
|
|
GenericType,
|
|
OverloadedFunction,
|
|
TopType,
|
|
Type,
|
|
TypeVar,
|
|
UnknownType,
|
|
substitute_typevars,
|
|
)
|
|
|
|
|
|
class TypesRegistry:
|
|
def __init__(self) -> None:
|
|
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
|
|
self._types: dict[str, Type] = {}
|
|
self._members: dict[str, dict[str, Type]] = {}
|
|
|
|
def get_type(self, name: str) -> Type:
|
|
"""Get a type from its name
|
|
|
|
Args:
|
|
name (str): the name of the type
|
|
|
|
Raises:
|
|
NameError: if the type is not defined
|
|
|
|
Returns:
|
|
Type: the type
|
|
"""
|
|
if name in self._types:
|
|
return self._types[name]
|
|
raise NameError(f"Undefined type {name}")
|
|
|
|
def define_type(self, name: str, type: Type) -> Type:
|
|
"""Define a type in the registry
|
|
|
|
Args:
|
|
name (str): the name of the type
|
|
type (Type): the type to define
|
|
|
|
Raises:
|
|
ValueError: if a type is already defined with that name
|
|
|
|
Returns:
|
|
Type: the defined type
|
|
"""
|
|
if name in self._types:
|
|
raise ValueError(f"Type {name} already defined")
|
|
self._types[name] = type
|
|
return type
|
|
|
|
def define_member(
|
|
self, type_name: str, member_name: str, member_type: Type, is_method: bool
|
|
):
|
|
members: dict[str, Type] = self._members.setdefault(type_name, {})
|
|
if member_name in members:
|
|
if not is_method:
|
|
self.logger.error(
|
|
f"Member '{member_name}' already defined for type {type_name}"
|
|
)
|
|
return
|
|
current: Type = members[member_name]
|
|
combined: Type
|
|
match current:
|
|
case OverloadedFunction(overloads=overloads):
|
|
combined = OverloadedFunction(overloads=overloads + [member_type])
|
|
case _:
|
|
combined = OverloadedFunction(overloads=[current, member_type])
|
|
members[member_name] = combined
|
|
|
|
else:
|
|
members[member_name] = member_type
|
|
|
|
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
|
"""Check whether `type1` is a subtype of `type2`
|
|
|
|
For more details on the rules checked here, see TAPL Chap. 15-16-17
|
|
|
|
Args:
|
|
type1 (Type): the potential subtype
|
|
type2 (Type): the potential supertype
|
|
|
|
Returns:
|
|
bool: whether `type1` is a subtype of `type2`
|
|
"""
|
|
|
|
if type1 == type2:
|
|
return True
|
|
|
|
match (type1, type2):
|
|
case (_, TopType()):
|
|
return True
|
|
|
|
case (AliasType(type=base1), _):
|
|
return self.is_subtype(base1, type2)
|
|
|
|
case (BaseType(name=name1), BaseType(name=name2)):
|
|
return name1 in BUILTIN_SUBTYPES.get(name2, set())
|
|
|
|
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
|
for k, t in props2.items():
|
|
if k not in props1:
|
|
return False
|
|
if not self.is_subtype(props1[k], t):
|
|
return False
|
|
return True
|
|
|
|
case (Function(), Function()):
|
|
return self.is_func_subtype(type1, type2)
|
|
|
|
case (TypeVar(bound=bound), _):
|
|
if bound is None:
|
|
return False
|
|
return self.is_subtype(bound, type2)
|
|
|
|
return False
|
|
|
|
# TODO: verify the logic in here
|
|
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
|
|
"""Check whether a function is a subtype of another
|
|
|
|
Args:
|
|
func1 (Function): the potential function subtype
|
|
func2 (Function): the potential function supertype
|
|
|
|
Returns:
|
|
bool: whether `func1` is a subtype of `func2`
|
|
"""
|
|
if not self.is_subtype(func1.returns, func2.returns):
|
|
return False
|
|
|
|
pos1: list[Function.Argument] = func1.pos_args
|
|
mixed1: list[Function.Argument] = func1.args
|
|
kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args}
|
|
pos2: list[Function.Argument] = func2.pos_args
|
|
mixed2: list[Function.Argument] = func2.args
|
|
kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args}
|
|
|
|
mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2}
|
|
mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2}
|
|
|
|
def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool:
|
|
if not self.is_subtype(sub.type, sup.type):
|
|
return False
|
|
if not sup.required and sub.required:
|
|
return False
|
|
return True
|
|
|
|
for arg1 in pos1:
|
|
arg2: Function.Argument
|
|
if arg1.pos < len(pos2):
|
|
arg2 = pos2[arg1.pos]
|
|
elif arg1.pos in mixed_by_pos:
|
|
arg2 = mixed_by_pos[arg1.pos]
|
|
elif not arg1.required:
|
|
continue
|
|
else:
|
|
return False
|
|
if not is_arg_subtype(arg2, arg1):
|
|
return False
|
|
|
|
for name, arg1 in kw1.items():
|
|
arg2: Function.Argument
|
|
if name in kw2:
|
|
arg2 = kw2[name]
|
|
elif name in mixed_by_name:
|
|
arg2 = mixed_by_name[name]
|
|
elif not arg1.required:
|
|
continue
|
|
else:
|
|
return False
|
|
if not is_arg_subtype(arg2, arg1):
|
|
return False
|
|
|
|
for arg1 in mixed1:
|
|
pos_arg2: Optional[Function.Argument] = None
|
|
kw_arg2: Optional[Function.Argument] = None
|
|
if arg1.name in kw2:
|
|
kw_arg2 = kw2[arg1.name]
|
|
elif arg1.name in mixed_by_name:
|
|
kw_arg2 = mixed_by_name[arg1.name]
|
|
if arg1.pos < len(pos2):
|
|
pos_arg2 = pos2[arg1.pos]
|
|
elif arg1.pos in mixed_by_pos:
|
|
pos_arg2 = mixed_by_pos[arg1.pos]
|
|
|
|
# No match in func2 and arg is required
|
|
if pos_arg2 is None and kw_arg2 is None and arg1.required:
|
|
return False
|
|
|
|
# Matching keyword argument
|
|
if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1):
|
|
return False
|
|
|
|
# Matching positional argument
|
|
if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1):
|
|
return False
|
|
|
|
mixed_positions: set[int] = {a.pos for a in mixed1}
|
|
mixed_names: set[str] = {a.name for a in mixed1}
|
|
for arg2 in pos2:
|
|
if not arg2.required:
|
|
continue
|
|
if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions:
|
|
return False
|
|
|
|
for name, arg2 in kw2.items():
|
|
if not arg2.required:
|
|
continue
|
|
if name not in kw1 and name not in mixed_names:
|
|
return False
|
|
|
|
for arg2 in mixed2:
|
|
if arg2.required:
|
|
continue
|
|
pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions
|
|
kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names
|
|
if not pos_match or not kw_match:
|
|
return False
|
|
|
|
return True
|
|
|
|
def apply_generic(self, type: Type, args: list[Type]) -> Type:
|
|
match type:
|
|
case AliasType(name=name, type=base):
|
|
return AliasType(name=name, type=self.apply_generic(base, args))
|
|
|
|
case GenericType(name=name, params=type_vars, body=body):
|
|
n_args: int = len(args)
|
|
n_type_vars: int = len(type_vars)
|
|
if n_args < n_type_vars:
|
|
raise ValueError(
|
|
f"Missing type arguments, expected {n_type_vars} but only {n_args} provided"
|
|
)
|
|
if n_args > n_type_vars:
|
|
raise ValueError(
|
|
f"Too many type arguments, expected {n_type_vars} but {n_args} provided"
|
|
)
|
|
substitutions: dict[str, Type] = {}
|
|
for arg, type_var in zip(args, type_vars):
|
|
if type_var.bound is not None and not self.is_subtype(
|
|
arg, type_var.bound
|
|
):
|
|
raise ValueError(
|
|
f"Type argument {arg} is not a subtype of {type_var.bound}"
|
|
)
|
|
substitutions[type_var.name] = arg
|
|
return AppliedType(
|
|
name=name,
|
|
args=args,
|
|
body=substitute_typevars(body, substitutions),
|
|
)
|
|
|
|
case _:
|
|
raise ValueError(f"{type} is not a generic type")
|
|
|
|
def reduce_types(self, types: list[Type]) -> list[Type]:
|
|
"""Reduce a list of types to remove subtypes and only keep the highest types
|
|
|
|
Args:
|
|
types (list[Type]): the types to reduce
|
|
|
|
Returns:
|
|
list[Type]: the reduced list of types
|
|
"""
|
|
|
|
reduced: bool = True
|
|
keep: list[int] = list(range(len(types)))
|
|
while reduced:
|
|
reduced = False
|
|
for i, i1 in enumerate(keep):
|
|
type1: Type = types[i1]
|
|
for i2 in keep[i + 1 :]:
|
|
type2 = types[i2]
|
|
if self.is_subtype(type1, type2):
|
|
keep.remove(i1)
|
|
elif self.is_subtype(type2, type1):
|
|
keep.remove(i2)
|
|
else:
|
|
continue
|
|
reduced = True
|
|
break
|
|
return [types[i] for i in keep]
|
|
|
|
def lookup_member(self, type: Type, member_name: str) -> Optional[Type]:
|
|
match type:
|
|
case BaseType(name=name):
|
|
if name in self._members:
|
|
if member_name in self._members[name]:
|
|
return self._members[name][member_name]
|
|
return None
|
|
|
|
case AliasType(name=name, type=base):
|
|
if name in self._members:
|
|
if member_name in self._members[name]:
|
|
return self._members[name][member_name]
|
|
return self.lookup_member(base, member_name)
|
|
|
|
case AppliedType(name=name, body=body, args=args):
|
|
generic: Type = self.get_type(name)
|
|
|
|
if not isinstance(generic, GenericType):
|
|
raise ValueError("AppliedType not derived from a GenericType")
|
|
|
|
substitutions = {
|
|
type_var.name: arg for arg, type_var in zip(args, generic.params)
|
|
}
|
|
if name in self._members:
|
|
if member_name in self._members[name]:
|
|
member_type: Type = self._members[name][member_name]
|
|
return substitute_typevars(member_type, substitutions)
|
|
|
|
member_type2: Optional[Type] = self.lookup_member(body, member_name)
|
|
if member_type2 is not None:
|
|
member_type2 = substitute_typevars(member_type2, substitutions)
|
|
return member_type2
|
|
|
|
case ComplexType(members=members):
|
|
if member_name in members:
|
|
return members[member_name]
|
|
self.logger.debug(f"No member '{member_name}' in {type}")
|
|
return None
|
|
|
|
case ExtensionType(base=base, extension=ComplexType(members=members)):
|
|
if member_name in members:
|
|
return members[member_name]
|
|
self.logger.debug(
|
|
f"No member '{member_name}' on {type}, looking up in base"
|
|
)
|
|
return self.lookup_member(base, member_name)
|
|
|
|
case UnknownType():
|
|
return UnknownType()
|
|
|
|
case _:
|
|
self.logger.debug(f"Can't get member on {type}")
|
|
return None
|