Compare commits

..

26 Commits

Author SHA1 Message Date
776a3fb86c fix(checker): change back warning to errors 2026-06-19 22:12:47 +02:00
fcbbff0177 tests: add predicates and constraints test 2026-06-19 22:09:16 +02:00
e2efbf693e fix(checker) minor tweaks 2026-06-19 21:54:13 +02:00
68f83ab6cf feat(parser): parse strings in Midas files 2026-06-19 21:53:35 +02:00
4118f95753 fix(parser): correctly parse keyword arguments 2026-06-19 21:10:59 +02:00
d5972df3f6 fix(checker): handle all operations and calls in predicates 2026-06-19 21:10:33 +02:00
3411aa9953 fix(checker): lookup predicate variables in preamble 2026-06-19 21:09:02 +02:00
7a2ee5a4cc feat(cli): print predicate with dump-registry 2026-06-19 15:06:55 +02:00
359ed21bb8 fix(checker): typo in docstring 2026-06-19 15:05:49 +02:00
d0f1178c17 fix(checker): change some diagnostics to warnings
temporarily change type errors in predicates to warnings until operations are fully type checked
2026-06-19 14:41:43 +02:00
0eca23b894 feat(gen): generate type hints for functions 2026-06-19 14:11:38 +02:00
f664fb4a4f feat(gen): handle predicate aliases
handle cases where a predicate is defined as an alias, i.e. without any parameters
2026-06-19 14:05:34 +02:00
32330243c6 fix(parser): fix call expr location span 2026-06-19 13:57:49 +02:00
96e76065cf feat(types): detect constraint base subtyping 2026-06-19 13:57:21 +02:00
7b7d87e59a feat(checker): type check predicate body 2026-06-19 13:55:32 +02:00
1eb90164e6 fix(gen): remove id from named predicate function 2026-06-19 10:15:09 +02:00
35ec0d0db8 fix(tests): update generator tester 2026-06-18 22:49:08 +02:00
48fcb499a1 feat(gen): generate predicate functions 2026-06-18 22:48:10 +02:00
bdc1b265a6 feat(gen): generate basic constraint assertion 2026-06-18 13:19:17 +02:00
1fb4b6f8c6 feat(types): add ConstraintType 2026-06-18 12:52:39 +02:00
48c1ecc1c8 refactor: ensure exhaustiveness in some match/case 2026-06-18 12:51:28 +02:00
04853eac70 tests: update with new predicate AST representation 2026-06-18 12:43:24 +02:00
020824d1f8 fix(tests): correctly serialize param name 2026-06-18 12:43:02 +02:00
ad86446a2d feat(midas): generalize param spec of predicate and parse 2026-06-18 12:38:24 +02:00
94d84ab170 feat(midas): add CallExpr 2026-06-18 12:34:29 +02:00
8381f4f31d refactor: add param spec for FunctionType 2026-06-18 11:06:02 +02:00
15 changed files with 38 additions and 1374 deletions

View File

@@ -1,4 +1,4 @@
<h1>Midas</h1>
# Midas
*Midas* is a type system to _Maintain Integrity of Data with Annotated Structures_. In Greek mythology, [Midas](https://en.wikipedia.org/wiki/Midas) was a Phrygian king who was blessed with the gift of turning everything he touched into gold.
@@ -6,25 +6,6 @@
This framework is being developed as part of a Bachelor's Thesis by Louis Heredero at HEI Sion.
<details>
<summary><strong>Table of Contents</strong></summary>
- [Requirements](#requirements)
- [Installation](#installation)
- [Commands](#commands)
- [Type Checking](#type-checking)
- [Compiling](#compiling)
- [Formatting](#formatting)
- [Highlighting](#highlighting)
- [Dumping the AST](#dumping-the-ast)
- [Dumping the Registry](#dumping-the-registry)
- [Generating Stubs](#generating-stubs)
- [Showing Type Judgements](#showing-type-judgements)
- [Validating Definitions](#validating-definitions)
- [Tests](#tests)
</details>
## Requirements
- Python 3.11+
@@ -51,49 +32,25 @@ This framework is being developed as part of a Bachelor's Thesis by Louis Herede
## Commands
<!--
check
compile
format
highlight
parse
dump_registry
types
validate
-->
### Type Checking
```shell
midas check -t types.midas source.py
```
This command parses the given files and run the type checkers against the Midas definitions and Python program. Diagnostics are then printed showing warnings and errors.
### Compiling
> [!NOTE]
> In the current state of the project, the `compile` command doesn't generate any runnable code, it only runs the parsers and type checker on the provided files
```shell
midas compile -t types.midas source.py
```
With the `compile` command, you can process a source Python file, with any number of custom type definition files (`-t FILE` option), and the type checker will verify the coherence of your program and generate the runnable code with valid syntax and runtime assertions.
### Formatting
```shell
midas format types.midas
midas format types.midas -o formatted.midas
```
This command parses the given Midas file and outputs a pretty printed file from the AST.
The optional `-l FILE` option lets you produce a highlighted version of the source code showing diagnostics from the type checker (see [Highlighting](#highlighting))
### Highlighting
```shell
midas highlight source.py
midas highlight source.py -o highlighted.html
midas highlight types.midas
midas highlight types.midas -o highlighted.html
midas utils highlight source.py
# or
midas utils highlight types.midas
```
The `highlight` command takes in a source file (Python or Midas), runs the appropriate parser and outputs an HTML file containing the source code with added highlighting. This highlighting takes the form of hoverable annotations showing some of the parsed structures (e.g. a function definition, an assignment, a generic type, etc.)
@@ -103,43 +60,14 @@ The optional `-o FILE` option can be used to specify an output path. By default,
### Dumping the AST
```shell
midas parse source.py
midas parse types.midas
midas utils dump-ast source.py
# or
midas utils dump-ast types.midas
```
For debugging purposes, you can output the AST parsed from a Python or Midas file. For Python files, the `--raw` flags lets you toggle the custom AST parsing. With `--raw`, the raw AST is returned, as produced by the builtin `ast` module. This flag has no effect on Midas files.
For debugging purposes, you can output the AST parsed from a Python or Midas file. For Python files, the `-p` flags lets you toggle the custom AST parsing. Without `-p`, the raw AST is returned, as produced by the builtin `ast` module. This flag has no effect on Midas files.
### Dumping the Registry
```shell
midas dump-registry -t types.midas
```
This command processes the given Midas definitions and dumps the contents of the types registry.
### Generating Stubs
```shell
midas stubs types.midas -o stubs.pyi
```
This command generate Python stubs from a Midas definition file
### Showing Type Judgements
```shell
midas types -t types.midas source.py
```
This command type checks the given Python source file and logs all typing judgements made by the type checker.
### Validating Definitions
```shell
midas validate types.midas
```
This command lets you validate a Midas definition file by running the parser and type checker, verifying syntax and references.
The optional `-o FILE` option can be used to specify an output path. By default, the file is printed in stdout (equivalent to `-o -`).
## Tests
@@ -149,7 +77,6 @@ Several snapshot tests are available to assert the good behaviour of the parsers
uv run -m tests.midas run -a
uv run -m tests.python run -a
uv run -m tests.checker run -a
uv run -m tests.generator run -a
```
**Available subcommands:**

View File

@@ -26,7 +26,6 @@ from midas.checker.types import (
UnknownType,
unfold_type,
)
from midas.checker.variance import VarianceInferrer
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
@@ -133,11 +132,6 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
for stmt in stmts:
stmt.accept(self)
for name, type in self.types._types.items():
if isinstance(type, GenericType):
inferrer = VarianceInferrer(self.types)
self.types._types[name] = inferrer.infer(type)
def assert_bool(self, expr: m.Expr):
type: Type = self.type_of(expr)
if not self.types.is_subtype(type, self._bool):
@@ -173,7 +167,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
base_name,
member.name.lexeme,
member_type,
member.kind,
member.kind == m.MemberKind.METHOD,
)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:

View File

@@ -16,7 +16,6 @@ from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver
from midas.checker.types import (
AliasType,
AppliedType,
Function,
OverloadedFunction,
@@ -699,17 +698,9 @@ class PythonTyper(
case UnknownType():
return UnknownType()
case AliasType(type=base):
return self._get_call_result(
location, base, positional, keywords, report_errors
)
case _:
if report_errors:
self.reporter.error(
location,
f"{callee} ({callee.__class__.__name__}) is not callable",
)
self.reporter.error(location, f"{callee} is not callable")
return None
def _are_arguments_valid(

View File

@@ -1,8 +1,6 @@
import logging
from dataclasses import dataclass
from typing import Optional
from midas.ast.midas import MemberKind
from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.types import (
AliasType,
@@ -19,22 +17,15 @@ from midas.checker.types import (
Type,
TypeVar,
UnknownType,
Variance,
substitute_typevars,
)
@dataclass
class Member:
kind: MemberKind
type: Type
class TypesRegistry:
def __init__(self) -> None:
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
self._types: dict[str, Type] = {}
self._members: dict[str, dict[str, Member]] = {}
self._members: dict[str, dict[str, Type]] = {}
self._predicates: dict[str, Predicate] = {}
def get_type(self, name: str) -> Type:
@@ -72,38 +63,26 @@ class TypesRegistry:
return type
def define_member(
self,
type_name: str,
member_name: str,
member_type: Type,
kind: MemberKind,
self, type_name: str, member_name: str, member_type: Type, is_method: bool
):
members: dict[str, Member] = self._members.setdefault(type_name, {})
members: dict[str, Type] = self._members.setdefault(type_name, {})
if member_name in members:
current: Member = members[member_name]
if current.kind != kind:
if not is_method:
self.logger.error(
f"Member '{member_name}' is already defined as a {current.kind},"
+ f" cannot define a {kind} with the same name"
f"Member '{member_name}' already defined for type {type_name}"
)
return
if kind != MemberKind.METHOD:
self.logger.error(
f"Member '{member_name}' already defined for type {type_name},"
+ " only methods can be overloaded"
)
return
current: Type = members[member_name]
combined: Type
match current.type:
match current:
case OverloadedFunction(overloads=overloads):
combined = OverloadedFunction(overloads=overloads + [member_type])
case _:
combined = OverloadedFunction(overloads=[current.type, member_type])
members[member_name] = Member(kind=current.kind, type=combined)
combined = OverloadedFunction(overloads=[current, member_type])
members[member_name] = combined
else:
members[member_name] = Member(kind=kind, type=member_type)
members[member_name] = member_type
def define_predicate(self, name: str, predicate: Predicate):
if name in self._predicates:
@@ -155,24 +134,6 @@ class TypesRegistry:
case (ConstraintType(type=base1), _):
return self.is_subtype(base1, type2)
case (
AppliedType(name=name1, args=args1),
AppliedType(name=name2, args=args2),
) if (
name1 == name2
):
generic: Type = self.get_type(name1)
assert isinstance(generic, GenericType)
for param, arg1, arg2 in zip(generic.params, args1, args2):
variance: Variance = param.variance
if variance in {Variance.INVARIANT, Variance.COVARIANT}:
if not self.is_subtype(arg1, arg2):
return False
if variance in {Variance.INVARIANT, Variance.CONTRAVARIANT}:
if not self.is_subtype(arg2, arg1):
return False
return True
return False
# TODO: verify the logic in here
@@ -347,13 +308,13 @@ class TypesRegistry:
case BaseType(name=name):
if name in self._members:
if member_name in self._members[name]:
return self._members[name][member_name].type
return self._members[name][member_name]
return None
case AliasType(name=name, type=base):
if name in self._members:
if member_name in self._members[name]:
return self._members[name][member_name].type
return self._members[name][member_name]
return self.lookup_member(base, member_name)
case AppliedType(name=name, body=body, args=args):
@@ -367,7 +328,7 @@ class TypesRegistry:
}
if name in self._members:
if member_name in self._members[name]:
member_type: Type = self._members[name][member_name].type
member_type: Type = self._members[name][member_name]
return substitute_typevars(member_type, substitutions)
member_type2: Optional[Type] = self.lookup_member(body, member_name)

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Optional, assert_never
import midas.ast.midas as m
@@ -103,27 +102,15 @@ class ExtensionType:
return f"{self.base} & {self.extension}"
class Variance(StrEnum):
INVARIANT = "INVARIANT"
COVARIANT = "COVARIANT"
CONTRAVARIANT = "CONTRAVARIANT"
@dataclass(frozen=True, kw_only=True)
class TypeVar:
name: str
bound: Optional[Type]
variance: Variance = Variance.INVARIANT
def __str__(self) -> str:
variance: str = {
Variance.COVARIANT: "+",
Variance.CONTRAVARIANT: "-",
}.get(self.variance, "")
res: str = f"{variance}{self.name}"
if self.bound is not None:
res = f"{res} <: {self.bound}"
return res
return f"{self.name} <: {self.bound}"
return self.name
@dataclass(frozen=True, kw_only=True)
@@ -166,9 +153,6 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
)
match type:
case TopType():
return type
case BaseType(name=name) if name in substitutions:
return substitutions[name]
@@ -235,21 +219,6 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
return substitutions[name]
raise ValueError(f"Missing TypeVar substitution for {name}")
case GenericType(name=name, params=params, body=body):
params2: list[TypeVar] = []
for param in params:
param2: Type = substitute_typevars(param, substitutions)
if not isinstance(param2, TypeVar):
raise ValueError(
f"Invalid type parameter substitution, expected TypeVar, got {param2}"
)
params2.append(param2)
return GenericType(
name=name,
params=params2,
body=substitute_typevars(body, substitutions),
)
case UnknownType() | UnitType():
return type

View File

@@ -1,129 +0,0 @@
from typing import Literal, Optional, cast
from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import (
AppliedType,
ConstraintType,
Function,
GenericType,
OverloadedFunction,
Type,
TypeVar,
Variance,
)
Polarity = Literal[-1, 0, 1]
class Tracker:
def __init__(self, vars: list[TypeVar]) -> None:
self.vars: list[TypeVar] = vars
self.refs: dict[str, set[Polarity]] = {var.name: set() for var in self.vars}
def record(self, var: TypeVar, polarity: Polarity):
self.refs[var.name].add(polarity)
def get_updated_vars(self) -> list[TypeVar]:
return [
TypeVar(
name=var.name, bound=var.bound, variance=self.get_variance(var.name)
)
for var in self.vars
]
def get_variance(self, name: str) -> Variance:
refs: set[Polarity] = self.refs[name]
if refs == {-1}:
return Variance.CONTRAVARIANT
if refs == {1}:
return Variance.COVARIANT
return Variance.INVARIANT
def __contains__(self, item: TypeVar | str):
if isinstance(item, TypeVar):
return item.name in self
return item in self.refs
class VarianceInferrer:
def __init__(self, types: TypesRegistry) -> None:
self.types: TypesRegistry = types
self.tracker: Tracker = Tracker([])
def infer(self, type: GenericType) -> GenericType:
self.tracker = Tracker(type.params)
self.walk(type.body, 1, type.name)
members: dict[str, Member] = self.types._members.get(type.name, {})
for name, member in members.items():
self.walk(member.type, 1, type.name, [f"member:'{name}'"])
return GenericType(
name=type.name,
params=self.tracker.get_updated_vars(),
body=type.body,
)
def walk(
self,
type: Type,
polarity: Polarity,
base_name: str,
path: Optional[list[str]] = None,
):
if path is None:
path = []
match type:
# Arguments are negative positions -> flip polarity
# Return is positive position -> keep polarity
case Function(pos_args=pos_args, args=mixed_args, kw_args=kw_args):
all_args: list[Function.Argument] = pos_args + mixed_args + kw_args
for arg in all_args:
self.walk(
arg.type,
-polarity,
base_name,
path + [f"arg:'{arg.name}'"],
)
self.walk(type.returns, polarity, base_name, path + ["return"])
# Walk all overloads
case OverloadedFunction(overloads=overloads):
for overload in overloads:
self.walk(overload, polarity, base_name, path)
# If same name as root generic -> skip
# Get inferred variance of parameters and multiply with current
# polarity to recurse through arguments
case AppliedType(name=name, args=args):
# TODO: handle mutually recursive types
if name == base_name:
return
generic: Type = self.types.get_type(name)
assert isinstance(generic, GenericType)
params: list[TypeVar] = generic.params
polarities: dict[Variance, Polarity] = {
Variance.INVARIANT: 0,
Variance.COVARIANT: 1,
Variance.CONTRAVARIANT: -1,
}
for arg, param in zip(args, params):
param_polarity: Polarity = polarities[param.variance]
self.walk(
arg,
cast(Polarity, polarity * param_polarity),
base_name,
path + [f"applied:'{name}'"],
)
# Walk base type
case ConstraintType(type=base):
self.walk(base, polarity, base_name, path + ["constraint"])
# Reached end
# If tracked, record polarity
case TypeVar():
if type in self.tracker:
self.tracker.record(type, polarity)

View File

@@ -4,6 +4,5 @@ from .format import format as format
from .highlight import highlight as highlight
from .parse import parse as parse
from .registry import dump_registry as dump_registry
from .stubs import stubs as stubs
from .types import types as types
from .validate import validate as validate

View File

@@ -10,7 +10,6 @@ import click
from midas.ast.printer import MidasPrinter
from midas.checker.checker import TypeChecker
from midas.checker.registry import Member
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
@@ -39,17 +38,12 @@ def dump_registry(
print("##### Types #####")
for name, type in checker.types._types.items():
members: dict[str, Member] = checker.types._members.get(name, {})
params: str = ""
if isinstance(type, GenericType):
params = ", ".join(map(str, type.params))
params = f"[{params}]"
print(f"{name}{params} = {base_type(type)}")
members: dict[str, Type] = checker.types._members.get(name, {})
print(f"{name} = {base_type(type)}")
if len(members) != 0:
print(" " * 4 + "Members:")
for member_name, member in members.items():
kind: str = member.kind.name
print(" " * 8 + f"({kind:8}) {member_name}: {member.type}")
for member_name, member_type in members.items():
print(" " * 8 + f"{member_name}: {member_type}")
print("##### Predicates #####")
printer = MidasPrinter()

View File

@@ -1,27 +0,0 @@
import ast
from pathlib import Path
from typing import TextIO
import click
from midas.checker.checker import TypeChecker
from midas.generator.stubs import StubsGenerator
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
def stubs(
file: TextIO,
output: TextIO,
):
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
checker.import_midas(source_path)
generator = StubsGenerator(checker.types)
module: ast.Module = generator.generate_stubs()
module = ast.fix_missing_locations(module)
output.write(ast.unparse(module))

View File

@@ -228,13 +228,6 @@ class PythonHighlighter(
for item in expr.items:
item.accept(self)
def visit_dict_expr(self, expr: p.DictExpr) -> None:
for key in expr.keys:
if key is not None:
key.accept(self)
for value in expr.values:
value.accept(self)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
expr.object.accept(self)
expr.index.accept(self)
@@ -247,10 +240,6 @@ class PythonHighlighter(
if expr.step is not None:
expr.step.accept(self)
def visit_raw_expr(self, expr: p.RawExpr) -> None: ...
def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ...
class MidasHighlighter(
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
@@ -277,9 +266,8 @@ class MidasHighlighter(
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.wrap(stmt, "predicate")
self.wrap(LocatableToken(stmt.name), "predicate-name")
for spec in stmt.params:
self._visit_param_spec(spec)
stmt.body.accept(self)
stmt.type.accept(self)
stmt.condition.accept(self)
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.wrap(expr, "logical-expr")
@@ -295,14 +283,6 @@ class MidasHighlighter(
self.wrap(expr, "unary-expr")
expr.right.accept(self)
def visit_call_expr(self, expr: m.CallExpr) -> None:
self.wrap(expr, "call-expr")
expr.callee.accept(self)
for arg in expr.arguments:
arg.accept(self)
for arg in expr.keywords.values():
arg.accept(self)
def visit_get_expr(self, expr: m.GetExpr) -> None:
self.wrap(expr, "get-expr")
expr.expr.accept(self)
@@ -338,7 +318,8 @@ class MidasHighlighter(
def visit_function_type(self, type: m.FunctionType) -> None:
self.wrap(type, "function")
self._visit_param_spec(type.params)
for arg in type.pos_args + type.args + type.kw_args:
arg.type.accept(self)
type.returns.accept(self)
def visit_extension_type(self, type: m.ExtensionType) -> None:
@@ -346,10 +327,6 @@ class MidasHighlighter(
type.base.accept(self)
type.extension.accept(self)
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
for param in spec.pos + spec.mixed + spec.kw:
param.type.accept(self)
class DiagnosticsHighlighter(Highlighter):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"

View File

@@ -18,7 +18,6 @@ midas.add_command(commands.highlight)
midas.add_command(commands.parse)
midas.add_command(commands.dump_registry)
midas.add_command(commands.types)
midas.add_command(commands.stubs)
midas.add_command(commands.validate)

View File

@@ -1,368 +0,0 @@
import ast
from typing import Optional, assert_never
import midas.ast.midas as m
from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
ConstraintType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
TopType,
Type,
TypeVar,
UnitType,
UnknownType,
Variance,
substitute_typevars,
)
Empty = ast.Constant(value=...)
class StubsGenerator:
def __init__(self, types: TypesRegistry) -> None:
self.types: TypesRegistry = types
self.stubs: list[ast.stmt] = []
self.typing_imports: set[str] = set()
self.protocol_idx: int = 0
self.stub_idx: int = 0
self.type_var_idx: int = 0
self.substitutions: dict[str, dict[str, Type]] = {}
def generate_stubs(self) -> ast.Module:
self.stubs = []
self.typing_imports = set()
for name, type in self.types._types.items():
self.generate_stub(name, type)
imports = [
ast.ImportFrom(
module="__future__",
names=[ast.alias(name="annotations")],
level=0,
)
]
if len(self.typing_imports) != 0:
imports.append(
ast.ImportFrom(
module="typing",
names=[
ast.alias(name=name) for name in sorted(self.typing_imports)
],
level=0,
)
)
return ast.Module(body=imports + self.stubs, type_ignores=[])
def generate_stub(self, name: str, type: Type):
base_type: Type = type
members: dict[str, Member] = self.types._members.get(name, {})
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
return
bases: list[ast.expr] = []
substitutions: dict[str, Type] = {}
bases, substitutions = self.get_bases(type)
self.substitutions[name] = substitutions
body = self.generate_body(members, substitutions)
stub = ast.ClassDef(
name=name,
bases=bases,
body=body,
keywords=[],
decorator_list=[],
)
self.add_stub(stub)
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
match type:
case AliasType(type=base):
return [self.dump_type(base)], {}
case GenericType(params=params, body=body):
self.add_typing_import("Generic")
type_vars: ast.expr
params2: list[TypeVar] = self.define_type_vars(params)
if len(params) == 1:
type_vars = ast.Name(id=params2[0].name)
else:
type_vars = ast.Tuple(
elts=[ast.Name(id=param.name) for param in params2]
)
substitutions: dict[str, TypeVar] = {
param.name: param2 for param, param2 in zip(params, params2)
}
body_bases, body_subsitutions = self.get_bases(body)
return (
body_bases
+ [
ast.Subscript(
value=ast.Name(id="Generic"),
slice=type_vars,
)
],
body_subsitutions | substitutions,
)
case _:
return [], {}
def generate_body(
self, members: dict[str, Member], substitutions: dict[str, Type]
) -> list[ast.stmt]:
if len(members) == 0:
return [ast.Expr(value=Empty)]
body: list[ast.stmt] = []
for name, member in members.items():
type: Type = member.type
type = substitute_typevars(type, substitutions)
match member.kind:
case m.MemberKind.PROPERTY:
body.append(
ast.AnnAssign(
target=ast.Name(id=name),
annotation=self.dump_type(type),
simple=1,
)
)
case m.MemberKind.METHOD:
body.extend(self.dump_method(name, type))
return body
def dump_type(self, type: Type) -> ast.expr:
match type:
case AliasType(name=name) | GenericType(name=name) if (
name in self.substitutions
):
type = substitute_typevars(type, self.substitutions[name])
match type:
case TopType() | UnknownType():
self.add_typing_import("Any")
return ast.Name(id="Any")
case BaseType(name=name):
return ast.Name(id=name)
case AliasType(name=name):
return ast.Name(id=name)
case UnitType():
return ast.Constant(value=None)
case Function():
name: str = self.define_protocol(type)
return ast.Name(id=name)
case OverloadedFunction(overloads=overloads):
if len(overloads) == 1:
return self.dump_type(overloads[0])
return ast.BinOp(
left=self.dump_type(OverloadedFunction(overloads=overloads[:-1])),
op=ast.BitOr(),
right=self.dump_type(overloads[-1]),
)
case ComplexType():
name: str = self.new_stub_name()
self.generate_stub(name, type)
return ast.Name(id=name)
case ExtensionType():
raise NotImplementedError
case TypeVar():
return ast.Name(id=type.name)
case GenericType(name=name):
params: ast.expr
if len(type.params) == 1:
params = self.dump_type(type.params[0])
else:
params = ast.Tuple(
elts=[self.dump_type(param) for param in type.params]
)
return ast.Subscript(
value=ast.Name(id=type.name),
slice=params,
)
case AppliedType():
args: ast.expr
if len(type.args) == 1:
args = self.dump_type(type.args[0])
else:
args = ast.Tuple(elts=[self.dump_type(arg) for arg in type.args])
return ast.Subscript(
value=ast.Name(id=type.name),
slice=args,
)
case ConstraintType():
return self.dump_type(type.type)
case _:
assert_never(type)
def dump_method(
self, name: str, method: Type, overloaded: bool = False
) -> list[ast.stmt]:
match method:
case Function():
if overloaded:
self.add_typing_import("overload")
return [
ast.FunctionDef(
name=name,
args=self.dump_args(method, with_self=True),
returns=self.dump_type(method.returns),
body=[ast.Expr(value=Empty)],
decorator_list=[ast.Name(id="overload")] if overloaded else [],
)
]
case OverloadedFunction(overloads=overloads):
stmts: list[ast.stmt] = []
for overload in overloads:
stmts.extend(self.dump_method(name, overload, True))
return stmts
case _:
return [
ast.AnnAssign(
target=ast.Name(id=name),
annotation=self.dump_type(method),
simple=1,
)
]
def dump_args(self, func: Function, with_self: bool = False) -> ast.arguments:
pos: list[ast.arg] = [
ast.arg(arg=f"_{arg.pos}", annotation=self.dump_type(arg.type))
for arg in func.pos_args
]
mixed: list[ast.arg] = [
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
for arg in func.args
]
kw: list[ast.arg] = [
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
for arg in func.kw_args
]
defaults: list[ast.expr] = [
Empty for arg in func.pos_args + func.args if not arg.required
]
kw_defaults: list[Optional[ast.expr]] = [
None if arg.required else Empty for arg in func.kw_args
]
if with_self:
arg = ast.arg(arg="self", annotation=None)
if len(pos) != 0:
pos.insert(0, arg)
else:
mixed.insert(0, arg)
return ast.arguments(
posonlyargs=pos,
args=mixed,
kwonlyargs=kw,
defaults=defaults,
kw_defaults=kw_defaults,
)
def define_protocol(self, func: Function) -> str:
self.add_typing_import("Protocol")
name: str = self.new_protocol_name()
protocol = ast.ClassDef(
name=name,
bases=[ast.Name(id="Protocol")],
keywords=[],
body=[
ast.FunctionDef(
name="__call__",
args=self.dump_args(func, with_self=True),
returns=self.dump_type(func.returns),
body=[ast.Expr(value=Empty)],
decorator_list=[],
),
],
decorator_list=[],
)
self.add_stub(protocol)
return name
def new_protocol_name(self) -> str:
name: str = f"_Protocol{self.protocol_idx}"
self.protocol_idx += 1
return name
def new_stub_name(self) -> str:
name: str = f"_Stub_{self.stub_idx}"
self.stub_idx += 1
return name
def new_type_var_name(self) -> str:
name: str = f"_T{self.type_var_idx}"
self.type_var_idx += 1
return name
def add_stub(self, stub: ast.stmt):
self.stubs.append(stub)
def add_typing_import(self, name: str):
self.typing_imports.add(name)
def define_type_vars(self, vars: list[TypeVar]) -> list[TypeVar]:
vars2: list[TypeVar] = []
for var in vars:
vars2.append(self.define_type_var(var))
return vars2
def define_type_var(self, var: TypeVar) -> TypeVar:
name: str = self.new_type_var_name()
self.add_typing_import("TypeVar")
kwargs: list[ast.keyword] = []
if var.bound is not None:
kwargs.append(
ast.keyword(
arg="bound",
value=self.dump_type(var.bound),
)
)
if var.variance == Variance.COVARIANT:
kwargs.append(
ast.keyword(
arg="covariant",
value=ast.Constant(value=True),
)
)
elif var.variance == Variance.CONTRAVARIANT:
kwargs.append(
ast.keyword(
arg="contravariant",
value=ast.Constant(value=True),
)
)
self.add_stub(
ast.Assign(
targets=[ast.Name(id=name)],
value=ast.Call(
func=ast.Name(id="TypeVar"),
args=[
ast.Constant(value=name),
],
keywords=kwargs,
),
)
)
return TypeVar(name=name, bound=None)

View File

@@ -1,59 +0,0 @@
// T is invariant (unused)
type Unused[T] = object
// T is covariant
type Covariant[T] = object
// T is contravariant
type Contravariant[T] = object
// T is invariant
type Invariant[T] = object
extend Covariant[T] {
def foo: fn() -> T
}
extend Contravariant[T] {
def foo: fn(T, /) -> None
}
extend Invariant[T] {
def foo: fn(T, /) -> T
}
// T is covariant
type Coco[T] = object
extend Coco[T] {
def foo: fn() -> Covariant[T]
}
// T is contravariant
type Cocontra[T] = object
extend Cocontra[T] {
def foo: fn() -> Contravariant[T]
}
// T is contravariant
type Contraco[T] = object
extend Contraco[T] {
def foo: fn(Covariant[T], /) -> None
}
// T is covariant
type Contracontra[T] = object
extend Contracontra[T] {
def foo: fn(Contravariant[T], /) -> None
}
type T1[T] = object
type T2[T] = object
extend T1[T] {
def foo: fn() -> T2[T]
}
extend T2[T] {
def foo: fn() -> T1[T]
}

View File

@@ -1,52 +0,0 @@
from _ import (
T1,
T2,
Coco,
Cocontra,
Contraco,
Contracontra,
Contravariant,
Covariant,
Invariant,
Unused,
)
unused: Unused
covariant: Covariant
contravariant: Contravariant
invariant: Invariant
coco: Coco
cocontra: Cocontra
contraco: Contraco
contracontra: Contracontra
t1: T1
t2: T2
# Dummy print to prudce judgements for the expressions
print(
unused,
covariant,
contravariant,
invariant,
coco,
cocontra,
contraco,
contracontra,
t1,
t2,
)
cov1: Covariant[float]
cov2: Covariant[int]
cov1 = cov2 # Ok because int <: float => Covariant[int] <: Covariant[float]
cov2 = cov1 # Invalid
contra1: Contravariant[float]
contra2: Contravariant[int]
contra1 = contra2 # Invalid
contra2 = contra1 # Ok because int <: float => Covariant[float] <: Covariant[int]
inv1: Invariant[float]
inv2: Invariant[int]
inv1 = inv2 # Invalid
inv2 = inv1 # Invalid

View File

@@ -1,512 +0,0 @@
{
"diagnostics": [
{
"type": "Error",
"location": {
"start": [
28,
4
],
"end": [
28,
13
]
},
"message": "Too many positional arguments"
},
{
"type": "Error",
"location": {
"start": [
42,
0
],
"end": [
42,
11
]
},
"message": "Cannot assign Covariant[float] to variable 'cov2' of type Covariant[int]"
},
{
"type": "Error",
"location": {
"start": [
46,
0
],
"end": [
46,
17
]
},
"message": "Cannot assign Contravariant[int] to variable 'contra1' of type Contravariant[float]"
},
{
"type": "Error",
"location": {
"start": [
51,
0
],
"end": [
51,
11
]
},
"message": "Cannot assign Invariant[int] to variable 'inv1' of type Invariant[float]"
},
{
"type": "Error",
"location": {
"start": [
52,
0
],
"end": [
52,
11
]
},
"message": "Cannot assign Invariant[float] to variable 'inv2' of type Invariant[int]"
}
],
"judgments": [
{
"location": {
"from": "L26:0",
"to": "L26:5"
},
"expr": {
"_type": "VariableExpr",
"name": "print"
},
"type": {
"pos_args": [
{
"pos": 0,
"name": "object",
"type": {},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {}
}
},
{
"location": {
"from": "L27:4",
"to": "L27:10"
},
"expr": {
"_type": "VariableExpr",
"name": "unused"
},
"type": {
"name": "Unused",
"params": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L28:4",
"to": "L28:13"
},
"expr": {
"_type": "VariableExpr",
"name": "covariant"
},
"type": {
"name": "Covariant",
"params": [
{
"name": "T",
"bound": null,
"variance": "COVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L29:4",
"to": "L29:17"
},
"expr": {
"_type": "VariableExpr",
"name": "contravariant"
},
"type": {
"name": "Contravariant",
"params": [
{
"name": "T",
"bound": null,
"variance": "CONTRAVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L30:4",
"to": "L30:13"
},
"expr": {
"_type": "VariableExpr",
"name": "invariant"
},
"type": {
"name": "Invariant",
"params": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L31:4",
"to": "L31:8"
},
"expr": {
"_type": "VariableExpr",
"name": "coco"
},
"type": {
"name": "Coco",
"params": [
{
"name": "T",
"bound": null,
"variance": "COVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L32:4",
"to": "L32:12"
},
"expr": {
"_type": "VariableExpr",
"name": "cocontra"
},
"type": {
"name": "Cocontra",
"params": [
{
"name": "T",
"bound": null,
"variance": "CONTRAVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L33:4",
"to": "L33:12"
},
"expr": {
"_type": "VariableExpr",
"name": "contraco"
},
"type": {
"name": "Contraco",
"params": [
{
"name": "T",
"bound": null,
"variance": "CONTRAVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L34:4",
"to": "L34:16"
},
"expr": {
"_type": "VariableExpr",
"name": "contracontra"
},
"type": {
"name": "Contracontra",
"params": [
{
"name": "T",
"bound": null,
"variance": "COVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L35:4",
"to": "L35:6"
},
"expr": {
"_type": "VariableExpr",
"name": "t1"
},
"type": {
"name": "T1",
"params": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L36:4",
"to": "L36:6"
},
"expr": {
"_type": "VariableExpr",
"name": "t2"
},
"type": {
"name": "T2",
"params": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L26:0",
"to": "L37:1"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "VariableExpr",
"name": "print"
},
"arguments": [
{
"_type": "VariableExpr",
"name": "unused"
},
{
"_type": "VariableExpr",
"name": "covariant"
},
{
"_type": "VariableExpr",
"name": "contravariant"
},
{
"_type": "VariableExpr",
"name": "invariant"
},
{
"_type": "VariableExpr",
"name": "coco"
},
{
"_type": "VariableExpr",
"name": "cocontra"
},
{
"_type": "VariableExpr",
"name": "contraco"
},
{
"_type": "VariableExpr",
"name": "contracontra"
},
{
"_type": "VariableExpr",
"name": "t1"
},
{
"_type": "VariableExpr",
"name": "t2"
}
],
"keywords": {}
},
"type": {}
},
{
"location": {
"from": "L41:7",
"to": "L41:11"
},
"expr": {
"_type": "VariableExpr",
"name": "cov2"
},
"type": {
"name": "Covariant",
"args": [
{
"name": "int"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L42:7",
"to": "L42:11"
},
"expr": {
"_type": "VariableExpr",
"name": "cov1"
},
"type": {
"name": "Covariant",
"args": [
{
"name": "float"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L46:10",
"to": "L46:17"
},
"expr": {
"_type": "VariableExpr",
"name": "contra2"
},
"type": {
"name": "Contravariant",
"args": [
{
"name": "int"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L47:10",
"to": "L47:17"
},
"expr": {
"_type": "VariableExpr",
"name": "contra1"
},
"type": {
"name": "Contravariant",
"args": [
{
"name": "float"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L51:7",
"to": "L51:11"
},
"expr": {
"_type": "VariableExpr",
"name": "inv2"
},
"type": {
"name": "Invariant",
"args": [
{
"name": "int"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L52:7",
"to": "L52:11"
},
"expr": {
"_type": "VariableExpr",
"name": "inv1"
},
"type": {
"name": "Invariant",
"args": [
{
"name": "float"
}
],
"body": {
"name": "object"
}
}
}
]
}