diff --git a/midas/checker/types.py b/midas/checker/types.py index 83707b6..15079e0 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Optional @dataclass(frozen=True, kw_only=True) @@ -57,4 +58,75 @@ class Operation: right: Type -Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType +@dataclass(frozen=True, kw_only=True) +class TypeVar: + name: str + bound: Optional[Type] + + +@dataclass(frozen=True, kw_only=True) +class GenericType: + params: list[TypeVar] + body: Type + + +def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: + def sub_argument(arg: Function.Argument): + return Function.Argument( + pos=arg.pos, + name=arg.name, + type=substitute_typevars(arg.type, substitutions), + required=arg.required, + ) + + match type: + case BaseType(name=name) if name in substitutions: + return substitutions[name] + + case AliasType(name=name, type=type2): + return AliasType(name=name, type=substitute_typevars(type2, substitutions)) + + case Function( + name=name, + pos_args=pos_args, + args=args, + kw_args=kw_args, + returns=returns, + ): + return Function( + name=name, + pos_args=list(map(sub_argument, pos_args)), + args=list(map(sub_argument, args)), + kw_args=list(map(sub_argument, kw_args)), + returns=substitute_typevars(returns, substitutions), + ) + + case ComplexType(properties=properties): + properties2: dict[str, Type] = { + name: substitute_typevars(prop, substitutions) + for name, prop in properties.items() + } + return ComplexType(properties=properties2) + + case TypeVar(name=name): + if name in substitutions: + return substitutions[name] + raise ValueError(f"Missing TypeVar substitution for {name}") + + case UnknownType() | UnitType(): + return type + + case _: + raise NotImplementedError(f"Unsupported type {type}") + + +Type = ( + BaseType + | AliasType + | UnknownType + | UnitType + | Function + | ComplexType + | TypeVar + | GenericType +)