diff --git a/gen/python.py b/gen/python.py index 2e6c7d3..99b926d 100644 --- a/gen/python.py +++ b/gen/python.py @@ -86,6 +86,12 @@ class Pass: pass +class ForStmt: + target: Expr + iterator: Expr + body: list[Stmt] + + ###< diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 77a6069..364da35 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -596,6 +596,23 @@ class PythonAstPrinter( def visit_pass(self, stmt: p.Pass) -> None: self._write_line("Pass") + def visit_for_stmt(self, stmt: p.ForStmt) -> None: + self._write_line("ForStmt") + with self._child_level(): + self._write_line("target") + with self._child_level(single=True): + stmt.target.accept(self) + self._write_line("iterator") + with self._child_level(single=True): + stmt.iterator.accept(self) + self._write_line("body", last=True) + with self._child_level(): + for i, body_stmt in enumerate(stmt.body): + self._idx = i + if i == len(stmt.body) - 1: + self._mark_last() + body_stmt.accept(self) + def visit_binary_expr(self, expr: p.BinaryExpr) -> None: self._write_line("BinaryExpr") with self._child_level(): diff --git a/midas/ast/python.py b/midas/ast/python.py index f3c9e1c..b781d2e 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -110,6 +110,9 @@ class Stmt(ABC): @abstractmethod def visit_pass(self, stmt: Pass) -> T: ... + @abstractmethod + def visit_for_stmt(self, stmt: ForStmt) -> T: ... + @dataclass(frozen=True) class ExpressionStmt(Stmt): @@ -189,6 +192,16 @@ class Pass(Stmt): return visitor.visit_pass(self) +@dataclass(frozen=True) +class ForStmt(Stmt): + target: Expr + iterator: Expr + body: list[Stmt] + + def accept(self, visitor: Stmt.Visitor[T]) -> T: + return visitor.visit_for_stmt(self) + + ############### # Expressions # ############### diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py index ce63d62..1303f94 100644 --- a/midas/cli/highlighter.py +++ b/midas/cli/highlighter.py @@ -191,6 +191,13 @@ class PythonHighlighter( def visit_pass(self, stmt: p.Pass) -> None: pass + def visit_for_stmt(self, stmt: p.ForStmt) -> None: + self.wrap(stmt, "for") + stmt.iterator.accept(self) + stmt.target.accept(self) + for body_stmt in stmt.body: + body_stmt.accept(self) + def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ... def visit_compare_expr(self, expr: p.CompareExpr) -> None: ... diff --git a/midas/parser/python.py b/midas/parser/python.py index a0726da..55edd34 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -12,6 +12,7 @@ from midas.ast.python import ( ConstraintType, Expr, ExpressionStmt, + ForStmt, FrameColumn, FrameType, Function, @@ -93,6 +94,9 @@ class PythonParser: case ast.Pass(): return None + case ast.For(orelse=[]): + return self.parse_for(node) + case _: print(f"Unsupported statement: {ast.unparse(node)}") return None @@ -182,6 +186,22 @@ class PythonParser: orelse=orelse, ) + def parse_for(self, node: ast.For) -> ForStmt: + body: list[Stmt] = [] + for stmt in node.body: + stmts = self.parse_stmt(stmt) + if isinstance(stmts, Stmt): + body.append(stmts) + elif stmts is not None: + body.extend(stmts) + + return ForStmt( + location=Location.from_ast(node), + target=self.parse_expr(node.target), + iterator=self.parse_expr(node.iter), + body=body, + ) + def parse_function(self, node: ast.FunctionDef) -> Function: loc: Location = Location.from_ast(node) match node: diff --git a/tests/serializer/python.py b/tests/serializer/python.py index e1f3fd4..56171b8 100644 --- a/tests/serializer/python.py +++ b/tests/serializer/python.py @@ -11,6 +11,7 @@ from midas.ast.python import ( ConstraintType, Expr, ExpressionStmt, + ForStmt, FrameColumn, FrameType, Function, @@ -182,6 +183,14 @@ class PythonAstJsonSerializer( "_type": "Pass", } + def visit_for_stmt(self, stmt: ForStmt) -> dict: + return { + "_type": "ForStmt", + "target": stmt.target.accept(self), + "iterator": stmt.iterator.accept(self), + "body": self._serialize_list(stmt.body), + } + def visit_binary_expr(self, expr: BinaryExpr) -> dict: return { "_type": "BinaryExpr",