13 Commits

9 changed files with 209 additions and 21 deletions

View File

@@ -0,0 +1,10 @@
predicate in_range(min: float, max: float)(v: float) = min <= v & v <= max
predicate is_ratio = in_range(0, 1)
type Money = float
type Price[T <: Money] = T where _ >= 0
type EUR = Money
type USD = Money
type Reduction = float where is_ratio(_)

View File

@@ -0,0 +1,19 @@
from typing import TypeVar, cast
from demo_stubs import EUR, USD, Money, Price, Reduction
T = TypeVar("T", bound=Money)
def apply_reduction(amount: Price[T], reduction: Reduction) -> Price[T]:
return cast(Price[T], (1.0 - reduction) * amount)
a1 = cast(Price[EUR], 3.2)
a2 = cast(Price[USD], 10.4)
r1: Reduction = cast(Reduction, 0.2)
print(apply_reduction(a1, r1))
print(apply_reduction(a2, r1))
# a3 = a1 + a2

View File

@@ -54,6 +54,11 @@ class Preamble(Environment):
returns=self._list_of(map_out), # TODO: replace with Iterable[U]
type_vars=[map_in, map_out],
)
self._def_function(
name="input",
pos=[Param("prompt", TopType(), required=False)],
returns=self._types.get_type("str"),
)
def _list_of(self, item_type: Type) -> Type:
return self._types.apply_generic(self._types.get_type("list"), [item_type])

View File

@@ -22,8 +22,10 @@ from midas.checker.types import (
GenericType,
OverloadedFunction,
Type,
TypeVar,
UnitType,
UnknownType,
Variance,
unfold_type,
)
from midas.checker.unifier import Unifier
@@ -229,7 +231,8 @@ class PythonTyper(
)
pos += 1
for arg in pos_args + args + kw_args:
all_args: list[Function.Argument] = pos_args + args + kw_args
for arg in all_args:
env.define(arg.name, arg.type)
returns_hint: Optional[Type] = None
@@ -270,12 +273,25 @@ class PythonTyper(
returns = inferred_return
# TODO: handle *args and **kwargs sinks
function: Function = Function(
function: Type = Function(
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns,
)
generic_params: list[TypeVar] = []
all_types: list[Type] = [arg.type for arg in all_args] + [returns]
for type in all_types:
if isinstance(type, TypeVar):
if type not in generic_params:
generic_params.append(type)
if len(generic_params) != 0:
function = GenericType(
name=stmt.name,
params=generic_params,
body=function,
)
self.env.define(stmt.name, function)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
@@ -453,6 +469,10 @@ class PythonTyper(
return result or UnknownType()
def visit_call_expr(self, expr: p.CallExpr) -> Type:
match expr.callee:
case p.VariableExpr(name="TypeVar"):
return self.define_typevar(expr) or UnknownType()
callee: Type = self.type_of(expr.callee)
positional: list[TypedExpr] = [
(arg, self.type_of(arg)) for arg in expr.arguments
@@ -518,6 +538,7 @@ class PythonTyper(
return UnknownType()
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
_ = self.type_of(expr.expr)
return self.resolve_type_expr(expr.type)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
@@ -1033,3 +1054,57 @@ class PythonTyper(
report_errors=False,
)
return result
def define_typevar(self, call: p.CallExpr) -> Optional[TypeVar]:
def is_kw_true(name: str) -> bool:
match call.keywords.get(name):
case p.LiteralExpr(value=True):
return True
case _:
return False
match call:
case p.CallExpr(
arguments=[p.LiteralExpr(value=str() as name)],
):
bound: Optional[Type] = None
variance: Variance = Variance.INVARIANT
if "bound" in call.keywords:
bound_type: p.MidasType = self._parse_type_from_expr(
call.keywords["bound"]
)
bound = self.resolve_type_expr(bound_type)
if is_kw_true("covariant"):
variance = Variance.COVARIANT
if is_kw_true("contravariant"):
if variance == Variance.COVARIANT:
self.reporter.warning(
call.keywords["contravariant"].location,
"TypeVar cannot be covariant and contravariant at the same time. Marked as invariant",
)
variance = Variance.INVARIANT
else:
variance = Variance.CONTRAVARIANT
var: TypeVar = TypeVar(name=name, bound=bound, variance=variance)
self.types.define_type(name, var)
return var
case _:
self.reporter.warning(
call.location, "Invalid usage of 'TypeVar', skipping"
)
return None
def _parse_type_from_expr(self, expr: p.Expr) -> p.MidasType:
location: Location = expr.location
parser = PythonParser()
match expr:
case p.LiteralExpr(value=str() as value):
node: ast.Expression = ast.parse(value, mode="eval")
return parser._parse_type(node.body)
case p.VariableExpr(name=name):
return p.BaseType(location=location, base=name, param=None)
case _:
raise NotImplementedError

View File

@@ -130,6 +130,19 @@ class TypesRegistry:
case (_, TopType()):
return True
case (_, UnknownType()):
return True
case (TypeVar(bound=bound), _):
if bound is None:
return False
return self.is_subtype(bound, type2)
case (_, TypeVar(bound=bound)):
if bound is None:
return True
return self.is_subtype(type1, bound)
case (AliasType(type=base1), _):
return self.is_subtype(base1, type2)
@@ -147,11 +160,6 @@ class TypesRegistry:
case (Function(), Function()):
return self.is_func_subtype(type1, type2)
case (TypeVar(bound=bound), _):
if bound is None:
return False
return self.is_subtype(bound, type2)
case (ConstraintType(type=base1), _):
return self.is_subtype(base1, type2)
@@ -173,6 +181,10 @@ class TypesRegistry:
return False
return True
# TODO: verify legitimacy
case (AppliedType(body=body), _):
return self.is_subtype(body, type2)
return False
# TODO: verify the logic in here
@@ -389,6 +401,12 @@ class TypesRegistry:
)
return self.lookup_member(base, member_name)
case ConstraintType(type=base):
return self.lookup_member(base, member_name)
case TypeVar(bound=bound) if bound is not None:
return self.lookup_member(bound, member_name)
case UnknownType():
return UnknownType()

View File

@@ -1,27 +1,64 @@
import ast
import time
from pathlib import Path
from typing import TextIO
import black
import click
from watchdog.events import DirModifiedEvent, FileModifiedEvent, FileSystemEventHandler
from watchdog.observers import Observer
from midas.checker.checker import TypeChecker
from midas.generator.stubs import StubsGenerator
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
def stubs(
file: TextIO,
output: TextIO,
):
source_path: Path = Path(file.name).resolve()
def generate_stubs(in_path: Path, out_path: Path):
checker = TypeChecker()
checker.import_midas(source_path)
checker.import_midas(in_path)
generator = StubsGenerator(checker.types)
module: ast.Module = generator.generate_stubs()
module = ast.fix_missing_locations(module)
output.write(ast.unparse(module))
output: str = ast.unparse(module)
output = black.format_str(output, mode=black.Mode(is_pyi=True))
out_path.write_text(output)
class Handler(FileSystemEventHandler):
def __init__(self, in_path: Path, out_path: Path) -> None:
super().__init__()
self.in_path: Path = in_path
self.out_path: Path = out_path
def on_modified(self, event: DirModifiedEvent | FileModifiedEvent) -> None:
generate_stubs(self.in_path, self.out_path)
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
@click.option("-w", "--watch", is_flag=True)
def stubs(
file: TextIO,
output: TextIO,
watch: bool,
):
source_path: Path = Path(file.name).resolve()
out_path: Path = Path(output.name).resolve()
generate_stubs(source_path, out_path)
if watch:
print(f"Watching {source_path}...")
print("Press CTRL+C to stop")
handler = Handler(source_path, out_path)
observer = Observer()
observer.schedule(handler, str(source_path))
observer.start()
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
observer.stop()
observer.join()

View File

@@ -322,8 +322,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._make_cast_asserts(src_location, expr, base)
self._make_constraint_assert(src_location, expr, constraint)
case TypeVar():
raise RuntimeError("Unexpected TypeVar")
case TypeVar(bound=bound):
# TODO: check with type from arguments / use call-site context
if bound is not None:
self._make_cast_asserts(src_location, expr, bound)
case (
TopType()

View File

@@ -39,6 +39,18 @@ class StubsGenerator:
self.stubs = []
self.typing_imports = set()
for name, type in self.types._types.items():
# Skip builtin types, not just based on name so the user can override
# TODO: check if added members on builtin type
match type:
case BaseType(name=name_) if name == name_:
continue
case GenericType(
name=name1,
body=BaseType(name=name2),
) if (
name == name1 == name2
):
continue
self.generate_stub(name, type)
imports = [
@@ -115,6 +127,12 @@ class StubsGenerator:
body_subsitutions | substitutions,
)
case ConstraintType(type=base):
return self.get_bases(base)
case TypeVar(bound=bound) if bound is not None:
return [self.dump_type(bound)], {}
case _:
return [], {}

View File

@@ -8,7 +8,11 @@ authors = [
{ name = "Louis Heredero", email = "louis.heredero@students.hevs.ch" },
]
classifiers = ["Programming Language :: Python :: 3"]
dependencies = ["click>=8.4.1"]
dependencies = [
"black>=26.5.1",
"click>=8.4.1",
"watchdog>=6.0.0",
]
[project.urls]
Homepage = "https://git.kbk28.ch/HEL/midas"