Files
midas/midas/checker/registry.py

314 lines
10 KiB
Python

from typing import Optional
from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
Function,
GenericType,
Operation,
Type,
substitute_typevars,
)
class TypesRegistry:
def __init__(self) -> None:
self._types: dict[str, Type] = {}
self._operations: dict[Operation.CallSignature, Type] = {}
def get_type(self, name: str) -> Type:
"""Get a type from its name
Args:
name (str): the name of the type
Raises:
NameError: if the type is not defined
Returns:
Type: the type
"""
if name in self._types:
return self._types[name]
raise NameError(f"Undefined type {name}")
def get_operation_result(
self, left: Type, operator: str, right: Type
) -> Optional[Type]:
"""Get the resulting type of an operation
Args:
left (Type): the type of the left operand
operator (str): the operation name
right (Type): the type of the right operand
Returns:
Optional[Type]: the result type, or None if no matching operation was found
"""
signature: Operation.CallSignature = Operation.CallSignature(
left=left,
method=operator,
right=right,
)
result: Optional[Type] = self._operations.get(signature)
return result
def get_operations_by_name(self, name: str) -> list[Operation]:
operations: list[Operation] = []
for signature, result in self._operations.items():
if signature.method == name:
operations.append(
Operation(
signature=signature,
result=result,
)
)
return operations
def define_type(self, name: str, type: Type) -> Type:
"""Define a type in the registry
Args:
name (str): the name of the type
type (Type): the type to define
Raises:
ValueError: if a type is already defined with that name
Returns:
Type: the defined type
"""
if name in self._types:
raise ValueError(f"Type {name} already defined")
self._types[name] = type
return type
def define_operation(self, left: Type, operator: str, right: Type, result: Type):
"""Define an operation in the registry
Args:
left (Type): the type of the left operand
operator (str): the operation name
right (Type): the type of the right operand
result (Type): the result type
Raises:
ValueError: if an operation is already defined with these operands and name
"""
signature: Operation.CallSignature = Operation.CallSignature(
left=left,
method=operator,
right=right,
)
if signature in self._operations:
raise ValueError(
f"Operation {operator} already defined between {left} and {right}"
)
self._operations[signature] = result
def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2`
For more details on the rules checked here, see TAPL Chap. 15-16-17
Args:
type1 (Type): the potential subtype
type2 (Type): the potential supertype
Returns:
bool: whether `type1` is a subtype of `type2`
"""
if type1 == type2:
return True
match (type1, type2):
case (AliasType(type=base1), _):
return self.is_subtype(base1, type2)
case (BaseType(name=name1), BaseType(name=name2)):
return name1 in BUILTIN_SUBTYPES.get(name2, set())
case (ComplexType(properties=props1), ComplexType(properties=props2)):
for k, t in props2.items():
if k not in props1:
return False
if not self.is_subtype(props1[k], t):
return False
return True
case (Function(), Function()):
return self.is_func_subtype(type1, type2)
return False
# TODO: verify the logic in here
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
"""Check whether a function is a subtype of another
Args:
func1 (Function): the potential function subtype
func2 (Function): the potential function supertype
Returns:
bool: whether `func1` is a subtype of `func2`
"""
if not self.is_subtype(func1.returns, func2.returns):
return False
pos1: list[Function.Argument] = func1.pos_args
mixed1: list[Function.Argument] = func1.args
kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args}
pos2: list[Function.Argument] = func2.pos_args
mixed2: list[Function.Argument] = func2.args
kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args}
mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2}
mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2}
def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool:
if not self.is_subtype(sub.type, sup.type):
return False
if not sup.required and sub.required:
return False
return True
for arg1 in pos1:
arg2: Function.Argument
if arg1.pos < len(pos2):
arg2 = pos2[arg1.pos]
elif arg1.pos in mixed_by_pos:
arg2 = mixed_by_pos[arg1.pos]
elif not arg1.required:
continue
else:
return False
if not is_arg_subtype(arg2, arg1):
return False
for name, arg1 in kw1.items():
arg2: Function.Argument
if name in kw2:
arg2 = kw2[name]
elif name in mixed_by_name:
arg2 = mixed_by_name[name]
elif not arg1.required:
continue
else:
return False
if not is_arg_subtype(arg2, arg1):
return False
for arg1 in mixed1:
pos_arg2: Optional[Function.Argument] = None
kw_arg2: Optional[Function.Argument] = None
if arg1.name in kw2:
kw_arg2 = kw2[arg1.name]
elif arg1.name in mixed_by_name:
kw_arg2 = mixed_by_name[arg1.name]
if arg1.pos < len(pos2):
pos_arg2 = pos2[arg1.pos]
elif arg1.pos in mixed_by_pos:
pos_arg2 = mixed_by_pos[arg1.pos]
# No match in func2 and arg is required
if pos_arg2 is None and kw_arg2 is None and arg1.required:
return False
# Matching keyword argument
if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1):
return False
# Matching positional argument
if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1):
return False
mixed_positions: set[int] = {a.pos for a in mixed1}
mixed_names: set[str] = {a.name for a in mixed1}
for arg2 in pos2:
if not arg2.required:
continue
if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions:
return False
for name, arg2 in kw2.items():
if not arg2.required:
continue
if name not in kw1 and name not in mixed_names:
return False
for arg2 in mixed2:
if arg2.required:
continue
pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions
kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names
if not pos_match or not kw_match:
return False
return True
def apply_generic(self, type: Type, params: list[Type]) -> Type:
match type:
case AliasType(name=name, type=base):
return AliasType(name=name, type=self.apply_generic(base, params))
case GenericType(name=name, params=type_vars, body=body):
n_params: int = len(params)
n_type_vars: int = len(type_vars)
if n_params < n_type_vars:
raise ValueError(
f"Missing type parameters, expected {n_type_vars} but only {n_params} provided"
)
if n_params > n_type_vars:
raise ValueError(
f"Too many type parameters, expected {n_type_vars} but {n_params} provided"
)
substitutions: dict[str, Type] = {}
for param, type_var in zip(params, type_vars):
if type_var.bound is not None and not self.is_subtype(
param, type_var.bound
):
raise ValueError(
f"Type parameter {param} is not a subtype of {type_var.bound}"
)
substitutions[type_var.name] = param
return AppliedType(
name=name,
args=params,
body=substitute_typevars(body, substitutions),
)
case _:
raise ValueError(f"{type} is not a generic type")
def reduce_types(self, types: list[Type]) -> list[Type]:
"""Reduce a list of types to remove subtypes and only keep the highest types
Args:
types (list[Type]): the types to reduce
Returns:
list[Type]: the reduced list of types
"""
reduced: bool = True
keep: list[int] = list(range(len(types)))
while reduced:
reduced = False
for i, i1 in enumerate(keep):
type1: Type = types[i1]
for i2 in keep[i + 1 :]:
type2 = types[i2]
if self.is_subtype(type1, type2):
keep.remove(i1)
elif self.is_subtype(type2, type1):
keep.remove(i2)
else:
continue
reduced = True
break
return [types[i] for i in keep]