diff --git a/gen/python.py b/gen/python.py index 6240d5d..09d21b8 100644 --- a/gen/python.py +++ b/gen/python.py @@ -139,4 +139,10 @@ class CastExpr: expr: Expr +class TernaryExpr: + test: Expr + if_true: Expr + if_false: Expr + + ###< diff --git a/midas/ast/printer.py b/midas/ast/printer.py index a687936..45e4a64 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -465,7 +465,8 @@ class PythonAstPrinter( self._write_line("IfStmt") with self._child_level(): self._write_line("test") - stmt.test.accept(self) + with self._child_level(single=True): + stmt.test.accept(self) self._write_line("body") with self._child_level(): for i, body_stmt in enumerate(stmt.body): @@ -592,3 +593,18 @@ class PythonAstPrinter( self._write_line("expr", last=True) with self._child_level(single=True): expr.expr.accept(self) + + def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: + self._write_line("TernaryExpr") + with self._child_level(): + self._write_line("test") + with self._child_level(single=True): + expr.test.accept(self) + + self._write_line("if_true") + with self._child_level(single=True): + expr.if_true.accept(self) + + self._write_line("if_false", last=True) + with self._child_level(single=True): + expr.if_false.accept(self) diff --git a/midas/ast/python.py b/midas/ast/python.py index 4b9d08a..8607cd2 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -220,6 +220,9 @@ class Expr(ABC): @abstractmethod def visit_cast_expr(self, expr: CastExpr) -> T: ... + @abstractmethod + def visit_ternary_expr(self, expr: TernaryExpr) -> T: ... + @dataclass(frozen=True) class BinaryExpr(Expr): @@ -312,3 +315,13 @@ class CastExpr(Expr): def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_cast_expr(self) + + +@dataclass(frozen=True) +class TernaryExpr(Expr): + test: Expr + if_true: Expr + if_false: Expr + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_ternary_expr(self) diff --git a/midas/parser/python.py b/midas/parser/python.py index 9073953..265e8be 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -22,6 +22,7 @@ from midas.ast.python import ( MidasType, ReturnStmt, Stmt, + TernaryExpr, TypeAssign, UnaryExpr, VariableExpr, @@ -389,6 +390,9 @@ class PythonParser: case ast.Call(): return self.parse_call(node) + case ast.IfExp(): + return self.parse_ternary(node) + case ast.Constant(value=value): return LiteralExpr(location=location, value=value) @@ -478,3 +482,11 @@ class PythonParser: if arg.arg is not None # Should always be True, type checker happy }, ) + + def parse_ternary(self, node: ast.IfExp) -> TernaryExpr: + return TernaryExpr( + location=Location.from_ast(node), + test=self.parse_expr(node.test), + if_true=self.parse_expr(node.body), + if_false=self.parse_expr(node.orelse), + )