feat(parser): add if statement

This commit is contained in:
2026-06-01 14:10:46 +02:00
parent 5d4df7978b
commit ab0fa1de1a
4 changed files with 75 additions and 1 deletions

View File

@@ -76,6 +76,12 @@ class ReturnStmt:
value: Optional[Expr] value: Optional[Expr]
class IfStmt:
test: Expr
body: list[Stmt]
orelse: list[Stmt]
###< ###<

View File

@@ -419,7 +419,14 @@ class PythonAstPrinter(
self._mark_last() self._mark_last()
self._print_argument(arg) self._print_argument(arg)
self._write_optional_child("returns", stmt.returns, last=True) self._write_optional_child("returns", stmt.returns)
self._write_line("body", last=True)
with self._child_level():
for i, body_stmt in enumerate(stmt.body):
self._idx = i
if i == len(stmt.body) - 1:
self._mark_last()
body_stmt.accept(self)
def _print_argument(self, arg: p.Function.Argument) -> None: def _print_argument(self, arg: p.Function.Argument) -> None:
self._write_line("FunctionArgument") self._write_line("FunctionArgument")
@@ -454,6 +461,26 @@ class PythonAstPrinter(
with self._child_level(): with self._child_level():
self._write_optional_child("value", stmt.value, last=True) self._write_optional_child("value", stmt.value, last=True)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
self._write_line("IfStmt")
with self._child_level():
self._write_line("test")
stmt.test.accept(self)
self._write_line("body")
with self._child_level():
for i, body_stmt in enumerate(stmt.body):
self._idx = i
if i == len(stmt.body) - 1:
self._mark_last()
body_stmt.accept(self)
self._write_line("orelse", last=True)
with self._child_level():
for i, else_stmt in enumerate(stmt.orelse):
self._idx = i
if i == len(stmt.orelse) - 1:
self._mark_last()
else_stmt.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self._write_line("BinaryExpr") self._write_line("BinaryExpr")
with self._child_level(): with self._child_level():

View File

@@ -103,6 +103,9 @@ class Stmt(ABC):
@abstractmethod @abstractmethod
def visit_return_stmt(self, stmt: ReturnStmt) -> T: ... def visit_return_stmt(self, stmt: ReturnStmt) -> T: ...
@abstractmethod
def visit_if_stmt(self, stmt: IfStmt) -> T: ...
@dataclass(frozen=True) @dataclass(frozen=True)
class ExpressionStmt(Stmt): class ExpressionStmt(Stmt):
@@ -164,6 +167,16 @@ class ReturnStmt(Stmt):
return visitor.visit_return_stmt(self) return visitor.visit_return_stmt(self)
@dataclass(frozen=True)
class IfStmt(Stmt):
test: Expr
body: list[Stmt]
orelse: list[Stmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_if_stmt(self)
############### ###############
# Expressions # # Expressions #
############### ###############

View File

@@ -16,6 +16,7 @@ from midas.ast.python import (
FrameType, FrameType,
Function, Function,
GetExpr, GetExpr,
IfStmt,
LiteralExpr, LiteralExpr,
LogicalExpr, LogicalExpr,
MidasType, MidasType,
@@ -82,6 +83,9 @@ class PythonParser:
value=self.parse_expr(value) if value is not None else None, value=self.parse_expr(value) if value is not None else None,
) )
case ast.If():
return self.parse_if(node)
case _: case _:
print(f"Unsupported statement: {ast.unparse(node)}") print(f"Unsupported statement: {ast.unparse(node)}")
return None return None
@@ -147,6 +151,30 @@ class PythonParser:
), ),
) )
def parse_if(self, node: ast.If) -> IfStmt:
body: list[Stmt] = []
for stmt in node.body:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
body.append(stmts)
elif stmts is not None:
body.extend(stmts)
orelse: list[Stmt] = []
for stmt in node.orelse:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
orelse.append(stmts)
elif stmts is not None:
orelse.extend(stmts)
return IfStmt(
location=Location.from_ast(node),
test=self.parse_expr(node.test),
body=body,
orelse=orelse,
)
def parse_function(self, node: ast.FunctionDef) -> Function: def parse_function(self, node: ast.FunctionDef) -> Function:
loc: Location = Location.from_ast(node) loc: Location = Location.from_ast(node)
match node: match node: