feat(parser): add if statement
This commit is contained in:
@@ -76,6 +76,12 @@ class ReturnStmt:
|
|||||||
value: Optional[Expr]
|
value: Optional[Expr]
|
||||||
|
|
||||||
|
|
||||||
|
class IfStmt:
|
||||||
|
test: Expr
|
||||||
|
body: list[Stmt]
|
||||||
|
orelse: list[Stmt]
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -419,7 +419,14 @@ class PythonAstPrinter(
|
|||||||
self._mark_last()
|
self._mark_last()
|
||||||
self._print_argument(arg)
|
self._print_argument(arg)
|
||||||
|
|
||||||
self._write_optional_child("returns", stmt.returns, last=True)
|
self._write_optional_child("returns", stmt.returns)
|
||||||
|
self._write_line("body", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, body_stmt in enumerate(stmt.body):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.body) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
body_stmt.accept(self)
|
||||||
|
|
||||||
def _print_argument(self, arg: p.Function.Argument) -> None:
|
def _print_argument(self, arg: p.Function.Argument) -> None:
|
||||||
self._write_line("FunctionArgument")
|
self._write_line("FunctionArgument")
|
||||||
@@ -454,6 +461,26 @@ class PythonAstPrinter(
|
|||||||
with self._child_level():
|
with self._child_level():
|
||||||
self._write_optional_child("value", stmt.value, last=True)
|
self._write_optional_child("value", stmt.value, last=True)
|
||||||
|
|
||||||
|
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||||
|
self._write_line("IfStmt")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("test")
|
||||||
|
stmt.test.accept(self)
|
||||||
|
self._write_line("body")
|
||||||
|
with self._child_level():
|
||||||
|
for i, body_stmt in enumerate(stmt.body):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.body) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
body_stmt.accept(self)
|
||||||
|
self._write_line("orelse", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, else_stmt in enumerate(stmt.orelse):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.orelse) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
else_stmt.accept(self)
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||||
self._write_line("BinaryExpr")
|
self._write_line("BinaryExpr")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
|
|||||||
@@ -103,6 +103,9 @@ class Stmt(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_return_stmt(self, stmt: ReturnStmt) -> T: ...
|
def visit_return_stmt(self, stmt: ReturnStmt) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_if_stmt(self, stmt: IfStmt) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ExpressionStmt(Stmt):
|
class ExpressionStmt(Stmt):
|
||||||
@@ -164,6 +167,16 @@ class ReturnStmt(Stmt):
|
|||||||
return visitor.visit_return_stmt(self)
|
return visitor.visit_return_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class IfStmt(Stmt):
|
||||||
|
test: Expr
|
||||||
|
body: list[Stmt]
|
||||||
|
orelse: list[Stmt]
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_if_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
###############
|
###############
|
||||||
# Expressions #
|
# Expressions #
|
||||||
###############
|
###############
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from midas.ast.python import (
|
|||||||
FrameType,
|
FrameType,
|
||||||
Function,
|
Function,
|
||||||
GetExpr,
|
GetExpr,
|
||||||
|
IfStmt,
|
||||||
LiteralExpr,
|
LiteralExpr,
|
||||||
LogicalExpr,
|
LogicalExpr,
|
||||||
MidasType,
|
MidasType,
|
||||||
@@ -82,6 +83,9 @@ class PythonParser:
|
|||||||
value=self.parse_expr(value) if value is not None else None,
|
value=self.parse_expr(value) if value is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case ast.If():
|
||||||
|
return self.parse_if(node)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
print(f"Unsupported statement: {ast.unparse(node)}")
|
print(f"Unsupported statement: {ast.unparse(node)}")
|
||||||
return None
|
return None
|
||||||
@@ -147,6 +151,30 @@ class PythonParser:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def parse_if(self, node: ast.If) -> IfStmt:
|
||||||
|
body: list[Stmt] = []
|
||||||
|
for stmt in node.body:
|
||||||
|
stmts = self.parse_stmt(stmt)
|
||||||
|
if isinstance(stmts, Stmt):
|
||||||
|
body.append(stmts)
|
||||||
|
elif stmts is not None:
|
||||||
|
body.extend(stmts)
|
||||||
|
|
||||||
|
orelse: list[Stmt] = []
|
||||||
|
for stmt in node.orelse:
|
||||||
|
stmts = self.parse_stmt(stmt)
|
||||||
|
if isinstance(stmts, Stmt):
|
||||||
|
orelse.append(stmts)
|
||||||
|
elif stmts is not None:
|
||||||
|
orelse.extend(stmts)
|
||||||
|
|
||||||
|
return IfStmt(
|
||||||
|
location=Location.from_ast(node),
|
||||||
|
test=self.parse_expr(node.test),
|
||||||
|
body=body,
|
||||||
|
orelse=orelse,
|
||||||
|
)
|
||||||
|
|
||||||
def parse_function(self, node: ast.FunctionDef) -> Function:
|
def parse_function(self, node: ast.FunctionDef) -> Function:
|
||||||
loc: Location = Location.from_ast(node)
|
loc: Location = Location.from_ast(node)
|
||||||
match node:
|
match node:
|
||||||
|
|||||||
Reference in New Issue
Block a user