Merge pull request 'Fixes and small demo' (#19) from feat/demonstration into main
Reviewed-on: #19
This commit was merged in pull request #19.
This commit is contained in:
15
examples/02_demonstration/demo.midas
Normal file
15
examples/02_demonstration/demo.midas
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
predicate in_range(min: float, max: float)(v: float) = min <= v & v <= max
|
||||||
|
predicate is_ratio = in_range(0, 1)
|
||||||
|
|
||||||
|
type Currency = float
|
||||||
|
type Price[T <: Currency] = T where _ >= 0
|
||||||
|
|
||||||
|
extend Price[T <: Currency] {
|
||||||
|
def __add__: fn(Price[T], /) -> Price[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
type EUR = Currency
|
||||||
|
type USD = Currency
|
||||||
|
type CHF = Currency
|
||||||
|
|
||||||
|
type Discount = float where is_ratio(_)
|
||||||
30
examples/02_demonstration/demo.py
Normal file
30
examples/02_demonstration/demo.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from typing import TypeVar, cast
|
||||||
|
|
||||||
|
from demo_stubs import CHF, EUR, USD, Currency, Price, Discount
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=Currency)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_discount(amount: Price[T], discount: Discount) -> Price[T]:
|
||||||
|
return cast(Price[T], (1.0 - discount) * amount)
|
||||||
|
|
||||||
|
|
||||||
|
a1 = cast(Price[EUR], 3.2)
|
||||||
|
a2 = cast(Price[USD], 10.4)
|
||||||
|
r1 = cast(Discount, 0.2)
|
||||||
|
|
||||||
|
print(apply_discount(a1, r1))
|
||||||
|
print(apply_discount(a2, r1))
|
||||||
|
|
||||||
|
a3 = a1 + a1
|
||||||
|
a4 = a1 + a2 # cannot add euros and dollars
|
||||||
|
a3 = a2 # cannot change variable type
|
||||||
|
|
||||||
|
dyn_price = float(input("Price (CHF): "))
|
||||||
|
dyn_discount = float(input("Discount (0.0-1.0): "))
|
||||||
|
discounted = apply_discount(
|
||||||
|
cast(Price[CHF], dyn_price),
|
||||||
|
cast(Discount, dyn_discount),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Discounted: CHF {discounted}")
|
||||||
14
examples/02_demonstration/demo_stubs.pyi
Normal file
14
examples/02_demonstration/demo_stubs.pyi
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
class Currency(float): ...
|
||||||
|
|
||||||
|
_T0 = TypeVar("_T0", bound=Currency, covariant=True)
|
||||||
|
|
||||||
|
class Price(Currency, Generic[_T0]):
|
||||||
|
def __add__(self, _0: Price[_T0], /) -> Price[_T0]: ...
|
||||||
|
|
||||||
|
class EUR(Currency): ...
|
||||||
|
class USD(Currency): ...
|
||||||
|
class CHF(Currency): ...
|
||||||
|
class Discount(float): ...
|
||||||
@@ -54,6 +54,11 @@ class Preamble(Environment):
|
|||||||
returns=self._list_of(map_out), # TODO: replace with Iterable[U]
|
returns=self._list_of(map_out), # TODO: replace with Iterable[U]
|
||||||
type_vars=[map_in, map_out],
|
type_vars=[map_in, map_out],
|
||||||
)
|
)
|
||||||
|
self._def_function(
|
||||||
|
name="input",
|
||||||
|
pos=[Param("prompt", TopType(), required=False)],
|
||||||
|
returns=self._types.get_type("str"),
|
||||||
|
)
|
||||||
|
|
||||||
def _list_of(self, item_type: Type) -> Type:
|
def _list_of(self, item_type: Type) -> Type:
|
||||||
return self._types.apply_generic(self._types.get_type("list"), [item_type])
|
return self._types.apply_generic(self._types.get_type("list"), [item_type])
|
||||||
|
|||||||
@@ -22,8 +22,10 @@ from midas.checker.types import (
|
|||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
Type,
|
Type,
|
||||||
|
TypeVar,
|
||||||
UnitType,
|
UnitType,
|
||||||
UnknownType,
|
UnknownType,
|
||||||
|
Variance,
|
||||||
unfold_type,
|
unfold_type,
|
||||||
)
|
)
|
||||||
from midas.checker.unifier import Unifier
|
from midas.checker.unifier import Unifier
|
||||||
@@ -229,7 +231,8 @@ class PythonTyper(
|
|||||||
)
|
)
|
||||||
pos += 1
|
pos += 1
|
||||||
|
|
||||||
for arg in pos_args + args + kw_args:
|
all_args: list[Function.Argument] = pos_args + args + kw_args
|
||||||
|
for arg in all_args:
|
||||||
env.define(arg.name, arg.type)
|
env.define(arg.name, arg.type)
|
||||||
|
|
||||||
returns_hint: Optional[Type] = None
|
returns_hint: Optional[Type] = None
|
||||||
@@ -270,12 +273,25 @@ class PythonTyper(
|
|||||||
returns = inferred_return
|
returns = inferred_return
|
||||||
|
|
||||||
# TODO: handle *args and **kwargs sinks
|
# TODO: handle *args and **kwargs sinks
|
||||||
function: Function = Function(
|
function: Type = Function(
|
||||||
pos_args=pos_args,
|
pos_args=pos_args,
|
||||||
args=args,
|
args=args,
|
||||||
kw_args=kw_args,
|
kw_args=kw_args,
|
||||||
returns=returns,
|
returns=returns,
|
||||||
)
|
)
|
||||||
|
generic_params: list[TypeVar] = []
|
||||||
|
all_types: list[Type] = [arg.type for arg in all_args] + [returns]
|
||||||
|
for type in all_types:
|
||||||
|
if isinstance(type, TypeVar):
|
||||||
|
if type not in generic_params:
|
||||||
|
generic_params.append(type)
|
||||||
|
|
||||||
|
if len(generic_params) != 0:
|
||||||
|
function = GenericType(
|
||||||
|
name=stmt.name,
|
||||||
|
params=generic_params,
|
||||||
|
body=function,
|
||||||
|
)
|
||||||
self.env.define(stmt.name, function)
|
self.env.define(stmt.name, function)
|
||||||
|
|
||||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||||
@@ -453,6 +469,10 @@ class PythonTyper(
|
|||||||
return result or UnknownType()
|
return result or UnknownType()
|
||||||
|
|
||||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||||
|
match expr.callee:
|
||||||
|
case p.VariableExpr(name="TypeVar"):
|
||||||
|
return self.define_typevar(expr) or UnknownType()
|
||||||
|
|
||||||
callee: Type = self.type_of(expr.callee)
|
callee: Type = self.type_of(expr.callee)
|
||||||
positional: list[TypedExpr] = [
|
positional: list[TypedExpr] = [
|
||||||
(arg, self.type_of(arg)) for arg in expr.arguments
|
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||||
@@ -518,6 +538,7 @@ class PythonTyper(
|
|||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
||||||
|
_ = self.type_of(expr.expr)
|
||||||
return self.resolve_type_expr(expr.type)
|
return self.resolve_type_expr(expr.type)
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
||||||
@@ -1033,3 +1054,57 @@ class PythonTyper(
|
|||||||
report_errors=False,
|
report_errors=False,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def define_typevar(self, call: p.CallExpr) -> Optional[TypeVar]:
|
||||||
|
def is_kw_true(name: str) -> bool:
|
||||||
|
match call.keywords.get(name):
|
||||||
|
case p.LiteralExpr(value=True):
|
||||||
|
return True
|
||||||
|
case _:
|
||||||
|
return False
|
||||||
|
|
||||||
|
match call:
|
||||||
|
case p.CallExpr(
|
||||||
|
arguments=[p.LiteralExpr(value=str() as name)],
|
||||||
|
):
|
||||||
|
bound: Optional[Type] = None
|
||||||
|
variance: Variance = Variance.INVARIANT
|
||||||
|
if "bound" in call.keywords:
|
||||||
|
bound_type: p.MidasType = self._parse_type_from_expr(
|
||||||
|
call.keywords["bound"]
|
||||||
|
)
|
||||||
|
bound = self.resolve_type_expr(bound_type)
|
||||||
|
|
||||||
|
if is_kw_true("covariant"):
|
||||||
|
variance = Variance.COVARIANT
|
||||||
|
|
||||||
|
if is_kw_true("contravariant"):
|
||||||
|
if variance == Variance.COVARIANT:
|
||||||
|
self.reporter.warning(
|
||||||
|
call.keywords["contravariant"].location,
|
||||||
|
"TypeVar cannot be covariant and contravariant at the same time. Marked as invariant",
|
||||||
|
)
|
||||||
|
variance = Variance.INVARIANT
|
||||||
|
else:
|
||||||
|
variance = Variance.CONTRAVARIANT
|
||||||
|
var: TypeVar = TypeVar(name=name, bound=bound, variance=variance)
|
||||||
|
self.types.define_type(name, var)
|
||||||
|
return var
|
||||||
|
|
||||||
|
case _:
|
||||||
|
self.reporter.warning(
|
||||||
|
call.location, "Invalid usage of 'TypeVar', skipping"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_type_from_expr(self, expr: p.Expr) -> p.MidasType:
|
||||||
|
location: Location = expr.location
|
||||||
|
parser = PythonParser()
|
||||||
|
match expr:
|
||||||
|
case p.LiteralExpr(value=str() as value):
|
||||||
|
node: ast.Expression = ast.parse(value, mode="eval")
|
||||||
|
return parser._parse_type(node.body)
|
||||||
|
case p.VariableExpr(name=name):
|
||||||
|
return p.BaseType(location=location, base=name, param=None)
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -130,6 +130,19 @@ class TypesRegistry:
|
|||||||
case (_, TopType()):
|
case (_, TopType()):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
case (_, UnknownType()):
|
||||||
|
return True
|
||||||
|
|
||||||
|
case (TypeVar(bound=bound), _):
|
||||||
|
if bound is None:
|
||||||
|
return False
|
||||||
|
return self.is_subtype(bound, type2)
|
||||||
|
|
||||||
|
case (_, TypeVar(bound=bound)):
|
||||||
|
if bound is None:
|
||||||
|
return True
|
||||||
|
return self.is_subtype(type1, bound)
|
||||||
|
|
||||||
case (AliasType(type=base1), _):
|
case (AliasType(type=base1), _):
|
||||||
return self.is_subtype(base1, type2)
|
return self.is_subtype(base1, type2)
|
||||||
|
|
||||||
@@ -147,11 +160,6 @@ class TypesRegistry:
|
|||||||
case (Function(), Function()):
|
case (Function(), Function()):
|
||||||
return self.is_func_subtype(type1, type2)
|
return self.is_func_subtype(type1, type2)
|
||||||
|
|
||||||
case (TypeVar(bound=bound), _):
|
|
||||||
if bound is None:
|
|
||||||
return False
|
|
||||||
return self.is_subtype(bound, type2)
|
|
||||||
|
|
||||||
case (ConstraintType(type=base1), _):
|
case (ConstraintType(type=base1), _):
|
||||||
return self.is_subtype(base1, type2)
|
return self.is_subtype(base1, type2)
|
||||||
|
|
||||||
@@ -173,6 +181,10 @@ class TypesRegistry:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# TODO: verify legitimacy
|
||||||
|
case (AppliedType(body=body), _):
|
||||||
|
return self.is_subtype(body, type2)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# TODO: verify the logic in here
|
# TODO: verify the logic in here
|
||||||
@@ -389,6 +401,12 @@ class TypesRegistry:
|
|||||||
)
|
)
|
||||||
return self.lookup_member(base, member_name)
|
return self.lookup_member(base, member_name)
|
||||||
|
|
||||||
|
case ConstraintType(type=base):
|
||||||
|
return self.lookup_member(base, member_name)
|
||||||
|
|
||||||
|
case TypeVar(bound=bound) if bound is not None:
|
||||||
|
return self.lookup_member(bound, member_name)
|
||||||
|
|
||||||
case UnknownType():
|
case UnknownType():
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
|
|||||||
@@ -1,27 +1,64 @@
|
|||||||
import ast
|
import ast
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TextIO
|
from typing import TextIO
|
||||||
|
|
||||||
|
import black
|
||||||
import click
|
import click
|
||||||
|
from watchdog.events import DirModifiedEvent, FileModifiedEvent, FileSystemEventHandler
|
||||||
|
from watchdog.observers import Observer
|
||||||
|
|
||||||
from midas.checker.checker import TypeChecker
|
from midas.checker.checker import TypeChecker
|
||||||
from midas.generator.stubs import StubsGenerator
|
from midas.generator.stubs import StubsGenerator
|
||||||
|
|
||||||
|
|
||||||
@click.command(help="Generate stubs from Midas definitions")
|
def generate_stubs(in_path: Path, out_path: Path):
|
||||||
@click.argument("file", type=click.File("r"))
|
|
||||||
@click.option("-o", "--output", type=click.File("w"), default="-")
|
|
||||||
def stubs(
|
|
||||||
file: TextIO,
|
|
||||||
output: TextIO,
|
|
||||||
):
|
|
||||||
source_path: Path = Path(file.name).resolve()
|
|
||||||
|
|
||||||
checker = TypeChecker()
|
checker = TypeChecker()
|
||||||
checker.import_midas(source_path)
|
checker.import_midas(in_path)
|
||||||
|
|
||||||
generator = StubsGenerator(checker.types)
|
generator = StubsGenerator(checker.types)
|
||||||
module: ast.Module = generator.generate_stubs()
|
module: ast.Module = generator.generate_stubs()
|
||||||
module = ast.fix_missing_locations(module)
|
module = ast.fix_missing_locations(module)
|
||||||
|
|
||||||
output.write(ast.unparse(module))
|
output: str = ast.unparse(module)
|
||||||
|
output = black.format_str(output, mode=black.Mode(is_pyi=True))
|
||||||
|
|
||||||
|
out_path.write_text(output)
|
||||||
|
|
||||||
|
|
||||||
|
class Handler(FileSystemEventHandler):
|
||||||
|
def __init__(self, in_path: Path, out_path: Path) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.in_path: Path = in_path
|
||||||
|
self.out_path: Path = out_path
|
||||||
|
|
||||||
|
def on_modified(self, event: DirModifiedEvent | FileModifiedEvent) -> None:
|
||||||
|
generate_stubs(self.in_path, self.out_path)
|
||||||
|
|
||||||
|
|
||||||
|
@click.command(help="Generate stubs from Midas definitions")
|
||||||
|
@click.argument("file", type=click.File("r"))
|
||||||
|
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||||
|
@click.option("-w", "--watch", is_flag=True)
|
||||||
|
def stubs(
|
||||||
|
file: TextIO,
|
||||||
|
output: TextIO,
|
||||||
|
watch: bool,
|
||||||
|
):
|
||||||
|
source_path: Path = Path(file.name).resolve()
|
||||||
|
out_path: Path = Path(output.name).resolve()
|
||||||
|
generate_stubs(source_path, out_path)
|
||||||
|
|
||||||
|
if watch:
|
||||||
|
print(f"Watching {source_path}...")
|
||||||
|
print("Press CTRL+C to stop")
|
||||||
|
handler = Handler(source_path, out_path)
|
||||||
|
observer = Observer()
|
||||||
|
observer.schedule(handler, str(source_path))
|
||||||
|
observer.start()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
observer.stop()
|
||||||
|
observer.join()
|
||||||
|
|||||||
@@ -322,8 +322,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
self._make_cast_asserts(src_location, expr, base)
|
self._make_cast_asserts(src_location, expr, base)
|
||||||
self._make_constraint_assert(src_location, expr, constraint)
|
self._make_constraint_assert(src_location, expr, constraint)
|
||||||
|
|
||||||
case TypeVar():
|
case TypeVar(bound=bound):
|
||||||
raise RuntimeError("Unexpected TypeVar")
|
# TODO: check with type from arguments / use call-site context
|
||||||
|
if bound is not None:
|
||||||
|
self._make_cast_asserts(src_location, expr, bound)
|
||||||
|
|
||||||
case (
|
case (
|
||||||
TopType()
|
TopType()
|
||||||
|
|||||||
@@ -39,6 +39,18 @@ class StubsGenerator:
|
|||||||
self.stubs = []
|
self.stubs = []
|
||||||
self.typing_imports = set()
|
self.typing_imports = set()
|
||||||
for name, type in self.types._types.items():
|
for name, type in self.types._types.items():
|
||||||
|
# Skip builtin types, not just based on name so the user can override
|
||||||
|
# TODO: check if added members on builtin type
|
||||||
|
match type:
|
||||||
|
case BaseType(name=name_) if name == name_:
|
||||||
|
continue
|
||||||
|
case GenericType(
|
||||||
|
name=name1,
|
||||||
|
body=BaseType(name=name2),
|
||||||
|
) if (
|
||||||
|
name == name1 == name2
|
||||||
|
):
|
||||||
|
continue
|
||||||
self.generate_stub(name, type)
|
self.generate_stub(name, type)
|
||||||
|
|
||||||
imports = [
|
imports = [
|
||||||
@@ -115,6 +127,12 @@ class StubsGenerator:
|
|||||||
body_subsitutions | substitutions,
|
body_subsitutions | substitutions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case ConstraintType(type=base):
|
||||||
|
return self.get_bases(base)
|
||||||
|
|
||||||
|
case TypeVar(bound=bound) if bound is not None:
|
||||||
|
return [self.dump_type(bound)], {}
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
return [], {}
|
return [], {}
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,11 @@ authors = [
|
|||||||
{ name = "Louis Heredero", email = "louis.heredero@students.hevs.ch" },
|
{ name = "Louis Heredero", email = "louis.heredero@students.hevs.ch" },
|
||||||
]
|
]
|
||||||
classifiers = ["Programming Language :: Python :: 3"]
|
classifiers = ["Programming Language :: Python :: 3"]
|
||||||
dependencies = ["click>=8.4.1"]
|
dependencies = [
|
||||||
|
"black>=26.5.1",
|
||||||
|
"click>=8.4.1",
|
||||||
|
"watchdog>=6.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
Homepage = "https://git.kbk28.ch/HEL/midas"
|
Homepage = "https://git.kbk28.ch/HEL/midas"
|
||||||
|
|||||||
Reference in New Issue
Block a user