diff --git a/gen/python.py b/gen/python.py index c3b5733..6240d5d 100644 --- a/gen/python.py +++ b/gen/python.py @@ -76,6 +76,12 @@ class ReturnStmt: value: Optional[Expr] +class IfStmt: + test: Expr + body: list[Stmt] + orelse: list[Stmt] + + ###< diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 3839fac..a687936 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -419,7 +419,14 @@ class PythonAstPrinter( self._mark_last() self._print_argument(arg) - self._write_optional_child("returns", stmt.returns, last=True) + self._write_optional_child("returns", stmt.returns) + self._write_line("body", last=True) + with self._child_level(): + for i, body_stmt in enumerate(stmt.body): + self._idx = i + if i == len(stmt.body) - 1: + self._mark_last() + body_stmt.accept(self) def _print_argument(self, arg: p.Function.Argument) -> None: self._write_line("FunctionArgument") @@ -454,6 +461,26 @@ class PythonAstPrinter( with self._child_level(): self._write_optional_child("value", stmt.value, last=True) + def visit_if_stmt(self, stmt: p.IfStmt) -> None: + self._write_line("IfStmt") + with self._child_level(): + self._write_line("test") + stmt.test.accept(self) + self._write_line("body") + with self._child_level(): + for i, body_stmt in enumerate(stmt.body): + self._idx = i + if i == len(stmt.body) - 1: + self._mark_last() + body_stmt.accept(self) + self._write_line("orelse", last=True) + with self._child_level(): + for i, else_stmt in enumerate(stmt.orelse): + self._idx = i + if i == len(stmt.orelse) - 1: + self._mark_last() + else_stmt.accept(self) + def visit_binary_expr(self, expr: p.BinaryExpr) -> None: self._write_line("BinaryExpr") with self._child_level(): diff --git a/midas/ast/python.py b/midas/ast/python.py index 9ccefac..4b9d08a 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -103,6 +103,9 @@ class Stmt(ABC): @abstractmethod def visit_return_stmt(self, stmt: ReturnStmt) -> T: ... + @abstractmethod + def visit_if_stmt(self, stmt: IfStmt) -> T: ... + @dataclass(frozen=True) class ExpressionStmt(Stmt): @@ -164,6 +167,16 @@ class ReturnStmt(Stmt): return visitor.visit_return_stmt(self) +@dataclass(frozen=True) +class IfStmt(Stmt): + test: Expr + body: list[Stmt] + orelse: list[Stmt] + + def accept(self, visitor: Stmt.Visitor[T]) -> T: + return visitor.visit_if_stmt(self) + + ############### # Expressions # ############### diff --git a/midas/parser/python.py b/midas/parser/python.py index 6a01b32..9073953 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -16,6 +16,7 @@ from midas.ast.python import ( FrameType, Function, GetExpr, + IfStmt, LiteralExpr, LogicalExpr, MidasType, @@ -82,6 +83,9 @@ class PythonParser: value=self.parse_expr(value) if value is not None else None, ) + case ast.If(): + return self.parse_if(node) + case _: print(f"Unsupported statement: {ast.unparse(node)}") return None @@ -147,6 +151,30 @@ class PythonParser: ), ) + def parse_if(self, node: ast.If) -> IfStmt: + body: list[Stmt] = [] + for stmt in node.body: + stmts = self.parse_stmt(stmt) + if isinstance(stmts, Stmt): + body.append(stmts) + elif stmts is not None: + body.extend(stmts) + + orelse: list[Stmt] = [] + for stmt in node.orelse: + stmts = self.parse_stmt(stmt) + if isinstance(stmts, Stmt): + orelse.append(stmts) + elif stmts is not None: + orelse.extend(stmts) + + return IfStmt( + location=Location.from_ast(node), + test=self.parse_expr(node.test), + body=body, + orelse=orelse, + ) + def parse_function(self, node: ast.FunctionDef) -> Function: loc: Location = Location.from_ast(node) match node: