diff --git a/midas/checker/midas.py b/midas/checker/midas.py index 60d2b85..e79719c 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -26,6 +26,7 @@ from midas.checker.types import ( UnknownType, unfold_type, ) +from midas.checker.variance import VarianceInferrer from midas.lexer.midas import MidasLexer from midas.lexer.token import Token from midas.parser.midas import MidasParser @@ -132,6 +133,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type for stmt in stmts: stmt.accept(self) + for name, type in self.types._types.items(): + if isinstance(type, GenericType): + inferrer = VarianceInferrer(self.types) + self.types._types[name] = inferrer.infer(type) + def assert_bool(self, expr: m.Expr): type: Type = self.type_of(expr) if not self.types.is_subtype(type, self._bool): diff --git a/midas/checker/registry.py b/midas/checker/registry.py index fa2d1bd..b8f7dfa 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -17,6 +17,7 @@ from midas.checker.types import ( Type, TypeVar, UnknownType, + Variance, substitute_typevars, ) @@ -134,6 +135,24 @@ class TypesRegistry: case (ConstraintType(type=base1), _): return self.is_subtype(base1, type2) + case ( + AppliedType(name=name1, args=args1), + AppliedType(name=name2, args=args2), + ) if ( + name1 == name2 + ): + generic: Type = self.get_type(name1) + assert isinstance(generic, GenericType) + for param, arg1, arg2 in zip(generic.params, args1, args2): + variance: Variance = param.variance + if variance in {Variance.INVARIANT, Variance.COVARIANT}: + if not self.is_subtype(arg1, arg2): + return False + if variance in {Variance.INVARIANT, Variance.CONTRAVARIANT}: + if not self.is_subtype(arg2, arg1): + return False + return True + return False # TODO: verify the logic in here diff --git a/midas/checker/types.py b/midas/checker/types.py index 82a08ba..60fb9c7 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field +from enum import StrEnum from typing import Optional, assert_never import midas.ast.midas as m @@ -102,15 +103,27 @@ class ExtensionType: return f"{self.base} & {self.extension}" +class Variance(StrEnum): + INVARIANT = "INVARIANT" + COVARIANT = "COVARIANT" + CONTRAVARIANT = "CONTRAVARIANT" + + @dataclass(frozen=True, kw_only=True) class TypeVar: name: str bound: Optional[Type] + variance: Variance = Variance.INVARIANT def __str__(self) -> str: + variance: str = { + Variance.COVARIANT: "+", + Variance.CONTRAVARIANT: "-", + }.get(self.variance, "") + res: str = f"{variance}{self.name}" if self.bound is not None: - return f"{self.name} <: {self.bound}" - return self.name + res = f"{res} <: {self.bound}" + return res @dataclass(frozen=True, kw_only=True) diff --git a/midas/checker/variance.py b/midas/checker/variance.py new file mode 100644 index 0000000..b620958 --- /dev/null +++ b/midas/checker/variance.py @@ -0,0 +1,129 @@ +from typing import Literal, Optional, cast + +from midas.checker.registry import TypesRegistry +from midas.checker.types import ( + AppliedType, + ConstraintType, + Function, + GenericType, + OverloadedFunction, + Type, + TypeVar, + Variance, +) + +Polarity = Literal[-1, 0, 1] + + +class Tracker: + def __init__(self, vars: list[TypeVar]) -> None: + self.vars: list[TypeVar] = vars + self.refs: dict[str, set[Polarity]] = {var.name: set() for var in self.vars} + + def record(self, var: TypeVar, polarity: Polarity): + self.refs[var.name].add(polarity) + + def get_updated_vars(self) -> list[TypeVar]: + return [ + TypeVar( + name=var.name, bound=var.bound, variance=self.get_variance(var.name) + ) + for var in self.vars + ] + + def get_variance(self, name: str) -> Variance: + refs: set[Polarity] = self.refs[name] + if refs == {-1}: + return Variance.CONTRAVARIANT + if refs == {1}: + return Variance.COVARIANT + return Variance.INVARIANT + + def __contains__(self, item: TypeVar | str): + if isinstance(item, TypeVar): + return item.name in self + return item in self.refs + + +class VarianceInferrer: + def __init__(self, types: TypesRegistry) -> None: + self.types: TypesRegistry = types + self.tracker: Tracker = Tracker([]) + + def infer(self, type: GenericType) -> GenericType: + self.tracker = Tracker(type.params) + + self.walk(type.body, 1, type.name) + members: dict[str, Type] = self.types._members.get(type.name, {}) + for name, member in members.items(): + self.walk(member, 1, type.name, [f"member:'{name}'"]) + + return GenericType( + name=type.name, + params=self.tracker.get_updated_vars(), + body=type.body, + ) + + def walk( + self, + type: Type, + polarity: Polarity, + base_name: str, + path: Optional[list[str]] = None, + ): + if path is None: + path = [] + + match type: + # Arguments are negative positions -> flip polarity + # Return is positive position -> keep polarity + case Function(pos_args=pos_args, args=mixed_args, kw_args=kw_args): + all_args: list[Function.Argument] = pos_args + mixed_args + kw_args + for arg in all_args: + self.walk( + arg.type, + -polarity, + base_name, + path + [f"arg:'{arg.name}'"], + ) + + self.walk(type.returns, polarity, base_name, path + ["return"]) + + # Walk all overloads + case OverloadedFunction(overloads=overloads): + for overload in overloads: + self.walk(overload, polarity, base_name, path) + + # If same name as root generic -> skip + # Get inferred variance of parameters and multiply with current + # polarity to recurse through arguments + case AppliedType(name=name, args=args): + # TODO: handle mutually recursive types + if name == base_name: + return + generic: Type = self.types.get_type(name) + assert isinstance(generic, GenericType) + params: list[TypeVar] = generic.params + polarities: dict[Variance, Polarity] = { + Variance.INVARIANT: 0, + Variance.COVARIANT: 1, + Variance.CONTRAVARIANT: -1, + } + for arg, param in zip(args, params): + param_polarity: Polarity = polarities[param.variance] + self.walk( + arg, + cast(Polarity, polarity * param_polarity), + base_name, + path + [f"applied:'{name}'"], + ) + + # Walk base type + case ConstraintType(type=base): + self.walk(base, polarity, base_name, path + ["constraint"]) + + # Reached end + # If tracked, record polarity + case TypeVar(): + if type in self.tracker: + self.tracker.record(type, polarity) diff --git a/midas/cli/commands/registry.py b/midas/cli/commands/registry.py index 41fc616..7c0c521 100644 --- a/midas/cli/commands/registry.py +++ b/midas/cli/commands/registry.py @@ -39,7 +39,11 @@ def dump_registry( print("##### Types #####") for name, type in checker.types._types.items(): members: dict[str, Type] = checker.types._members.get(name, {}) - print(f"{name} = {base_type(type)}") + params: str = "" + if isinstance(type, GenericType): + params = ", ".join(map(str, type.params)) + params = f"[{params}]" + print(f"{name}{params} = {base_type(type)}") if len(members) != 0: print(" " * 4 + "Members:") for member_name, member_type in members.items(): diff --git a/tests/cases/checker/07_variance.midas b/tests/cases/checker/07_variance.midas new file mode 100644 index 0000000..c117390 --- /dev/null +++ b/tests/cases/checker/07_variance.midas @@ -0,0 +1,59 @@ +// T is invariant (unused) +type Unused[T] = object + +// T is covariant +type Covariant[T] = object + +// T is contravariant +type Contravariant[T] = object + +// T is invariant +type Invariant[T] = object + +extend Covariant[T] { + def foo: fn() -> T +} + +extend Contravariant[T] { + def foo: fn(T, /) -> None +} + +extend Invariant[T] { + def foo: fn(T, /) -> T +} + +// T is covariant +type Coco[T] = object +extend Coco[T] { + def foo: fn() -> Covariant[T] +} + +// T is contravariant +type Cocontra[T] = object +extend Cocontra[T] { + def foo: fn() -> Contravariant[T] +} + +// T is contravariant +type Contraco[T] = object +extend Contraco[T] { + def foo: fn(Covariant[T], /) -> None +} + +// T is covariant +type Contracontra[T] = object +extend Contracontra[T] { + def foo: fn(Contravariant[T], /) -> None +} + + +type T1[T] = object +type T2[T] = object + +extend T1[T] { + def foo: fn() -> T2[T] +} + +extend T2[T] { + def foo: fn() -> T1[T] +} diff --git a/tests/cases/checker/07_variance.py b/tests/cases/checker/07_variance.py new file mode 100644 index 0000000..9b251da --- /dev/null +++ b/tests/cases/checker/07_variance.py @@ -0,0 +1,52 @@ +from _ import ( + T1, + T2, + Coco, + Cocontra, + Contraco, + Contracontra, + Contravariant, + Covariant, + Invariant, + Unused, +) + +unused: Unused +covariant: Covariant +contravariant: Contravariant +invariant: Invariant +coco: Coco +cocontra: Cocontra +contraco: Contraco +contracontra: Contracontra +t1: T1 +t2: T2 + +# Dummy print to prudce judgements for the expressions +print( + unused, + covariant, + contravariant, + invariant, + coco, + cocontra, + contraco, + contracontra, + t1, + t2, +) + +cov1: Covariant[float] +cov2: Covariant[int] +cov1 = cov2 # Ok because int <: float => Covariant[int] <: Covariant[float] +cov2 = cov1 # Invalid + +contra1: Contravariant[float] +contra2: Contravariant[int] +contra1 = contra2 # Invalid +contra2 = contra1 # Ok because int <: float => Covariant[float] <: Covariant[int] + +inv1: Invariant[float] +inv2: Invariant[int] +inv1 = inv2 # Invalid +inv2 = inv1 # Invalid diff --git a/tests/cases/checker/07_variance.py.ref.json b/tests/cases/checker/07_variance.py.ref.json new file mode 100644 index 0000000..a43b957 --- /dev/null +++ b/tests/cases/checker/07_variance.py.ref.json @@ -0,0 +1,512 @@ +{ + "diagnostics": [ + { + "type": "Error", + "location": { + "start": [ + 28, + 4 + ], + "end": [ + 28, + 13 + ] + }, + "message": "Too many positional arguments" + }, + { + "type": "Error", + "location": { + "start": [ + 42, + 0 + ], + "end": [ + 42, + 11 + ] + }, + "message": "Cannot assign Covariant[float] to variable 'cov2' of type Covariant[int]" + }, + { + "type": "Error", + "location": { + "start": [ + 46, + 0 + ], + "end": [ + 46, + 17 + ] + }, + "message": "Cannot assign Contravariant[int] to variable 'contra1' of type Contravariant[float]" + }, + { + "type": "Error", + "location": { + "start": [ + 51, + 0 + ], + "end": [ + 51, + 11 + ] + }, + "message": "Cannot assign Invariant[int] to variable 'inv1' of type Invariant[float]" + }, + { + "type": "Error", + "location": { + "start": [ + 52, + 0 + ], + "end": [ + 52, + 11 + ] + }, + "message": "Cannot assign Invariant[float] to variable 'inv2' of type Invariant[int]" + } + ], + "judgments": [ + { + "location": { + "from": "L26:0", + "to": "L26:5" + }, + "expr": { + "_type": "VariableExpr", + "name": "print" + }, + "type": { + "pos_args": [ + { + "pos": 0, + "name": "object", + "type": {}, + "required": true + } + ], + "args": [], + "kw_args": [], + "returns": {} + } + }, + { + "location": { + "from": "L27:4", + "to": "L27:10" + }, + "expr": { + "_type": "VariableExpr", + "name": "unused" + }, + "type": { + "name": "Unused", + "params": [ + { + "name": "T", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "name": "object" + } + } + }, + { + "location": { + "from": "L28:4", + "to": "L28:13" + }, + "expr": { + "_type": "VariableExpr", + "name": "covariant" + }, + "type": { + "name": "Covariant", + "params": [ + { + "name": "T", + "bound": null, + "variance": "COVARIANT" + } + ], + "body": { + "name": "object" + } + } + }, + { + "location": { + "from": "L29:4", + "to": "L29:17" + }, + "expr": { + "_type": "VariableExpr", + "name": "contravariant" + }, + "type": { + "name": "Contravariant", + "params": [ + { + "name": "T", + "bound": null, + "variance": "CONTRAVARIANT" + } + ], + "body": { + "name": "object" + } + } + }, + { + "location": { + "from": "L30:4", + "to": "L30:13" + }, + "expr": { + "_type": "VariableExpr", + "name": "invariant" + }, + "type": { + "name": "Invariant", + "params": [ + { + "name": "T", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "name": "object" + } + } + }, + { + "location": { + "from": "L31:4", + "to": "L31:8" + }, + "expr": { + "_type": "VariableExpr", + "name": "coco" + }, + "type": { + "name": "Coco", + "params": [ + { + "name": "T", + "bound": null, + "variance": "COVARIANT" + } + ], + "body": { + "name": "object" + } + } + }, + { + "location": { + "from": "L32:4", + "to": "L32:12" + }, + "expr": { + "_type": "VariableExpr", + "name": "cocontra" + }, + "type": { + "name": "Cocontra", + "params": [ + { + "name": "T", + "bound": null, + "variance": "CONTRAVARIANT" + } + ], + "body": { + "name": "object" + } + } + }, + { + "location": { + "from": "L33:4", + "to": "L33:12" + }, + "expr": { + "_type": "VariableExpr", + "name": "contraco" + }, + "type": { + "name": "Contraco", + "params": [ + { + "name": "T", + "bound": null, + "variance": "CONTRAVARIANT" + } + ], + "body": { + "name": "object" + } + } + }, + { + "location": { + "from": "L34:4", + "to": "L34:16" + }, + "expr": { + "_type": "VariableExpr", + "name": "contracontra" + }, + "type": { + "name": "Contracontra", + "params": [ + { + "name": "T", + "bound": null, + "variance": "COVARIANT" + } + ], + "body": { + "name": "object" + } + } + }, + { + "location": { + "from": "L35:4", + "to": "L35:6" + }, + "expr": { + "_type": "VariableExpr", + "name": "t1" + }, + "type": { + "name": "T1", + "params": [ + { + "name": "T", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "name": "object" + } + } + }, + { + "location": { + "from": "L36:4", + "to": "L36:6" + }, + "expr": { + "_type": "VariableExpr", + "name": "t2" + }, + "type": { + "name": "T2", + "params": [ + { + "name": "T", + "bound": null, + "variance": "INVARIANT" + } + ], + "body": { + "name": "object" + } + } + }, + { + "location": { + "from": "L26:0", + "to": "L37:1" + }, + "expr": { + "_type": "CallExpr", + "callee": { + "_type": "VariableExpr", + "name": "print" + }, + "arguments": [ + { + "_type": "VariableExpr", + "name": "unused" + }, + { + "_type": "VariableExpr", + "name": "covariant" + }, + { + "_type": "VariableExpr", + "name": "contravariant" + }, + { + "_type": "VariableExpr", + "name": "invariant" + }, + { + "_type": "VariableExpr", + "name": "coco" + }, + { + "_type": "VariableExpr", + "name": "cocontra" + }, + { + "_type": "VariableExpr", + "name": "contraco" + }, + { + "_type": "VariableExpr", + "name": "contracontra" + }, + { + "_type": "VariableExpr", + "name": "t1" + }, + { + "_type": "VariableExpr", + "name": "t2" + } + ], + "keywords": {} + }, + "type": {} + }, + { + "location": { + "from": "L41:7", + "to": "L41:11" + }, + "expr": { + "_type": "VariableExpr", + "name": "cov2" + }, + "type": { + "name": "Covariant", + "args": [ + { + "name": "int" + } + ], + "body": { + "name": "object" + } + } + }, + { + "location": { + "from": "L42:7", + "to": "L42:11" + }, + "expr": { + "_type": "VariableExpr", + "name": "cov1" + }, + "type": { + "name": "Covariant", + "args": [ + { + "name": "float" + } + ], + "body": { + "name": "object" + } + } + }, + { + "location": { + "from": "L46:10", + "to": "L46:17" + }, + "expr": { + "_type": "VariableExpr", + "name": "contra2" + }, + "type": { + "name": "Contravariant", + "args": [ + { + "name": "int" + } + ], + "body": { + "name": "object" + } + } + }, + { + "location": { + "from": "L47:10", + "to": "L47:17" + }, + "expr": { + "_type": "VariableExpr", + "name": "contra1" + }, + "type": { + "name": "Contravariant", + "args": [ + { + "name": "float" + } + ], + "body": { + "name": "object" + } + } + }, + { + "location": { + "from": "L51:7", + "to": "L51:11" + }, + "expr": { + "_type": "VariableExpr", + "name": "inv2" + }, + "type": { + "name": "Invariant", + "args": [ + { + "name": "int" + } + ], + "body": { + "name": "object" + } + } + }, + { + "location": { + "from": "L52:7", + "to": "L52:11" + }, + "expr": { + "_type": "VariableExpr", + "name": "inv1" + }, + "type": { + "name": "Invariant", + "args": [ + { + "name": "float" + } + ], + "body": { + "name": "object" + } + } + } + ] +} \ No newline at end of file