diff --git a/tests/serializer/midas.py b/tests/serializer/midas.py index 0991334..6da0dec 100644 --- a/tests/serializer/midas.py +++ b/tests/serializer/midas.py @@ -2,56 +2,61 @@ from typing import Optional, Sequence from midas.ast.midas import ( BinaryExpr, - ComplexTypeStmt, + ComplexType, + ConstraintType, Expr, ExtendStmt, + GenericType, GetExpr, GroupingExpr, LiteralExpr, LogicalExpr, + NamedType, OpStmt, PredicateStmt, PropertyStmt, - SimpleTypeExpr, - SimpleTypeStmt, Stmt, - TemplateExpr, - TypeExpr, + Type, + TypeStmt, UnaryExpr, + UnionType, VariableExpr, WildcardExpr, ) -class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]): +class MidasAstJsonSerializer( + Stmt.Visitor[dict], Expr.Visitor[dict], Type.Visitor[dict] +): """An AST serializer which produces a JSON-compatible structure""" def serialize(self, stmts: list[Stmt]) -> list[dict]: return [stmt.accept(self) for stmt in stmts] - def _serialize_optional(self, element: Optional[Stmt | Expr]) -> Optional[dict]: + def _serialize_optional( + self, element: Optional[Stmt | Expr | Type] + ) -> Optional[dict]: if element is None: return None return element.accept(self) - def _serialize_list(self, elements: Sequence[Stmt | Expr]) -> list[dict]: + def _serialize_list(self, elements: Sequence[Stmt | Expr | Type]) -> list[dict]: return [element.accept(self) for element in elements] - def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> dict: + def visit_type_stmt(self, stmt: TypeStmt) -> dict: return { - "_type": "SimpleTypeStmt", + "_type": "TypeStmt", "name": stmt.name.lexeme, - "template": self._serialize_optional(stmt.template), - "base": stmt.base.accept(self), - "constraint": self._serialize_optional(stmt.constraint), + "params": [ + self._serialize_type_stmt_template_param(param) for param in stmt.params + ], + "type": stmt.type.accept(self), } - def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> dict: + def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict: return { - "_type": "ComplexTypeStmt", - "name": stmt.name.lexeme, - "template": self._serialize_optional(stmt.template), - "properties": self._serialize_list(stmt.properties), + "name": param.name.lexeme, + "bound": self._serialize_optional(param.bound), } def visit_property_stmt(self, stmt: PropertyStmt) -> dict: @@ -59,7 +64,6 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]): "_type": "PropertyStmt", "name": stmt.name.lexeme, "type": stmt.type.accept(self), - "constraint": self._serialize_optional(stmt.constraint), } def visit_extend_stmt(self, stmt: ExtendStmt) -> dict: @@ -86,13 +90,6 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]): "condition": stmt.condition.accept(self), } - def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> dict: - return { - "_type": "SimpleTypeExpr", - "name": expr.name.lexeme, - "optional": expr.optional, - } - def visit_logical_expr(self, expr: LogicalExpr) -> dict: return { "_type": "LogicalExpr", @@ -144,16 +141,34 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]): def visit_wildcard_expr(self, expr: WildcardExpr) -> dict: return {"_type": "WildcardExpr"} - def visit_template_expr(self, expr: TemplateExpr) -> dict: + def visit_named_type(self, type: NamedType) -> dict: return { - "_type": "TemplateExpr", - "type": expr.type.accept(self), + "_type": "NamedType", + "name": type.name.lexeme, } - def visit_type_expr(self, expr: TypeExpr) -> dict: + def visit_generic_type(self, type: GenericType) -> dict: return { - "_type": "TypeExpr", - "name": expr.name.lexeme, - "template": self._serialize_optional(expr.template), - "optional": expr.optional, + "_type": "GenericType", + "type": type.type.accept(self), + "params": self._serialize_list(type.params), + } + + def visit_constraint_type(self, type: ConstraintType) -> dict: + return { + "_type": "ConstraintType", + "type": type.type.accept(self), + "constraint": type.constraint.accept(self), + } + + def visit_union_type(self, type: UnionType) -> dict: + return { + "_type": "UnionType", + "types": self._serialize_list(type.types), + } + + def visit_complex_type(self, type: ComplexType) -> dict: + return { + "_type": "ComplexType", + "properties": self._serialize_list(type.properties), }