237 lines
5.8 KiB
Python
237 lines
5.8 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class BaseType:
|
|
name: str
|
|
|
|
def __str__(self) -> str:
|
|
return self.name
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class AliasType:
|
|
name: str
|
|
type: Type
|
|
|
|
def __str__(self) -> str:
|
|
return self.name
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class UnknownType:
|
|
def __str__(self) -> str:
|
|
return "<Unknown>"
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class UnitType:
|
|
def __str__(self) -> str:
|
|
return "None"
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class Function:
|
|
pos_args: list[Argument]
|
|
args: list[Argument]
|
|
kw_args: list[Argument]
|
|
returns: Type
|
|
|
|
def __str__(self) -> str:
|
|
args: list[str] = []
|
|
if len(self.pos_args) != 0:
|
|
args += list(map(str, self.pos_args))
|
|
if len(self.args) + len(self.kw_args) != 0:
|
|
args.append("/")
|
|
|
|
if len(self.args) != 0:
|
|
args += list(map(str, self.args))
|
|
|
|
if len(self.kw_args) != 0:
|
|
if len(args) != 0:
|
|
args.append("*")
|
|
args += list(map(str, self.kw_args))
|
|
|
|
return f"({', '.join(args)}) -> {self.returns}"
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class Argument:
|
|
pos: int
|
|
name: str
|
|
type: Type
|
|
required: bool
|
|
|
|
def __str__(self) -> str:
|
|
opt: str = "" if self.required else "?"
|
|
return f"{self.name}: {self.type}{opt}"
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class OverloadedFunction:
|
|
overloads: list[Type]
|
|
|
|
def __str__(self) -> str:
|
|
return "<overloaded function>"
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class ComplexType:
|
|
members: dict[str, Type]
|
|
|
|
def __str__(self) -> str:
|
|
props: list[str] = [f"{name}: {type}" for name, type in self.members.items()]
|
|
return f"{{{', '.join(props)}}}"
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class ExtensionType:
|
|
base: Type
|
|
extension: ComplexType
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.base} & {self.extension}"
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class Operation:
|
|
signature: CallSignature
|
|
result: Type
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.signature} -> {self.result}"
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class CallSignature:
|
|
left: Type
|
|
method: str
|
|
right: Type
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.method}({self.left}, {self.right})"
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class TypeVar:
|
|
name: str
|
|
bound: Optional[Type]
|
|
|
|
def __str__(self) -> str:
|
|
if self.bound is not None:
|
|
return f"{self.name} <: {self.bound}"
|
|
return self.name
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class GenericType:
|
|
name: str
|
|
params: list[TypeVar]
|
|
body: Type
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.name}[{', '.join(map(str, self.params))}]"
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class AppliedType:
|
|
name: str
|
|
args: list[Type]
|
|
body: Type
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.name}[{', '.join(map(str, self.args))}]"
|
|
|
|
|
|
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|
def sub_argument(arg: Function.Argument):
|
|
return Function.Argument(
|
|
pos=arg.pos,
|
|
name=arg.name,
|
|
type=substitute_typevars(arg.type, substitutions),
|
|
required=arg.required,
|
|
)
|
|
|
|
match type:
|
|
case BaseType(name=name) if name in substitutions:
|
|
return substitutions[name]
|
|
|
|
case BaseType():
|
|
return type
|
|
|
|
case AliasType(name=name, type=type2):
|
|
return AliasType(name=name, type=substitute_typevars(type2, substitutions))
|
|
|
|
case Function(
|
|
pos_args=pos_args,
|
|
args=args,
|
|
kw_args=kw_args,
|
|
returns=returns,
|
|
):
|
|
return Function(
|
|
pos_args=list(map(sub_argument, pos_args)),
|
|
args=list(map(sub_argument, args)),
|
|
kw_args=list(map(sub_argument, kw_args)),
|
|
returns=substitute_typevars(returns, substitutions),
|
|
)
|
|
|
|
case ComplexType(members=members):
|
|
members2: dict[str, Type] = {
|
|
name: substitute_typevars(prop, substitutions)
|
|
for name, prop in members.items()
|
|
}
|
|
return ComplexType(members=members2)
|
|
|
|
case ExtensionType(base=base, extension=ComplexType(members=members)):
|
|
return ExtensionType(
|
|
base=substitute_typevars(base, substitutions),
|
|
extension=ComplexType(
|
|
members={
|
|
name: substitute_typevars(prop, substitutions)
|
|
for name, prop in members.items()
|
|
}
|
|
),
|
|
)
|
|
|
|
case AppliedType(name=name, args=args, body=body):
|
|
return AppliedType(
|
|
name=name,
|
|
args=[substitute_typevars(arg, substitutions) for arg in args],
|
|
body=substitute_typevars(body, substitutions),
|
|
)
|
|
|
|
case TypeVar(name=name):
|
|
if name in substitutions:
|
|
return substitutions[name]
|
|
raise ValueError(f"Missing TypeVar substitution for {name}")
|
|
|
|
case UnknownType() | UnitType():
|
|
return type
|
|
|
|
case _:
|
|
raise NotImplementedError(f"Unsupported type {type}")
|
|
|
|
|
|
def unfold_type(type: Type) -> Type:
|
|
match type:
|
|
case AliasType(type=ref_type):
|
|
return unfold_type(ref_type)
|
|
case _:
|
|
return type
|
|
|
|
|
|
Type = (
|
|
BaseType
|
|
| AliasType
|
|
| UnknownType
|
|
| UnitType
|
|
| Function
|
|
| OverloadedFunction
|
|
| ComplexType
|
|
| ExtensionType
|
|
| TypeVar
|
|
| GenericType
|
|
| AppliedType
|
|
)
|