From 9f05ba3224722ce7579919154ead4e8b6e612f4b Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Wed, 24 Jun 2026 13:51:01 +0200 Subject: [PATCH] feat: handle unsafe casts --- gen/python.py | 1 + midas/ast/printer.py | 3 ++- midas/ast/python.py | 1 + midas/generator/generator.py | 2 +- midas/parser/python.py | 9 ++++++++- tests/cases/checker/04_custom_types.py.ref.json | 6 ++++-- tests/serializer/python.py | 1 + 7 files changed, 18 insertions(+), 5 deletions(-) diff --git a/gen/python.py b/gen/python.py index 4af901a..df83e6f 100644 --- a/gen/python.py +++ b/gen/python.py @@ -145,6 +145,7 @@ class LogicalExpr: class CastExpr: type: MidasType expr: Expr + unsafe: bool class TernaryExpr: diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 1c75a44..680fd79 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -757,9 +757,10 @@ class PythonAstPrinter( self._write_line("type") with self._child_level(single=True): expr.type.accept(self) - self._write_line("expr", last=True) + self._write_line("expr") with self._child_level(single=True): expr.expr.accept(self) + self._write_line(f"unsafe: {expr.unsafe}", last=True) def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: self._write_line("TernaryExpr") diff --git a/midas/ast/python.py b/midas/ast/python.py index 7770de6..20d7279 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -350,6 +350,7 @@ class LogicalExpr(Expr): class CastExpr(Expr): type: MidasType expr: Expr + unsafe: bool def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_cast_expr(self) diff --git a/midas/generator/generator.py b/midas/generator/generator.py index 0fba91e..0af3fcd 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -133,7 +133,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr: expr2: ast.expr = expr.expr.accept(self) - if expr in self._typed_ast.evaluated_casts: + if expr in self._typed_ast.evaluated_casts or expr.unsafe: return expr2 alias: ast.expr = self._make_alias(expr2) diff --git a/midas/parser/python.py b/midas/parser/python.py index 4110feb..7839f52 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -49,6 +49,7 @@ class UnsupportedSyntaxError(Exception): class PythonParser: CAST_FUNCTION = "cast" + UNSAFE_CAST_FUNCTION = "unsafe_cast" def parse_module(self, node: ast.Module) -> list[Stmt]: statements: list[Stmt] = [] @@ -423,6 +424,9 @@ class PythonParser: case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)): return self.parse_cast(node) + case ast.Call(func=ast.Name(id=self.UNSAFE_CAST_FUNCTION)): + return self.parse_cast(node) + case ast.Call(): return self.parse_call(node) @@ -527,16 +531,19 @@ class PythonParser: return expr def parse_cast(self, node: ast.Call) -> CastExpr: + assert isinstance(node.func, ast.Name) + func: str = node.func.id match node: case ast.Call(args=[type, expr], keywords=[]): return CastExpr( location=Location.from_ast(node), type=self._parse_type(type), expr=self.parse_expr(expr), + unsafe=func == self.UNSAFE_CAST_FUNCTION, ) case _: raise InvalidSyntaxError( - f"Invalid call to {self.CAST_FUNCTION}, expected type and expression" + f"Invalid call to {func}, expected type and expression" ) def parse_call(self, node: ast.Call) -> CallExpr: diff --git a/tests/cases/checker/04_custom_types.py.ref.json b/tests/cases/checker/04_custom_types.py.ref.json index d502a97..01177d9 100644 --- a/tests/cases/checker/04_custom_types.py.ref.json +++ b/tests/cases/checker/04_custom_types.py.ref.json @@ -29,7 +29,8 @@ "expr": { "_type": "LiteralExpr", "value": 123.45 - } + }, + "unsafe": false }, "type": { "name": "Meter", @@ -66,7 +67,8 @@ "expr": { "_type": "LiteralExpr", "value": 6.7 - } + }, + "unsafe": false }, "type": { "name": "Second", diff --git a/tests/serializer/python.py b/tests/serializer/python.py index 038b496..3904739 100644 --- a/tests/serializer/python.py +++ b/tests/serializer/python.py @@ -263,6 +263,7 @@ class PythonAstJsonSerializer( "_type": "CastExpr", "type": expr.type.accept(self), "expr": expr.expr.accept(self), + "unsafe": expr.unsafe, } def visit_ternary_expr(self, expr: TernaryExpr) -> dict: