diff --git a/gen/python.py b/gen/python.py index 99b926d..f67f540 100644 --- a/gen/python.py +++ b/gen/python.py @@ -92,6 +92,10 @@ class ForStmt: body: list[Stmt] +class RawStmt: + stmt: ast.stmt + + ###< @@ -164,4 +168,8 @@ class SliceExpr: step: Optional[Expr] +class RawExpr: + expr: ast.expr + + ###< diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 364da35..68ff7ba 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -613,6 +613,11 @@ class PythonAstPrinter( self._mark_last() body_stmt.accept(self) + def visit_raw_stmt(self, stmt: p.RawStmt) -> None: + self._write_line("RawStmt") + with self._child_level(single=True): + self._write_line(f"stmt: {ast.unparse(stmt.stmt)}") + def visit_binary_expr(self, expr: p.BinaryExpr) -> None: self._write_line("BinaryExpr") with self._child_level(): @@ -756,3 +761,8 @@ class PythonAstPrinter( self._write_optional_child("lower", expr.lower) self._write_optional_child("upper", expr.upper) self._write_optional_child("step", expr.step, last=True) + + def visit_raw_expr(self, expr: p.RawExpr) -> None: + self._write_line("RawExpr") + with self._child_level(single=True): + self._write_line(f"expr: {ast.unparse(expr.expr)}") diff --git a/midas/ast/python.py b/midas/ast/python.py index b781d2e..73d49e5 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -113,6 +113,9 @@ class Stmt(ABC): @abstractmethod def visit_for_stmt(self, stmt: ForStmt) -> T: ... + @abstractmethod + def visit_raw_stmt(self, stmt: RawStmt) -> T: ... + @dataclass(frozen=True) class ExpressionStmt(Stmt): @@ -202,6 +205,14 @@ class ForStmt(Stmt): return visitor.visit_for_stmt(self) +@dataclass(frozen=True) +class RawStmt(Stmt): + stmt: ast.stmt + + def accept(self, visitor: Stmt.Visitor[T]) -> T: + return visitor.visit_raw_stmt(self) + + ############### # Expressions # ############### @@ -254,6 +265,9 @@ class Expr(ABC): @abstractmethod def visit_slice_expr(self, expr: SliceExpr) -> T: ... + @abstractmethod + def visit_raw_expr(self, expr: RawExpr) -> T: ... + @dataclass(frozen=True) class BinaryExpr(Expr): @@ -373,3 +387,11 @@ class SliceExpr(Expr): def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_slice_expr(self) + + +@dataclass(frozen=True) +class RawExpr(Expr): + expr: ast.expr + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_raw_expr(self) diff --git a/midas/parser/python.py b/midas/parser/python.py index 55edd34..90c029a 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -22,6 +22,8 @@ from midas.ast.python import ( LiteralExpr, LogicalExpr, MidasType, + RawExpr, + RawStmt, ReturnStmt, SliceExpr, Stmt, @@ -99,7 +101,7 @@ class PythonParser: case _: print(f"Unsupported statement: {ast.unparse(node)}") - return None + return RawStmt(location=location, stmt=node) def parse_annotation_assign(self, node: ast.AnnAssign) -> list[Stmt]: statements: list[Stmt] = [] @@ -461,7 +463,8 @@ class PythonParser: ) case _: - raise UnsupportedSyntaxError(node) + print(f"Unsupported expression: {ast.unparse(node)}") + return RawExpr(location=location, expr=node) def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr: op: ast.boolop = node.op diff --git a/tests/cases/python-parser/01_simple_types.py.ref.json b/tests/cases/python-parser/01_simple_types.py.ref.json index e4fd591..452b9c0 100644 --- a/tests/cases/python-parser/01_simple_types.py.ref.json +++ b/tests/cases/python-parser/01_simple_types.py.ref.json @@ -1,5 +1,9 @@ { "stmts": [ + { + "_type": "RawStmt", + "stmt": "from __future__ import annotations" + }, { "_type": "TypeAssign", "name": "df", diff --git a/tests/cases/python-parser/02_custom_types.py.ref.json b/tests/cases/python-parser/02_custom_types.py.ref.json index 82c726c..9d77ebd 100644 --- a/tests/cases/python-parser/02_custom_types.py.ref.json +++ b/tests/cases/python-parser/02_custom_types.py.ref.json @@ -1,5 +1,9 @@ { "stmts": [ + { + "_type": "RawStmt", + "stmt": "from __future__ import annotations" + }, { "_type": "TypeAssign", "name": "df", diff --git a/tests/cases/python-parser/03_functions.py.ref.json b/tests/cases/python-parser/03_functions.py.ref.json index 529455b..a8f261f 100644 --- a/tests/cases/python-parser/03_functions.py.ref.json +++ b/tests/cases/python-parser/03_functions.py.ref.json @@ -1,5 +1,9 @@ { "stmts": [ + { + "_type": "RawStmt", + "stmt": "from __future__ import annotations" + }, { "_type": "Function", "name": "func", diff --git a/tests/serializer/python.py b/tests/serializer/python.py index 56171b8..45951df 100644 --- a/tests/serializer/python.py +++ b/tests/serializer/python.py @@ -22,6 +22,8 @@ from midas.ast.python import ( LogicalExpr, MidasType, Pass, + RawExpr, + RawStmt, ReturnStmt, SliceExpr, Stmt, @@ -191,6 +193,12 @@ class PythonAstJsonSerializer( "body": self._serialize_list(stmt.body), } + def visit_raw_stmt(self, stmt: RawStmt) -> dict: + return { + "_type": "RawStmt", + "stmt": ast.unparse(stmt.stmt), + } + def visit_binary_expr(self, expr: BinaryExpr) -> dict: return { "_type": "BinaryExpr", @@ -284,3 +292,9 @@ class PythonAstJsonSerializer( "upper": self._serialize_optional(expr.upper), "step": self._serialize_optional(expr.step), } + + def visit_raw_expr(self, expr: RawExpr) -> dict: + return { + "_type": "RawExpr", + "expr": ast.unparse(expr.expr), + }