diff --git a/midas/checker/python.py b/midas/checker/python.py index 63a076e..a920990 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -534,7 +534,11 @@ class PythonTyper( return self.types.apply_generic(list_type, [UnknownType()]) def visit_base_type(self, node: p.BaseType) -> Type: - return self.types.get_type(node.base) + base: Type = self.types.get_type(node.base) + if node.param is not None: + param: Type = node.param.accept(self) + return self.types.apply_generic(base, [param]) + return base def visit_constraint_type(self, node: p.ConstraintType) -> Type: ... diff --git a/midas/checker/registry.py b/midas/checker/registry.py index 1324bbd..da5a7ee 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -3,6 +3,7 @@ from typing import Optional from midas.checker.builtins import BUILTIN_SUBTYPES from midas.checker.types import ( AliasType, + AppliedType, BaseType, ComplexType, Function, @@ -254,7 +255,7 @@ class TypesRegistry: case AliasType(name=name, type=base): return AliasType(name=name, type=self.apply_generic(base, params)) - case GenericType(params=type_vars, body=body): + case GenericType(name=name, params=type_vars, body=body): n_params: int = len(params) n_type_vars: int = len(type_vars) if n_params < n_type_vars: @@ -274,7 +275,11 @@ class TypesRegistry: f"Type parameter {param} is not a subtype of {type_var.bound}" ) substitutions[type_var.name] = param - return substitute_typevars(body, substitutions) + return AppliedType( + name=name, + args=params, + body=substitute_typevars(body, substitutions), + ) case _: raise ValueError(f"{type} is not a generic type") diff --git a/midas/checker/types.py b/midas/checker/types.py index 8c95134..9081a95 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -71,6 +71,13 @@ class GenericType: body: Type +@dataclass(frozen=True, kw_only=True) +class AppliedType: + name: str + args: list[Type] + body: Type + + def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: def sub_argument(arg: Function.Argument): return Function.Argument( @@ -138,4 +145,5 @@ Type = ( | ComplexType | TypeVar | GenericType + | AppliedType )