Merge pull request 'Unsafe cast' (#21) from feat/unsafe-cast into main

Reviewed-on: #21
This commit was merged in pull request #21.
This commit is contained in:
2026-06-24 12:00:03 +00:00
9 changed files with 59 additions and 7 deletions

View File

@@ -1,6 +1,8 @@
from typing import TypeVar, cast from typing import TypeVar
from demo_stubs import CHF, EUR, USD, Currency, Price, Discount from demo_stubs import CHF, EUR, USD, Currency, Discount, Price
from midas.typing import cast, unsafe_cast
T = TypeVar("T", bound=Currency) T = TypeVar("T", bound=Currency)
@@ -28,3 +30,6 @@ discounted = apply_discount(
) )
print(f"Discounted: CHF {discounted}") print(f"Discounted: CHF {discounted}")
large_data = [i * 10 for i in range(100)]
prices = unsafe_cast(list[Price[EUR]], large_data)

View File

@@ -145,6 +145,7 @@ class LogicalExpr:
class CastExpr: class CastExpr:
type: MidasType type: MidasType
expr: Expr expr: Expr
unsafe: bool
class TernaryExpr: class TernaryExpr:

View File

@@ -757,9 +757,10 @@ class PythonAstPrinter(
self._write_line("type") self._write_line("type")
with self._child_level(single=True): with self._child_level(single=True):
expr.type.accept(self) expr.type.accept(self)
self._write_line("expr", last=True) self._write_line("expr")
with self._child_level(single=True): with self._child_level(single=True):
expr.expr.accept(self) expr.expr.accept(self)
self._write_line(f"unsafe: {expr.unsafe}", last=True)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
self._write_line("TernaryExpr") self._write_line("TernaryExpr")

View File

@@ -350,6 +350,7 @@ class LogicalExpr(Expr):
class CastExpr(Expr): class CastExpr(Expr):
type: MidasType type: MidasType
expr: Expr expr: Expr
unsafe: bool
def accept(self, visitor: Expr.Visitor[T]) -> T: def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_cast_expr(self) return visitor.visit_cast_expr(self)

View File

@@ -133,7 +133,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr: def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
expr2: ast.expr = expr.expr.accept(self) expr2: ast.expr = expr.expr.accept(self)
if expr in self._typed_ast.evaluated_casts: if expr in self._typed_ast.evaluated_casts or expr.unsafe:
return expr2 return expr2
alias: ast.expr = self._make_alias(expr2) alias: ast.expr = self._make_alias(expr2)

View File

@@ -49,6 +49,7 @@ class UnsupportedSyntaxError(Exception):
class PythonParser: class PythonParser:
CAST_FUNCTION = "cast" CAST_FUNCTION = "cast"
UNSAFE_CAST_FUNCTION = "unsafe_cast"
def parse_module(self, node: ast.Module) -> list[Stmt]: def parse_module(self, node: ast.Module) -> list[Stmt]:
statements: list[Stmt] = [] statements: list[Stmt] = []
@@ -423,6 +424,9 @@ class PythonParser:
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)): case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
return self.parse_cast(node) return self.parse_cast(node)
case ast.Call(func=ast.Name(id=self.UNSAFE_CAST_FUNCTION)):
return self.parse_cast(node)
case ast.Call(): case ast.Call():
return self.parse_call(node) return self.parse_call(node)
@@ -527,16 +531,19 @@ class PythonParser:
return expr return expr
def parse_cast(self, node: ast.Call) -> CastExpr: def parse_cast(self, node: ast.Call) -> CastExpr:
assert isinstance(node.func, ast.Name)
func: str = node.func.id
match node: match node:
case ast.Call(args=[type, expr], keywords=[]): case ast.Call(args=[type, expr], keywords=[]):
return CastExpr( return CastExpr(
location=Location.from_ast(node), location=Location.from_ast(node),
type=self._parse_type(type), type=self._parse_type(type),
expr=self.parse_expr(expr), expr=self.parse_expr(expr),
unsafe=func == self.UNSAFE_CAST_FUNCTION,
) )
case _: case _:
raise InvalidSyntaxError( raise InvalidSyntaxError(
f"Invalid call to {self.CAST_FUNCTION}, expected type and expression" f"Invalid call to {func}, expected type and expression"
) )
def parse_call(self, node: ast.Call) -> CallExpr: def parse_call(self, node: ast.Call) -> CallExpr:

34
midas/typing.py Normal file
View File

@@ -0,0 +1,34 @@
from typing import cast as typing_cast
cast = typing_cast
"""### Midas documentation
Cast a value to a type.
- **Compile-time**: tells the type checker that the return value has the designated type.
- **Run-time**: generates assertions to ensure the value can be interpreted as the given type.
---
<br>
<br>
<br>
_**Internal Python documentation**_
"""
unsafe_cast = typing_cast
"""### Midas documentation
Cast a value to a type.
- **Compile-time**: tells the type checker that the return value has the designated type.
- **Run-time**: -
This operation is unsound, use at your own risk!
---
<br>
<br>
<br>
_**Internal Python documentation**_
"""

View File

@@ -29,7 +29,8 @@
"expr": { "expr": {
"_type": "LiteralExpr", "_type": "LiteralExpr",
"value": 123.45 "value": 123.45
} },
"unsafe": false
}, },
"type": { "type": {
"name": "Meter", "name": "Meter",
@@ -66,7 +67,8 @@
"expr": { "expr": {
"_type": "LiteralExpr", "_type": "LiteralExpr",
"value": 6.7 "value": 6.7
} },
"unsafe": false
}, },
"type": { "type": {
"name": "Second", "name": "Second",

View File

@@ -263,6 +263,7 @@ class PythonAstJsonSerializer(
"_type": "CastExpr", "_type": "CastExpr",
"type": expr.type.accept(self), "type": expr.type.accept(self),
"expr": expr.expr.accept(self), "expr": expr.expr.accept(self),
"unsafe": expr.unsafe,
} }
def visit_ternary_expr(self, expr: TernaryExpr) -> dict: def visit_ternary_expr(self, expr: TernaryExpr) -> dict: