feat(parser): add for loop node

This commit is contained in:
2026-06-16 00:35:05 +02:00
parent 274e366561
commit faa98ce0ef
6 changed files with 72 additions and 0 deletions

View File

@@ -86,6 +86,12 @@ class Pass:
pass pass
class ForStmt:
target: Expr
iterator: Expr
body: list[Stmt]
###< ###<

View File

@@ -596,6 +596,23 @@ class PythonAstPrinter(
def visit_pass(self, stmt: p.Pass) -> None: def visit_pass(self, stmt: p.Pass) -> None:
self._write_line("Pass") self._write_line("Pass")
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
self._write_line("ForStmt")
with self._child_level():
self._write_line("target")
with self._child_level(single=True):
stmt.target.accept(self)
self._write_line("iterator")
with self._child_level(single=True):
stmt.iterator.accept(self)
self._write_line("body", last=True)
with self._child_level():
for i, body_stmt in enumerate(stmt.body):
self._idx = i
if i == len(stmt.body) - 1:
self._mark_last()
body_stmt.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self._write_line("BinaryExpr") self._write_line("BinaryExpr")
with self._child_level(): with self._child_level():

View File

@@ -110,6 +110,9 @@ class Stmt(ABC):
@abstractmethod @abstractmethod
def visit_pass(self, stmt: Pass) -> T: ... def visit_pass(self, stmt: Pass) -> T: ...
@abstractmethod
def visit_for_stmt(self, stmt: ForStmt) -> T: ...
@dataclass(frozen=True) @dataclass(frozen=True)
class ExpressionStmt(Stmt): class ExpressionStmt(Stmt):
@@ -189,6 +192,16 @@ class Pass(Stmt):
return visitor.visit_pass(self) return visitor.visit_pass(self)
@dataclass(frozen=True)
class ForStmt(Stmt):
target: Expr
iterator: Expr
body: list[Stmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_for_stmt(self)
############### ###############
# Expressions # # Expressions #
############### ###############

View File

@@ -191,6 +191,13 @@ class PythonHighlighter(
def visit_pass(self, stmt: p.Pass) -> None: def visit_pass(self, stmt: p.Pass) -> None:
pass pass
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
self.wrap(stmt, "for")
stmt.iterator.accept(self)
stmt.target.accept(self)
for body_stmt in stmt.body:
body_stmt.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ... def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
def visit_compare_expr(self, expr: p.CompareExpr) -> None: ... def visit_compare_expr(self, expr: p.CompareExpr) -> None: ...

View File

@@ -12,6 +12,7 @@ from midas.ast.python import (
ConstraintType, ConstraintType,
Expr, Expr,
ExpressionStmt, ExpressionStmt,
ForStmt,
FrameColumn, FrameColumn,
FrameType, FrameType,
Function, Function,
@@ -93,6 +94,9 @@ class PythonParser:
case ast.Pass(): case ast.Pass():
return None return None
case ast.For(orelse=[]):
return self.parse_for(node)
case _: case _:
print(f"Unsupported statement: {ast.unparse(node)}") print(f"Unsupported statement: {ast.unparse(node)}")
return None return None
@@ -182,6 +186,22 @@ class PythonParser:
orelse=orelse, orelse=orelse,
) )
def parse_for(self, node: ast.For) -> ForStmt:
body: list[Stmt] = []
for stmt in node.body:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
body.append(stmts)
elif stmts is not None:
body.extend(stmts)
return ForStmt(
location=Location.from_ast(node),
target=self.parse_expr(node.target),
iterator=self.parse_expr(node.iter),
body=body,
)
def parse_function(self, node: ast.FunctionDef) -> Function: def parse_function(self, node: ast.FunctionDef) -> Function:
loc: Location = Location.from_ast(node) loc: Location = Location.from_ast(node)
match node: match node:

View File

@@ -11,6 +11,7 @@ from midas.ast.python import (
ConstraintType, ConstraintType,
Expr, Expr,
ExpressionStmt, ExpressionStmt,
ForStmt,
FrameColumn, FrameColumn,
FrameType, FrameType,
Function, Function,
@@ -182,6 +183,14 @@ class PythonAstJsonSerializer(
"_type": "Pass", "_type": "Pass",
} }
def visit_for_stmt(self, stmt: ForStmt) -> dict:
return {
"_type": "ForStmt",
"target": stmt.target.accept(self),
"iterator": stmt.iterator.accept(self),
"body": self._serialize_list(stmt.body),
}
def visit_binary_expr(self, expr: BinaryExpr) -> dict: def visit_binary_expr(self, expr: BinaryExpr) -> dict:
return { return {
"_type": "BinaryExpr", "_type": "BinaryExpr",