diff --git a/examples/02_demonstration/demo.py b/examples/02_demonstration/demo.py
index c4ec322..2d7d5be 100644
--- a/examples/02_demonstration/demo.py
+++ b/examples/02_demonstration/demo.py
@@ -1,6 +1,8 @@
-from typing import TypeVar, cast
+from typing import TypeVar
-from demo_stubs import CHF, EUR, USD, Currency, Price, Discount
+from demo_stubs import CHF, EUR, USD, Currency, Discount, Price
+
+from midas.typing import cast, unsafe_cast
T = TypeVar("T", bound=Currency)
@@ -28,3 +30,6 @@ discounted = apply_discount(
)
print(f"Discounted: CHF {discounted}")
+
+large_data = [i * 10 for i in range(100)]
+prices = unsafe_cast(list[Price[EUR]], large_data)
diff --git a/gen/python.py b/gen/python.py
index 4af901a..df83e6f 100644
--- a/gen/python.py
+++ b/gen/python.py
@@ -145,6 +145,7 @@ class LogicalExpr:
class CastExpr:
type: MidasType
expr: Expr
+ unsafe: bool
class TernaryExpr:
diff --git a/midas/ast/printer.py b/midas/ast/printer.py
index 1c75a44..680fd79 100644
--- a/midas/ast/printer.py
+++ b/midas/ast/printer.py
@@ -757,9 +757,10 @@ class PythonAstPrinter(
self._write_line("type")
with self._child_level(single=True):
expr.type.accept(self)
- self._write_line("expr", last=True)
+ self._write_line("expr")
with self._child_level(single=True):
expr.expr.accept(self)
+ self._write_line(f"unsafe: {expr.unsafe}", last=True)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
self._write_line("TernaryExpr")
diff --git a/midas/ast/python.py b/midas/ast/python.py
index 7770de6..20d7279 100644
--- a/midas/ast/python.py
+++ b/midas/ast/python.py
@@ -350,6 +350,7 @@ class LogicalExpr(Expr):
class CastExpr(Expr):
type: MidasType
expr: Expr
+ unsafe: bool
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_cast_expr(self)
diff --git a/midas/generator/generator.py b/midas/generator/generator.py
index 0fba91e..0af3fcd 100644
--- a/midas/generator/generator.py
+++ b/midas/generator/generator.py
@@ -133,7 +133,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
expr2: ast.expr = expr.expr.accept(self)
- if expr in self._typed_ast.evaluated_casts:
+ if expr in self._typed_ast.evaluated_casts or expr.unsafe:
return expr2
alias: ast.expr = self._make_alias(expr2)
diff --git a/midas/parser/python.py b/midas/parser/python.py
index 4110feb..7839f52 100644
--- a/midas/parser/python.py
+++ b/midas/parser/python.py
@@ -49,6 +49,7 @@ class UnsupportedSyntaxError(Exception):
class PythonParser:
CAST_FUNCTION = "cast"
+ UNSAFE_CAST_FUNCTION = "unsafe_cast"
def parse_module(self, node: ast.Module) -> list[Stmt]:
statements: list[Stmt] = []
@@ -423,6 +424,9 @@ class PythonParser:
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
return self.parse_cast(node)
+ case ast.Call(func=ast.Name(id=self.UNSAFE_CAST_FUNCTION)):
+ return self.parse_cast(node)
+
case ast.Call():
return self.parse_call(node)
@@ -527,16 +531,19 @@ class PythonParser:
return expr
def parse_cast(self, node: ast.Call) -> CastExpr:
+ assert isinstance(node.func, ast.Name)
+ func: str = node.func.id
match node:
case ast.Call(args=[type, expr], keywords=[]):
return CastExpr(
location=Location.from_ast(node),
type=self._parse_type(type),
expr=self.parse_expr(expr),
+ unsafe=func == self.UNSAFE_CAST_FUNCTION,
)
case _:
raise InvalidSyntaxError(
- f"Invalid call to {self.CAST_FUNCTION}, expected type and expression"
+ f"Invalid call to {func}, expected type and expression"
)
def parse_call(self, node: ast.Call) -> CallExpr:
diff --git a/midas/typing.py b/midas/typing.py
new file mode 100644
index 0000000..9c5b407
--- /dev/null
+++ b/midas/typing.py
@@ -0,0 +1,34 @@
+from typing import cast as typing_cast
+
+cast = typing_cast
+"""### Midas documentation
+Cast a value to a type.
+
+- **Compile-time**: tells the type checker that the return value has the designated type.
+- **Run-time**: generates assertions to ensure the value can be interpreted as the given type.
+
+---
+
+
+
+
+_**Internal Python documentation**_
+"""
+
+
+unsafe_cast = typing_cast
+"""### Midas documentation
+Cast a value to a type.
+
+- **Compile-time**: tells the type checker that the return value has the designated type.
+- **Run-time**: -
+
+This operation is unsound, use at your own risk!
+
+---
+
+
+
+
+_**Internal Python documentation**_
+"""
diff --git a/tests/cases/checker/04_custom_types.py.ref.json b/tests/cases/checker/04_custom_types.py.ref.json
index d502a97..01177d9 100644
--- a/tests/cases/checker/04_custom_types.py.ref.json
+++ b/tests/cases/checker/04_custom_types.py.ref.json
@@ -29,7 +29,8 @@
"expr": {
"_type": "LiteralExpr",
"value": 123.45
- }
+ },
+ "unsafe": false
},
"type": {
"name": "Meter",
@@ -66,7 +67,8 @@
"expr": {
"_type": "LiteralExpr",
"value": 6.7
- }
+ },
+ "unsafe": false
},
"type": {
"name": "Second",
diff --git a/tests/serializer/python.py b/tests/serializer/python.py
index 038b496..3904739 100644
--- a/tests/serializer/python.py
+++ b/tests/serializer/python.py
@@ -263,6 +263,7 @@ class PythonAstJsonSerializer(
"_type": "CastExpr",
"type": expr.type.accept(self),
"expr": expr.expr.accept(self),
+ "unsafe": expr.unsafe,
}
def visit_ternary_expr(self, expr: TernaryExpr) -> dict: