From 0a3216e07ddd5f78c1694f2f73bdfac7c57de31d Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Fri, 29 May 2026 22:03:39 +0200 Subject: [PATCH] feat(parser):add cast expression --- gen/python.py | 5 +++++ midas/ast/printer.py | 10 ++++++++++ midas/ast/python.py | 12 ++++++++++++ midas/parser/python.py | 40 +++++++++++++++++++++++++------------- tests/serializer/python.py | 8 ++++++++ 5 files changed, 62 insertions(+), 13 deletions(-) diff --git a/gen/python.py b/gen/python.py index 91c6058..c3b5733 100644 --- a/gen/python.py +++ b/gen/python.py @@ -128,4 +128,9 @@ class SetExpr: value: Expr +class CastExpr: + type: MidasType + expr: Expr + + ###< diff --git a/midas/ast/printer.py b/midas/ast/printer.py index c923551..3839fac 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -555,3 +555,13 @@ class PythonAstPrinter( self._write_line("value", last=True) with self._child_level(single=True): expr.value.accept(self) + + def visit_cast_expr(self, expr: p.CastExpr) -> None: + self._write_line("CastExpr") + with self._child_level(): + self._write_line("type") + with self._child_level(single=True): + expr.type.accept(self) + self._write_line("expr", last=True) + with self._child_level(single=True): + expr.expr.accept(self) diff --git a/midas/ast/python.py b/midas/ast/python.py index dee4aa7..9ccefac 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -204,6 +204,9 @@ class Expr(ABC): @abstractmethod def visit_set_expr(self, expr: SetExpr) -> T: ... + @abstractmethod + def visit_cast_expr(self, expr: CastExpr) -> T: ... + @dataclass(frozen=True) class BinaryExpr(Expr): @@ -287,3 +290,12 @@ class SetExpr(Expr): def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_set_expr(self) + + +@dataclass(frozen=True) +class CastExpr(Expr): + type: MidasType + expr: Expr + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_cast_expr(self) diff --git a/midas/parser/python.py b/midas/parser/python.py index 8966fc2..bc6a9fe 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -7,6 +7,7 @@ from midas.ast.python import ( BaseType, BinaryExpr, CallExpr, + CastExpr, CompareExpr, ConstraintType, Expr, @@ -38,6 +39,8 @@ class UnsupportedSyntaxError(Exception): class PythonParser: + CAST_FUNCTION = "cast" + def parse_module(self, node: ast.Module) -> list[Stmt]: statements: list[Stmt] = [] for stmt in node.body: @@ -90,15 +93,14 @@ class PythonParser: value=value, simple=1, ): - type = self._parse_type(annotation, root=True) - if type is not None: - statements.append( - TypeAssign( - location=loc, - name=target, - type=type, - ) + type = self._parse_type(annotation) + statements.append( + TypeAssign( + location=loc, + name=target, + type=type, ) + ) if value is not None: statements.append( @@ -215,9 +217,7 @@ class PythonParser: default=default, ) - def _parse_type( - self, type_expr: ast.expr, root: bool = False - ) -> Optional[MidasType]: + def _parse_type(self, type_expr: ast.expr) -> MidasType: loc: Location = Location.from_ast(type_expr) match type_expr: case ast.Subscript(value=ast.Name(id="Frame"), slice=schema): @@ -265,8 +265,6 @@ class PythonParser: ) case _: - if root: - return None raise UnsupportedSyntaxError(type_expr) def _parse_frame_type(self, schema: ast.expr) -> FrameType: @@ -339,6 +337,9 @@ class PythonParser: case ast.Compare(): return self.parse_compare(node) + case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)): + return self.parse_cast(node) + case ast.Call(): return self.parse_call(node) @@ -407,6 +408,19 @@ class PythonParser: ) return expr + def parse_cast(self, node: ast.Call) -> CastExpr: + match node: + case ast.Call(args=[type, expr], keywords=[]): + return CastExpr( + location=Location.from_ast(node), + type=self._parse_type(type), + expr=self.parse_expr(expr), + ) + case _: + raise InvalidSyntaxError( + f"Invalid call to {self.CAST_FUNCTION}, expected type and expression" + ) + def parse_call(self, node: ast.Call) -> CallExpr: return CallExpr( location=Location.from_ast(node), diff --git a/tests/serializer/python.py b/tests/serializer/python.py index 9d4dc91..e14dc1b 100644 --- a/tests/serializer/python.py +++ b/tests/serializer/python.py @@ -6,6 +6,7 @@ from midas.ast.python import ( BaseType, BinaryExpr, CallExpr, + CastExpr, CompareExpr, ConstraintType, Expr, @@ -228,3 +229,10 @@ class PythonAstJsonSerializer( "name": expr.name, "value": expr.value.accept(self), } + + def visit_cast_expr(self, expr: CastExpr) -> dict: + return { + "_type": "CastExpr", + "type": expr.type.accept(self), + "expr": expr.expr.accept(self), + }