diff --git a/examples/02_demonstration/demo.py b/examples/02_demonstration/demo.py index c4ec322..2d7d5be 100644 --- a/examples/02_demonstration/demo.py +++ b/examples/02_demonstration/demo.py @@ -1,6 +1,8 @@ -from typing import TypeVar, cast +from typing import TypeVar -from demo_stubs import CHF, EUR, USD, Currency, Price, Discount +from demo_stubs import CHF, EUR, USD, Currency, Discount, Price + +from midas.typing import cast, unsafe_cast T = TypeVar("T", bound=Currency) @@ -28,3 +30,6 @@ discounted = apply_discount( ) print(f"Discounted: CHF {discounted}") + +large_data = [i * 10 for i in range(100)] +prices = unsafe_cast(list[Price[EUR]], large_data) diff --git a/gen/python.py b/gen/python.py index 4af901a..df83e6f 100644 --- a/gen/python.py +++ b/gen/python.py @@ -145,6 +145,7 @@ class LogicalExpr: class CastExpr: type: MidasType expr: Expr + unsafe: bool class TernaryExpr: diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 1c75a44..680fd79 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -757,9 +757,10 @@ class PythonAstPrinter( self._write_line("type") with self._child_level(single=True): expr.type.accept(self) - self._write_line("expr", last=True) + self._write_line("expr") with self._child_level(single=True): expr.expr.accept(self) + self._write_line(f"unsafe: {expr.unsafe}", last=True) def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: self._write_line("TernaryExpr") diff --git a/midas/ast/python.py b/midas/ast/python.py index 7770de6..20d7279 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -350,6 +350,7 @@ class LogicalExpr(Expr): class CastExpr(Expr): type: MidasType expr: Expr + unsafe: bool def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_cast_expr(self) diff --git a/midas/generator/generator.py b/midas/generator/generator.py index 0fba91e..0af3fcd 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -133,7 +133,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr: expr2: ast.expr = expr.expr.accept(self) - if expr in self._typed_ast.evaluated_casts: + if expr in self._typed_ast.evaluated_casts or expr.unsafe: return expr2 alias: ast.expr = self._make_alias(expr2) diff --git a/midas/parser/python.py b/midas/parser/python.py index 4110feb..7839f52 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -49,6 +49,7 @@ class UnsupportedSyntaxError(Exception): class PythonParser: CAST_FUNCTION = "cast" + UNSAFE_CAST_FUNCTION = "unsafe_cast" def parse_module(self, node: ast.Module) -> list[Stmt]: statements: list[Stmt] = [] @@ -423,6 +424,9 @@ class PythonParser: case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)): return self.parse_cast(node) + case ast.Call(func=ast.Name(id=self.UNSAFE_CAST_FUNCTION)): + return self.parse_cast(node) + case ast.Call(): return self.parse_call(node) @@ -527,16 +531,19 @@ class PythonParser: return expr def parse_cast(self, node: ast.Call) -> CastExpr: + assert isinstance(node.func, ast.Name) + func: str = node.func.id match node: case ast.Call(args=[type, expr], keywords=[]): return CastExpr( location=Location.from_ast(node), type=self._parse_type(type), expr=self.parse_expr(expr), + unsafe=func == self.UNSAFE_CAST_FUNCTION, ) case _: raise InvalidSyntaxError( - f"Invalid call to {self.CAST_FUNCTION}, expected type and expression" + f"Invalid call to {func}, expected type and expression" ) def parse_call(self, node: ast.Call) -> CallExpr: diff --git a/midas/typing.py b/midas/typing.py new file mode 100644 index 0000000..9c5b407 --- /dev/null +++ b/midas/typing.py @@ -0,0 +1,34 @@ +from typing import cast as typing_cast + +cast = typing_cast +"""### Midas documentation +Cast a value to a type. + +- **Compile-time**: tells the type checker that the return value has the designated type. +- **Run-time**: generates assertions to ensure the value can be interpreted as the given type. + +--- +
+
+
+ +_**Internal Python documentation**_ +""" + + +unsafe_cast = typing_cast +"""### Midas documentation +Cast a value to a type. + +- **Compile-time**: tells the type checker that the return value has the designated type. +- **Run-time**: - + +This operation is unsound, use at your own risk! + +--- +
+
+
+ +_**Internal Python documentation**_ +""" diff --git a/tests/cases/checker/04_custom_types.py.ref.json b/tests/cases/checker/04_custom_types.py.ref.json index d502a97..01177d9 100644 --- a/tests/cases/checker/04_custom_types.py.ref.json +++ b/tests/cases/checker/04_custom_types.py.ref.json @@ -29,7 +29,8 @@ "expr": { "_type": "LiteralExpr", "value": 123.45 - } + }, + "unsafe": false }, "type": { "name": "Meter", @@ -66,7 +67,8 @@ "expr": { "_type": "LiteralExpr", "value": 6.7 - } + }, + "unsafe": false }, "type": { "name": "Second", diff --git a/tests/serializer/python.py b/tests/serializer/python.py index 038b496..3904739 100644 --- a/tests/serializer/python.py +++ b/tests/serializer/python.py @@ -263,6 +263,7 @@ class PythonAstJsonSerializer( "_type": "CastExpr", "type": expr.type.accept(self), "expr": expr.expr.accept(self), + "unsafe": expr.unsafe, } def visit_ternary_expr(self, expr: TernaryExpr) -> dict: