Merge pull request 'Unsafe cast' (#21) from feat/unsafe-cast into main
Reviewed-on: #21
This commit was merged in pull request #21.
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
from typing import TypeVar, cast
|
from typing import TypeVar
|
||||||
|
|
||||||
from demo_stubs import CHF, EUR, USD, Currency, Price, Discount
|
from demo_stubs import CHF, EUR, USD, Currency, Discount, Price
|
||||||
|
|
||||||
|
from midas.typing import cast, unsafe_cast
|
||||||
|
|
||||||
T = TypeVar("T", bound=Currency)
|
T = TypeVar("T", bound=Currency)
|
||||||
|
|
||||||
@@ -28,3 +30,6 @@ discounted = apply_discount(
|
|||||||
)
|
)
|
||||||
|
|
||||||
print(f"Discounted: CHF {discounted}")
|
print(f"Discounted: CHF {discounted}")
|
||||||
|
|
||||||
|
large_data = [i * 10 for i in range(100)]
|
||||||
|
prices = unsafe_cast(list[Price[EUR]], large_data)
|
||||||
|
|||||||
@@ -145,6 +145,7 @@ class LogicalExpr:
|
|||||||
class CastExpr:
|
class CastExpr:
|
||||||
type: MidasType
|
type: MidasType
|
||||||
expr: Expr
|
expr: Expr
|
||||||
|
unsafe: bool
|
||||||
|
|
||||||
|
|
||||||
class TernaryExpr:
|
class TernaryExpr:
|
||||||
|
|||||||
@@ -757,9 +757,10 @@ class PythonAstPrinter(
|
|||||||
self._write_line("type")
|
self._write_line("type")
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
expr.type.accept(self)
|
expr.type.accept(self)
|
||||||
self._write_line("expr", last=True)
|
self._write_line("expr")
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
expr.expr.accept(self)
|
expr.expr.accept(self)
|
||||||
|
self._write_line(f"unsafe: {expr.unsafe}", last=True)
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||||
self._write_line("TernaryExpr")
|
self._write_line("TernaryExpr")
|
||||||
|
|||||||
@@ -350,6 +350,7 @@ class LogicalExpr(Expr):
|
|||||||
class CastExpr(Expr):
|
class CastExpr(Expr):
|
||||||
type: MidasType
|
type: MidasType
|
||||||
expr: Expr
|
expr: Expr
|
||||||
|
unsafe: bool
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
return visitor.visit_cast_expr(self)
|
return visitor.visit_cast_expr(self)
|
||||||
|
|||||||
@@ -133,7 +133,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
||||||
expr2: ast.expr = expr.expr.accept(self)
|
expr2: ast.expr = expr.expr.accept(self)
|
||||||
|
|
||||||
if expr in self._typed_ast.evaluated_casts:
|
if expr in self._typed_ast.evaluated_casts or expr.unsafe:
|
||||||
return expr2
|
return expr2
|
||||||
|
|
||||||
alias: ast.expr = self._make_alias(expr2)
|
alias: ast.expr = self._make_alias(expr2)
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ class UnsupportedSyntaxError(Exception):
|
|||||||
|
|
||||||
class PythonParser:
|
class PythonParser:
|
||||||
CAST_FUNCTION = "cast"
|
CAST_FUNCTION = "cast"
|
||||||
|
UNSAFE_CAST_FUNCTION = "unsafe_cast"
|
||||||
|
|
||||||
def parse_module(self, node: ast.Module) -> list[Stmt]:
|
def parse_module(self, node: ast.Module) -> list[Stmt]:
|
||||||
statements: list[Stmt] = []
|
statements: list[Stmt] = []
|
||||||
@@ -423,6 +424,9 @@ class PythonParser:
|
|||||||
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
|
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
|
||||||
return self.parse_cast(node)
|
return self.parse_cast(node)
|
||||||
|
|
||||||
|
case ast.Call(func=ast.Name(id=self.UNSAFE_CAST_FUNCTION)):
|
||||||
|
return self.parse_cast(node)
|
||||||
|
|
||||||
case ast.Call():
|
case ast.Call():
|
||||||
return self.parse_call(node)
|
return self.parse_call(node)
|
||||||
|
|
||||||
@@ -527,16 +531,19 @@ class PythonParser:
|
|||||||
return expr
|
return expr
|
||||||
|
|
||||||
def parse_cast(self, node: ast.Call) -> CastExpr:
|
def parse_cast(self, node: ast.Call) -> CastExpr:
|
||||||
|
assert isinstance(node.func, ast.Name)
|
||||||
|
func: str = node.func.id
|
||||||
match node:
|
match node:
|
||||||
case ast.Call(args=[type, expr], keywords=[]):
|
case ast.Call(args=[type, expr], keywords=[]):
|
||||||
return CastExpr(
|
return CastExpr(
|
||||||
location=Location.from_ast(node),
|
location=Location.from_ast(node),
|
||||||
type=self._parse_type(type),
|
type=self._parse_type(type),
|
||||||
expr=self.parse_expr(expr),
|
expr=self.parse_expr(expr),
|
||||||
|
unsafe=func == self.UNSAFE_CAST_FUNCTION,
|
||||||
)
|
)
|
||||||
case _:
|
case _:
|
||||||
raise InvalidSyntaxError(
|
raise InvalidSyntaxError(
|
||||||
f"Invalid call to {self.CAST_FUNCTION}, expected type and expression"
|
f"Invalid call to {func}, expected type and expression"
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_call(self, node: ast.Call) -> CallExpr:
|
def parse_call(self, node: ast.Call) -> CallExpr:
|
||||||
|
|||||||
34
midas/typing.py
Normal file
34
midas/typing.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from typing import cast as typing_cast
|
||||||
|
|
||||||
|
cast = typing_cast
|
||||||
|
"""### Midas documentation
|
||||||
|
Cast a value to a type.
|
||||||
|
|
||||||
|
- **Compile-time**: tells the type checker that the return value has the designated type.
|
||||||
|
- **Run-time**: generates assertions to ensure the value can be interpreted as the given type.
|
||||||
|
|
||||||
|
---
|
||||||
|
<br>
|
||||||
|
<br>
|
||||||
|
<br>
|
||||||
|
|
||||||
|
_**Internal Python documentation**_
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
unsafe_cast = typing_cast
|
||||||
|
"""### Midas documentation
|
||||||
|
Cast a value to a type.
|
||||||
|
|
||||||
|
- **Compile-time**: tells the type checker that the return value has the designated type.
|
||||||
|
- **Run-time**: -
|
||||||
|
|
||||||
|
This operation is unsound, use at your own risk!
|
||||||
|
|
||||||
|
---
|
||||||
|
<br>
|
||||||
|
<br>
|
||||||
|
<br>
|
||||||
|
|
||||||
|
_**Internal Python documentation**_
|
||||||
|
"""
|
||||||
@@ -29,7 +29,8 @@
|
|||||||
"expr": {
|
"expr": {
|
||||||
"_type": "LiteralExpr",
|
"_type": "LiteralExpr",
|
||||||
"value": 123.45
|
"value": 123.45
|
||||||
}
|
},
|
||||||
|
"unsafe": false
|
||||||
},
|
},
|
||||||
"type": {
|
"type": {
|
||||||
"name": "Meter",
|
"name": "Meter",
|
||||||
@@ -66,7 +67,8 @@
|
|||||||
"expr": {
|
"expr": {
|
||||||
"_type": "LiteralExpr",
|
"_type": "LiteralExpr",
|
||||||
"value": 6.7
|
"value": 6.7
|
||||||
}
|
},
|
||||||
|
"unsafe": false
|
||||||
},
|
},
|
||||||
"type": {
|
"type": {
|
||||||
"name": "Second",
|
"name": "Second",
|
||||||
|
|||||||
@@ -263,6 +263,7 @@ class PythonAstJsonSerializer(
|
|||||||
"_type": "CastExpr",
|
"_type": "CastExpr",
|
||||||
"type": expr.type.accept(self),
|
"type": expr.type.accept(self),
|
||||||
"expr": expr.expr.accept(self),
|
"expr": expr.expr.accept(self),
|
||||||
|
"unsafe": expr.unsafe,
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: TernaryExpr) -> dict:
|
def visit_ternary_expr(self, expr: TernaryExpr) -> dict:
|
||||||
|
|||||||
Reference in New Issue
Block a user