diff --git a/gen/python.py b/gen/python.py index f67f540..4af901a 100644 --- a/gen/python.py +++ b/gen/python.py @@ -157,6 +157,11 @@ class ListExpr: items: list[Expr] +class DictExpr: + keys: list[Optional[Expr]] + values: list[Expr] + + class SubscriptExpr: object: Expr index: Expr diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 68ff7ba..694c272 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -745,6 +745,27 @@ class PythonAstPrinter( self._mark_last() item.accept(self) + def visit_dict_expr(self, expr: p.DictExpr) -> None: + self._write_line("DictExpr") + with self._child_level(): + self._write_line("keys") + with self._child_level(): + for i, key in enumerate(expr.keys): + self._idx = i + if i == len(expr.keys) - 1: + self._mark_last() + if key is None: + self._write_line("None") + else: + key.accept(self) + self._write_line("values", last=True) + with self._child_level(): + for i, value in enumerate(expr.values): + self._idx = i + if i == len(expr.values) - 1: + self._mark_last() + value.accept(self) + def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: self._write_line("SubscriptExpr") with self._child_level(): diff --git a/midas/ast/python.py b/midas/ast/python.py index 73d49e5..7770de6 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -259,6 +259,9 @@ class Expr(ABC): @abstractmethod def visit_list_expr(self, expr: ListExpr) -> T: ... + @abstractmethod + def visit_dict_expr(self, expr: DictExpr) -> T: ... + @abstractmethod def visit_subscript_expr(self, expr: SubscriptExpr) -> T: ... @@ -370,6 +373,15 @@ class ListExpr(Expr): return visitor.visit_list_expr(self) +@dataclass(frozen=True) +class DictExpr(Expr): + keys: list[Optional[Expr]] + values: list[Expr] + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_dict_expr(self) + + @dataclass(frozen=True) class SubscriptExpr(Expr): object: Expr diff --git a/midas/parser/python.py b/midas/parser/python.py index 90c029a..4110feb 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -10,6 +10,7 @@ from midas.ast.python import ( CastExpr, CompareExpr, ConstraintType, + DictExpr, Expr, ExpressionStmt, ForStmt, @@ -447,6 +448,16 @@ class PythonParser: items=[self.parse_expr(item) for item in items], ) + case ast.Dict(keys=keys, values=values): + return DictExpr( + location=location, + keys=[ + self.parse_expr(key) if key is not None else None + for key in keys + ], + values=[self.parse_expr(value) for value in values], + ) + case ast.Subscript(value=value, slice=index): return SubscriptExpr( location=location, diff --git a/tests/serializer/python.py b/tests/serializer/python.py index 45951df..038b496 100644 --- a/tests/serializer/python.py +++ b/tests/serializer/python.py @@ -9,6 +9,7 @@ from midas.ast.python import ( CastExpr, CompareExpr, ConstraintType, + DictExpr, Expr, ExpressionStmt, ForStmt, @@ -278,6 +279,13 @@ class PythonAstJsonSerializer( "items": [item.accept(self) for item in expr.items], } + def visit_dict_expr(self, expr: DictExpr) -> dict: + return { + "_type": "DictExpr", + "keys": [self._serialize_optional(key) for key in expr.keys], + "values": self._serialize_list(expr.values), + } + def visit_subscript_expr(self, expr: SubscriptExpr) -> dict: return { "_type": "SubscriptExpr",