diff --git a/examples/02_demonstration/demo.midas b/examples/02_demonstration/demo.midas new file mode 100644 index 0000000..661a583 --- /dev/null +++ b/examples/02_demonstration/demo.midas @@ -0,0 +1,15 @@ +predicate in_range(min: float, max: float)(v: float) = min <= v & v <= max +predicate is_ratio = in_range(0, 1) + +type Currency = float +type Price[T <: Currency] = T where _ >= 0 + +extend Price[T <: Currency] { + def __add__: fn(Price[T], /) -> Price[T] +} + +type EUR = Currency +type USD = Currency +type CHF = Currency + +type Discount = float where is_ratio(_) diff --git a/examples/02_demonstration/demo.py b/examples/02_demonstration/demo.py new file mode 100644 index 0000000..c4ec322 --- /dev/null +++ b/examples/02_demonstration/demo.py @@ -0,0 +1,30 @@ +from typing import TypeVar, cast + +from demo_stubs import CHF, EUR, USD, Currency, Price, Discount + +T = TypeVar("T", bound=Currency) + + +def apply_discount(amount: Price[T], discount: Discount) -> Price[T]: + return cast(Price[T], (1.0 - discount) * amount) + + +a1 = cast(Price[EUR], 3.2) +a2 = cast(Price[USD], 10.4) +r1 = cast(Discount, 0.2) + +print(apply_discount(a1, r1)) +print(apply_discount(a2, r1)) + +a3 = a1 + a1 +a4 = a1 + a2 # cannot add euros and dollars +a3 = a2 # cannot change variable type + +dyn_price = float(input("Price (CHF): ")) +dyn_discount = float(input("Discount (0.0-1.0): ")) +discounted = apply_discount( + cast(Price[CHF], dyn_price), + cast(Discount, dyn_discount), +) + +print(f"Discounted: CHF {discounted}") diff --git a/examples/02_demonstration/demo_stubs.pyi b/examples/02_demonstration/demo_stubs.pyi new file mode 100644 index 0000000..6615018 --- /dev/null +++ b/examples/02_demonstration/demo_stubs.pyi @@ -0,0 +1,14 @@ +from __future__ import annotations +from typing import Generic, TypeVar + +class Currency(float): ... + +_T0 = TypeVar("_T0", bound=Currency, covariant=True) + +class Price(Currency, Generic[_T0]): + def __add__(self, _0: Price[_T0], /) -> Price[_T0]: ... + +class EUR(Currency): ... +class USD(Currency): ... +class CHF(Currency): ... +class Discount(float): ... diff --git a/midas/checker/preamble.py b/midas/checker/preamble.py index 96a4ef7..1dcd157 100644 --- a/midas/checker/preamble.py +++ b/midas/checker/preamble.py @@ -54,6 +54,11 @@ class Preamble(Environment): returns=self._list_of(map_out), # TODO: replace with Iterable[U] type_vars=[map_in, map_out], ) + self._def_function( + name="input", + pos=[Param("prompt", TopType(), required=False)], + returns=self._types.get_type("str"), + ) def _list_of(self, item_type: Type) -> Type: return self._types.apply_generic(self._types.get_type("list"), [item_type]) diff --git a/midas/checker/python.py b/midas/checker/python.py index e1fb788..435f6f1 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -22,8 +22,10 @@ from midas.checker.types import ( GenericType, OverloadedFunction, Type, + TypeVar, UnitType, UnknownType, + Variance, unfold_type, ) from midas.checker.unifier import Unifier @@ -229,7 +231,8 @@ class PythonTyper( ) pos += 1 - for arg in pos_args + args + kw_args: + all_args: list[Function.Argument] = pos_args + args + kw_args + for arg in all_args: env.define(arg.name, arg.type) returns_hint: Optional[Type] = None @@ -270,12 +273,25 @@ class PythonTyper( returns = inferred_return # TODO: handle *args and **kwargs sinks - function: Function = Function( + function: Type = Function( pos_args=pos_args, args=args, kw_args=kw_args, returns=returns, ) + generic_params: list[TypeVar] = [] + all_types: list[Type] = [arg.type for arg in all_args] + [returns] + for type in all_types: + if isinstance(type, TypeVar): + if type not in generic_params: + generic_params.append(type) + + if len(generic_params) != 0: + function = GenericType( + name=stmt.name, + params=generic_params, + body=function, + ) self.env.define(stmt.name, function) def visit_type_assign(self, stmt: p.TypeAssign) -> None: @@ -453,6 +469,10 @@ class PythonTyper( return result or UnknownType() def visit_call_expr(self, expr: p.CallExpr) -> Type: + match expr.callee: + case p.VariableExpr(name="TypeVar"): + return self.define_typevar(expr) or UnknownType() + callee: Type = self.type_of(expr.callee) positional: list[TypedExpr] = [ (arg, self.type_of(arg)) for arg in expr.arguments @@ -518,6 +538,7 @@ class PythonTyper( return UnknownType() def visit_cast_expr(self, expr: p.CastExpr) -> Type: + _ = self.type_of(expr.expr) return self.resolve_type_expr(expr.type) def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type: @@ -1033,3 +1054,57 @@ class PythonTyper( report_errors=False, ) return result + + def define_typevar(self, call: p.CallExpr) -> Optional[TypeVar]: + def is_kw_true(name: str) -> bool: + match call.keywords.get(name): + case p.LiteralExpr(value=True): + return True + case _: + return False + + match call: + case p.CallExpr( + arguments=[p.LiteralExpr(value=str() as name)], + ): + bound: Optional[Type] = None + variance: Variance = Variance.INVARIANT + if "bound" in call.keywords: + bound_type: p.MidasType = self._parse_type_from_expr( + call.keywords["bound"] + ) + bound = self.resolve_type_expr(bound_type) + + if is_kw_true("covariant"): + variance = Variance.COVARIANT + + if is_kw_true("contravariant"): + if variance == Variance.COVARIANT: + self.reporter.warning( + call.keywords["contravariant"].location, + "TypeVar cannot be covariant and contravariant at the same time. Marked as invariant", + ) + variance = Variance.INVARIANT + else: + variance = Variance.CONTRAVARIANT + var: TypeVar = TypeVar(name=name, bound=bound, variance=variance) + self.types.define_type(name, var) + return var + + case _: + self.reporter.warning( + call.location, "Invalid usage of 'TypeVar', skipping" + ) + return None + + def _parse_type_from_expr(self, expr: p.Expr) -> p.MidasType: + location: Location = expr.location + parser = PythonParser() + match expr: + case p.LiteralExpr(value=str() as value): + node: ast.Expression = ast.parse(value, mode="eval") + return parser._parse_type(node.body) + case p.VariableExpr(name=name): + return p.BaseType(location=location, base=name, param=None) + case _: + raise NotImplementedError diff --git a/midas/checker/registry.py b/midas/checker/registry.py index b787f20..bbd3d85 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -130,6 +130,19 @@ class TypesRegistry: case (_, TopType()): return True + case (_, UnknownType()): + return True + + case (TypeVar(bound=bound), _): + if bound is None: + return False + return self.is_subtype(bound, type2) + + case (_, TypeVar(bound=bound)): + if bound is None: + return True + return self.is_subtype(type1, bound) + case (AliasType(type=base1), _): return self.is_subtype(base1, type2) @@ -147,11 +160,6 @@ class TypesRegistry: case (Function(), Function()): return self.is_func_subtype(type1, type2) - case (TypeVar(bound=bound), _): - if bound is None: - return False - return self.is_subtype(bound, type2) - case (ConstraintType(type=base1), _): return self.is_subtype(base1, type2) @@ -173,6 +181,10 @@ class TypesRegistry: return False return True + # TODO: verify legitimacy + case (AppliedType(body=body), _): + return self.is_subtype(body, type2) + return False # TODO: verify the logic in here @@ -389,6 +401,12 @@ class TypesRegistry: ) return self.lookup_member(base, member_name) + case ConstraintType(type=base): + return self.lookup_member(base, member_name) + + case TypeVar(bound=bound) if bound is not None: + return self.lookup_member(bound, member_name) + case UnknownType(): return UnknownType() diff --git a/midas/cli/commands/stubs.py b/midas/cli/commands/stubs.py index 98b3cd4..a5267c9 100644 --- a/midas/cli/commands/stubs.py +++ b/midas/cli/commands/stubs.py @@ -1,27 +1,64 @@ import ast +import time from pathlib import Path from typing import TextIO +import black import click +from watchdog.events import DirModifiedEvent, FileModifiedEvent, FileSystemEventHandler +from watchdog.observers import Observer from midas.checker.checker import TypeChecker from midas.generator.stubs import StubsGenerator -@click.command(help="Generate stubs from Midas definitions") -@click.argument("file", type=click.File("r")) -@click.option("-o", "--output", type=click.File("w"), default="-") -def stubs( - file: TextIO, - output: TextIO, -): - source_path: Path = Path(file.name).resolve() - +def generate_stubs(in_path: Path, out_path: Path): checker = TypeChecker() - checker.import_midas(source_path) + checker.import_midas(in_path) generator = StubsGenerator(checker.types) module: ast.Module = generator.generate_stubs() module = ast.fix_missing_locations(module) - output.write(ast.unparse(module)) + output: str = ast.unparse(module) + output = black.format_str(output, mode=black.Mode(is_pyi=True)) + + out_path.write_text(output) + + +class Handler(FileSystemEventHandler): + def __init__(self, in_path: Path, out_path: Path) -> None: + super().__init__() + self.in_path: Path = in_path + self.out_path: Path = out_path + + def on_modified(self, event: DirModifiedEvent | FileModifiedEvent) -> None: + generate_stubs(self.in_path, self.out_path) + + +@click.command(help="Generate stubs from Midas definitions") +@click.argument("file", type=click.File("r")) +@click.option("-o", "--output", type=click.File("w"), default="-") +@click.option("-w", "--watch", is_flag=True) +def stubs( + file: TextIO, + output: TextIO, + watch: bool, +): + source_path: Path = Path(file.name).resolve() + out_path: Path = Path(output.name).resolve() + generate_stubs(source_path, out_path) + + if watch: + print(f"Watching {source_path}...") + print("Press CTRL+C to stop") + handler = Handler(source_path, out_path) + observer = Observer() + observer.schedule(handler, str(source_path)) + observer.start() + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + observer.stop() + observer.join() diff --git a/midas/generator/generator.py b/midas/generator/generator.py index 22eab41..e66f532 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -322,8 +322,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): self._make_cast_asserts(src_location, expr, base) self._make_constraint_assert(src_location, expr, constraint) - case TypeVar(): - raise RuntimeError("Unexpected TypeVar") + case TypeVar(bound=bound): + # TODO: check with type from arguments / use call-site context + if bound is not None: + self._make_cast_asserts(src_location, expr, bound) case ( TopType() diff --git a/midas/generator/stubs.py b/midas/generator/stubs.py index d54c948..c9a3804 100644 --- a/midas/generator/stubs.py +++ b/midas/generator/stubs.py @@ -39,6 +39,18 @@ class StubsGenerator: self.stubs = [] self.typing_imports = set() for name, type in self.types._types.items(): + # Skip builtin types, not just based on name so the user can override + # TODO: check if added members on builtin type + match type: + case BaseType(name=name_) if name == name_: + continue + case GenericType( + name=name1, + body=BaseType(name=name2), + ) if ( + name == name1 == name2 + ): + continue self.generate_stub(name, type) imports = [ @@ -115,6 +127,12 @@ class StubsGenerator: body_subsitutions | substitutions, ) + case ConstraintType(type=base): + return self.get_bases(base) + + case TypeVar(bound=bound) if bound is not None: + return [self.dump_type(bound)], {} + case _: return [], {} diff --git a/pyproject.toml b/pyproject.toml index 69a9f7e..48f3b21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,11 @@ authors = [ { name = "Louis Heredero", email = "louis.heredero@students.hevs.ch" }, ] classifiers = ["Programming Language :: Python :: 3"] -dependencies = ["click>=8.4.1"] +dependencies = [ + "black>=26.5.1", + "click>=8.4.1", + "watchdog>=6.0.0", +] [project.urls] Homepage = "https://git.kbk28.ch/HEL/midas"