feat: handle unsafe casts
This commit is contained in:
@@ -145,6 +145,7 @@ class LogicalExpr:
|
||||
class CastExpr:
|
||||
type: MidasType
|
||||
expr: Expr
|
||||
unsafe: bool
|
||||
|
||||
|
||||
class TernaryExpr:
|
||||
|
||||
@@ -757,9 +757,10 @@ class PythonAstPrinter(
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
expr.type.accept(self)
|
||||
self._write_line("expr", last=True)
|
||||
self._write_line("expr")
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
self._write_line(f"unsafe: {expr.unsafe}", last=True)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||
self._write_line("TernaryExpr")
|
||||
|
||||
@@ -350,6 +350,7 @@ class LogicalExpr(Expr):
|
||||
class CastExpr(Expr):
|
||||
type: MidasType
|
||||
expr: Expr
|
||||
unsafe: bool
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_cast_expr(self)
|
||||
|
||||
@@ -133,7 +133,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
||||
expr2: ast.expr = expr.expr.accept(self)
|
||||
|
||||
if expr in self._typed_ast.evaluated_casts:
|
||||
if expr in self._typed_ast.evaluated_casts or expr.unsafe:
|
||||
return expr2
|
||||
|
||||
alias: ast.expr = self._make_alias(expr2)
|
||||
|
||||
@@ -49,6 +49,7 @@ class UnsupportedSyntaxError(Exception):
|
||||
|
||||
class PythonParser:
|
||||
CAST_FUNCTION = "cast"
|
||||
UNSAFE_CAST_FUNCTION = "unsafe_cast"
|
||||
|
||||
def parse_module(self, node: ast.Module) -> list[Stmt]:
|
||||
statements: list[Stmt] = []
|
||||
@@ -423,6 +424,9 @@ class PythonParser:
|
||||
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
|
||||
return self.parse_cast(node)
|
||||
|
||||
case ast.Call(func=ast.Name(id=self.UNSAFE_CAST_FUNCTION)):
|
||||
return self.parse_cast(node)
|
||||
|
||||
case ast.Call():
|
||||
return self.parse_call(node)
|
||||
|
||||
@@ -527,16 +531,19 @@ class PythonParser:
|
||||
return expr
|
||||
|
||||
def parse_cast(self, node: ast.Call) -> CastExpr:
|
||||
assert isinstance(node.func, ast.Name)
|
||||
func: str = node.func.id
|
||||
match node:
|
||||
case ast.Call(args=[type, expr], keywords=[]):
|
||||
return CastExpr(
|
||||
location=Location.from_ast(node),
|
||||
type=self._parse_type(type),
|
||||
expr=self.parse_expr(expr),
|
||||
unsafe=func == self.UNSAFE_CAST_FUNCTION,
|
||||
)
|
||||
case _:
|
||||
raise InvalidSyntaxError(
|
||||
f"Invalid call to {self.CAST_FUNCTION}, expected type and expression"
|
||||
f"Invalid call to {func}, expected type and expression"
|
||||
)
|
||||
|
||||
def parse_call(self, node: ast.Call) -> CallExpr:
|
||||
|
||||
@@ -29,7 +29,8 @@
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 123.45
|
||||
}
|
||||
},
|
||||
"unsafe": false
|
||||
},
|
||||
"type": {
|
||||
"name": "Meter",
|
||||
@@ -66,7 +67,8 @@
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 6.7
|
||||
}
|
||||
},
|
||||
"unsafe": false
|
||||
},
|
||||
"type": {
|
||||
"name": "Second",
|
||||
|
||||
@@ -263,6 +263,7 @@ class PythonAstJsonSerializer(
|
||||
"_type": "CastExpr",
|
||||
"type": expr.type.accept(self),
|
||||
"expr": expr.expr.accept(self),
|
||||
"unsafe": expr.unsafe,
|
||||
}
|
||||
|
||||
def visit_ternary_expr(self, expr: TernaryExpr) -> dict:
|
||||
|
||||
Reference in New Issue
Block a user