Merge pull request 'Fixes and small demo' (#19) from feat/demonstration into main

Reviewed-on: #19
This commit was merged in pull request #19.
This commit is contained in:
2026-06-23 08:15:56 +00:00
10 changed files with 239 additions and 21 deletions

View File

@@ -0,0 +1,15 @@
predicate in_range(min: float, max: float)(v: float) = min <= v & v <= max
predicate is_ratio = in_range(0, 1)
type Currency = float
type Price[T <: Currency] = T where _ >= 0
extend Price[T <: Currency] {
def __add__: fn(Price[T], /) -> Price[T]
}
type EUR = Currency
type USD = Currency
type CHF = Currency
type Discount = float where is_ratio(_)

View File

@@ -0,0 +1,30 @@
from typing import TypeVar, cast
from demo_stubs import CHF, EUR, USD, Currency, Price, Discount
T = TypeVar("T", bound=Currency)
def apply_discount(amount: Price[T], discount: Discount) -> Price[T]:
return cast(Price[T], (1.0 - discount) * amount)
a1 = cast(Price[EUR], 3.2)
a2 = cast(Price[USD], 10.4)
r1 = cast(Discount, 0.2)
print(apply_discount(a1, r1))
print(apply_discount(a2, r1))
a3 = a1 + a1
a4 = a1 + a2 # cannot add euros and dollars
a3 = a2 # cannot change variable type
dyn_price = float(input("Price (CHF): "))
dyn_discount = float(input("Discount (0.0-1.0): "))
discounted = apply_discount(
cast(Price[CHF], dyn_price),
cast(Discount, dyn_discount),
)
print(f"Discounted: CHF {discounted}")

View File

@@ -0,0 +1,14 @@
from __future__ import annotations
from typing import Generic, TypeVar
class Currency(float): ...
_T0 = TypeVar("_T0", bound=Currency, covariant=True)
class Price(Currency, Generic[_T0]):
def __add__(self, _0: Price[_T0], /) -> Price[_T0]: ...
class EUR(Currency): ...
class USD(Currency): ...
class CHF(Currency): ...
class Discount(float): ...

View File

@@ -54,6 +54,11 @@ class Preamble(Environment):
returns=self._list_of(map_out), # TODO: replace with Iterable[U]
type_vars=[map_in, map_out],
)
self._def_function(
name="input",
pos=[Param("prompt", TopType(), required=False)],
returns=self._types.get_type("str"),
)
def _list_of(self, item_type: Type) -> Type:
return self._types.apply_generic(self._types.get_type("list"), [item_type])

View File

@@ -22,8 +22,10 @@ from midas.checker.types import (
GenericType,
OverloadedFunction,
Type,
TypeVar,
UnitType,
UnknownType,
Variance,
unfold_type,
)
from midas.checker.unifier import Unifier
@@ -229,7 +231,8 @@ class PythonTyper(
)
pos += 1
for arg in pos_args + args + kw_args:
all_args: list[Function.Argument] = pos_args + args + kw_args
for arg in all_args:
env.define(arg.name, arg.type)
returns_hint: Optional[Type] = None
@@ -270,12 +273,25 @@ class PythonTyper(
returns = inferred_return
# TODO: handle *args and **kwargs sinks
function: Function = Function(
function: Type = Function(
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns,
)
generic_params: list[TypeVar] = []
all_types: list[Type] = [arg.type for arg in all_args] + [returns]
for type in all_types:
if isinstance(type, TypeVar):
if type not in generic_params:
generic_params.append(type)
if len(generic_params) != 0:
function = GenericType(
name=stmt.name,
params=generic_params,
body=function,
)
self.env.define(stmt.name, function)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
@@ -453,6 +469,10 @@ class PythonTyper(
return result or UnknownType()
def visit_call_expr(self, expr: p.CallExpr) -> Type:
match expr.callee:
case p.VariableExpr(name="TypeVar"):
return self.define_typevar(expr) or UnknownType()
callee: Type = self.type_of(expr.callee)
positional: list[TypedExpr] = [
(arg, self.type_of(arg)) for arg in expr.arguments
@@ -518,6 +538,7 @@ class PythonTyper(
return UnknownType()
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
_ = self.type_of(expr.expr)
return self.resolve_type_expr(expr.type)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
@@ -1033,3 +1054,57 @@ class PythonTyper(
report_errors=False,
)
return result
def define_typevar(self, call: p.CallExpr) -> Optional[TypeVar]:
def is_kw_true(name: str) -> bool:
match call.keywords.get(name):
case p.LiteralExpr(value=True):
return True
case _:
return False
match call:
case p.CallExpr(
arguments=[p.LiteralExpr(value=str() as name)],
):
bound: Optional[Type] = None
variance: Variance = Variance.INVARIANT
if "bound" in call.keywords:
bound_type: p.MidasType = self._parse_type_from_expr(
call.keywords["bound"]
)
bound = self.resolve_type_expr(bound_type)
if is_kw_true("covariant"):
variance = Variance.COVARIANT
if is_kw_true("contravariant"):
if variance == Variance.COVARIANT:
self.reporter.warning(
call.keywords["contravariant"].location,
"TypeVar cannot be covariant and contravariant at the same time. Marked as invariant",
)
variance = Variance.INVARIANT
else:
variance = Variance.CONTRAVARIANT
var: TypeVar = TypeVar(name=name, bound=bound, variance=variance)
self.types.define_type(name, var)
return var
case _:
self.reporter.warning(
call.location, "Invalid usage of 'TypeVar', skipping"
)
return None
def _parse_type_from_expr(self, expr: p.Expr) -> p.MidasType:
location: Location = expr.location
parser = PythonParser()
match expr:
case p.LiteralExpr(value=str() as value):
node: ast.Expression = ast.parse(value, mode="eval")
return parser._parse_type(node.body)
case p.VariableExpr(name=name):
return p.BaseType(location=location, base=name, param=None)
case _:
raise NotImplementedError

View File

@@ -130,6 +130,19 @@ class TypesRegistry:
case (_, TopType()):
return True
case (_, UnknownType()):
return True
case (TypeVar(bound=bound), _):
if bound is None:
return False
return self.is_subtype(bound, type2)
case (_, TypeVar(bound=bound)):
if bound is None:
return True
return self.is_subtype(type1, bound)
case (AliasType(type=base1), _):
return self.is_subtype(base1, type2)
@@ -147,11 +160,6 @@ class TypesRegistry:
case (Function(), Function()):
return self.is_func_subtype(type1, type2)
case (TypeVar(bound=bound), _):
if bound is None:
return False
return self.is_subtype(bound, type2)
case (ConstraintType(type=base1), _):
return self.is_subtype(base1, type2)
@@ -173,6 +181,10 @@ class TypesRegistry:
return False
return True
# TODO: verify legitimacy
case (AppliedType(body=body), _):
return self.is_subtype(body, type2)
return False
# TODO: verify the logic in here
@@ -389,6 +401,12 @@ class TypesRegistry:
)
return self.lookup_member(base, member_name)
case ConstraintType(type=base):
return self.lookup_member(base, member_name)
case TypeVar(bound=bound) if bound is not None:
return self.lookup_member(bound, member_name)
case UnknownType():
return UnknownType()

View File

@@ -1,27 +1,64 @@
import ast
import time
from pathlib import Path
from typing import TextIO
import black
import click
from watchdog.events import DirModifiedEvent, FileModifiedEvent, FileSystemEventHandler
from watchdog.observers import Observer
from midas.checker.checker import TypeChecker
from midas.generator.stubs import StubsGenerator
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
def stubs(
file: TextIO,
output: TextIO,
):
source_path: Path = Path(file.name).resolve()
def generate_stubs(in_path: Path, out_path: Path):
checker = TypeChecker()
checker.import_midas(source_path)
checker.import_midas(in_path)
generator = StubsGenerator(checker.types)
module: ast.Module = generator.generate_stubs()
module = ast.fix_missing_locations(module)
output.write(ast.unparse(module))
output: str = ast.unparse(module)
output = black.format_str(output, mode=black.Mode(is_pyi=True))
out_path.write_text(output)
class Handler(FileSystemEventHandler):
def __init__(self, in_path: Path, out_path: Path) -> None:
super().__init__()
self.in_path: Path = in_path
self.out_path: Path = out_path
def on_modified(self, event: DirModifiedEvent | FileModifiedEvent) -> None:
generate_stubs(self.in_path, self.out_path)
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
@click.option("-w", "--watch", is_flag=True)
def stubs(
file: TextIO,
output: TextIO,
watch: bool,
):
source_path: Path = Path(file.name).resolve()
out_path: Path = Path(output.name).resolve()
generate_stubs(source_path, out_path)
if watch:
print(f"Watching {source_path}...")
print("Press CTRL+C to stop")
handler = Handler(source_path, out_path)
observer = Observer()
observer.schedule(handler, str(source_path))
observer.start()
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
observer.stop()
observer.join()

View File

@@ -322,8 +322,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._make_cast_asserts(src_location, expr, base)
self._make_constraint_assert(src_location, expr, constraint)
case TypeVar():
raise RuntimeError("Unexpected TypeVar")
case TypeVar(bound=bound):
# TODO: check with type from arguments / use call-site context
if bound is not None:
self._make_cast_asserts(src_location, expr, bound)
case (
TopType()

View File

@@ -39,6 +39,18 @@ class StubsGenerator:
self.stubs = []
self.typing_imports = set()
for name, type in self.types._types.items():
# Skip builtin types, not just based on name so the user can override
# TODO: check if added members on builtin type
match type:
case BaseType(name=name_) if name == name_:
continue
case GenericType(
name=name1,
body=BaseType(name=name2),
) if (
name == name1 == name2
):
continue
self.generate_stub(name, type)
imports = [
@@ -115,6 +127,12 @@ class StubsGenerator:
body_subsitutions | substitutions,
)
case ConstraintType(type=base):
return self.get_bases(base)
case TypeVar(bound=bound) if bound is not None:
return [self.dump_type(bound)], {}
case _:
return [], {}

View File

@@ -8,7 +8,11 @@ authors = [
{ name = "Louis Heredero", email = "louis.heredero@students.hevs.ch" },
]
classifiers = ["Programming Language :: Python :: 3"]
dependencies = ["click>=8.4.1"]
dependencies = [
"black>=26.5.1",
"click>=8.4.1",
"watchdog>=6.0.0",
]
[project.urls]
Homepage = "https://git.kbk28.ch/HEL/midas"