From c3722c7438426ec846a41f00948919e01bff83de Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Fri, 29 May 2026 18:44:53 +0200 Subject: [PATCH] tests: add python parser tester --- tests/python.py | 46 ++++++++ tests/serializer/python.py | 230 +++++++++++++++++++++++++++++++++++++ 2 files changed, 276 insertions(+) create mode 100644 tests/python.py create mode 100644 tests/serializer/python.py diff --git a/tests/python.py b/tests/python.py new file mode 100644 index 0000000..a6cbb7b --- /dev/null +++ b/tests/python.py @@ -0,0 +1,46 @@ +import ast +import json +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Optional + +from midas.ast.python import Stmt +from midas.parser.python import PythonParser +from tests.base import Tester +from tests.serializer.python import PythonAstJsonSerializer + + +@dataclass +class CaseResult: + stmts: Optional[list[dict]] = None + + def dumps(self) -> str: + return json.dumps(asdict(self), indent=2) + + +class PythonTester(Tester): + @property + def namespace(self) -> str: + return "python-parser" + + def _list_tests(self) -> list[Path]: + return list(self.base_dir.rglob("*.py")) + + def _exec_case(self, path: Path) -> CaseResult: + if not path.exists(): + raise FileNotFoundError(f"Could not find test '{path}'") + if not path.is_file(): + raise TypeError(f"Test '{path}' is not a file") + + result: CaseResult = CaseResult() + content: str = path.read_text() + tree: ast.Module = ast.parse(content) + + parser: PythonParser = PythonParser() + stmts: list[Stmt] = parser.parse_module(tree) + result.stmts = PythonAstJsonSerializer().serialize(stmts) + return result + + +if __name__ == "__main__": + PythonTester.main() diff --git a/tests/serializer/python.py b/tests/serializer/python.py new file mode 100644 index 0000000..9d4dc91 --- /dev/null +++ b/tests/serializer/python.py @@ -0,0 +1,230 @@ +import ast +from typing import Optional, Sequence, Type + +from midas.ast.python import ( + AssignStmt, + BaseType, + BinaryExpr, + CallExpr, + CompareExpr, + ConstraintType, + Expr, + ExpressionStmt, + FrameColumn, + FrameType, + Function, + GetExpr, + LiteralExpr, + LogicalExpr, + MidasType, + ReturnStmt, + SetExpr, + Stmt, + TypeAssign, + UnaryExpr, + VariableExpr, +) + +unary_ops: dict[Type[ast.unaryop], str] = { + ast.Invert: "~", + ast.Not: "not", + ast.UAdd: "+", + ast.USub: "-", +} +binary_ops: dict[Type[ast.operator], str] = { + ast.Add: "+", + ast.Sub: "-", + ast.Mult: "*", + ast.MatMult: "@", + ast.Div: "/", + ast.Mod: "%", + ast.LShift: "<<", + ast.RShift: ">>", + ast.BitOr: "|", + ast.BitXor: "^", + ast.BitAnd: "&", + ast.FloorDiv: "//", + ast.Pow: "**", +} +compare_ops: dict[Type[ast.cmpop], str] = { + ast.Eq: "==", + ast.NotEq: "!=", + ast.Lt: "<", + ast.LtE: "<=", + ast.Gt: ">", + ast.GtE: ">=", + ast.Is: "is", + ast.IsNot: "is not", + ast.In: "in", + ast.NotIn: "not in", +} +boolean_ops: dict[Type[ast.boolop], str] = { + ast.And: "and", + ast.Or: "or", +} + + +class PythonAstJsonSerializer( + Stmt.Visitor[dict], Expr.Visitor[dict], MidasType.Visitor[dict] +): + """An AST serializer which produces a JSON-compatible structure""" + + def serialize(self, stmts: list[Stmt]) -> list[dict]: + return [stmt.accept(self) for stmt in stmts] + + def _serialize_optional( + self, element: Optional[Stmt | Expr | MidasType] + ) -> Optional[dict]: + if element is None: + return None + return element.accept(self) + + def _serialize_list( + self, elements: Sequence[Stmt | Expr | MidasType] + ) -> list[dict]: + return [element.accept(self) for element in elements] + + def visit_base_type(self, node: BaseType) -> dict: + return { + "_type": "BaseType", + "base": node.base, + "param": self._serialize_optional(node.param), + } + + def visit_constraint_type(self, node: ConstraintType) -> dict: + return { + "_type": "ConstraintType", + "type": node.type.accept(self), + "constraint": ast.unparse(node.constraint), + } + + def visit_frame_column(self, node: FrameColumn) -> dict: + return { + "_type": "FrameColumn", + "name": node.name, + "type": self._serialize_optional(node.type), + } + + def visit_frame_type(self, node: FrameType) -> dict: + return { + "_type": "FrameType", + "columns": self._serialize_list(node.columns), + } + + def visit_expression_stmt(self, stmt: ExpressionStmt) -> dict: + return { + "_type": "ExpressionStmt", + "expr": stmt.expr.accept(self), + } + + def _serialize_argument(self, arg: Function.Argument) -> dict: + return { + "name": arg.name, + "type": self._serialize_optional(arg.type), + "default": self._serialize_optional(arg.default), + } + + def visit_function(self, stmt: Function) -> dict: + return { + "_type": "Function", + "name": stmt.name, + "posonlyargs": [self._serialize_argument(arg) for arg in stmt.posonlyargs], + "args": [self._serialize_argument(arg) for arg in stmt.args], + "sink": ( + self._serialize_argument(stmt.sink) if stmt.sink is not None else None + ), + "kwonlyargs": [self._serialize_argument(arg) for arg in stmt.kwonlyargs], + "kw_sink": ( + self._serialize_argument(stmt.kw_sink) + if stmt.kw_sink is not None + else None + ), + "returns": self._serialize_optional(stmt.returns), + "body": self._serialize_list(stmt.body), + } + + def visit_type_assign(self, stmt: TypeAssign) -> dict: + return { + "_type": "TypeAssign", + "name": stmt.name, + "type": stmt.type.accept(self), + } + + def visit_assign_stmt(self, stmt: AssignStmt) -> dict: + return { + "_type": "AssignStmt", + "targets": self._serialize_list(stmt.targets), + "value": stmt.value.accept(self), + } + + def visit_return_stmt(self, stmt: ReturnStmt) -> dict: + return { + "_type": "ReturnStmt", + "value": self._serialize_optional(stmt.value), + } + + def visit_binary_expr(self, expr: BinaryExpr) -> dict: + return { + "_type": "BinaryExpr", + "left": expr.left.accept(self), + "operator": binary_ops[expr.operator.__class__], + "right": expr.right.accept(self), + } + + def visit_compare_expr(self, expr: CompareExpr) -> dict: + return { + "_type": "CompareExpr", + "left": expr.left.accept(self), + "operator": compare_ops[expr.operator.__class__], + "right": expr.right.accept(self), + } + + def visit_unary_expr(self, expr: UnaryExpr) -> dict: + return { + "_type": "UnaryExpr", + "operator": unary_ops[expr.operator.__class__], + "right": expr.right.accept(self), + } + + def visit_call_expr(self, expr: CallExpr) -> dict: + return { + "_type": "CallExpr", + "callee": expr.callee.accept(self), + "arguments": self._serialize_list(expr.arguments), + "keywords": {name: arg.accept(self) for name, arg in expr.keywords.items()}, + } + + def visit_get_expr(self, expr: GetExpr) -> dict: + return { + "_type": "GetExpr", + "object": expr.object.accept(self), + "name": expr.name, + } + + def visit_literal_expr(self, expr: LiteralExpr) -> dict: + return { + "_type": "LiteralExpr", + "value": expr.value, + } + + def visit_variable_expr(self, expr: VariableExpr) -> dict: + return { + "_type": "VariableExpr", + "name": expr.name, + } + + def visit_logical_expr(self, expr: LogicalExpr) -> dict: + return { + "_type": "LogicalExpr", + "left": expr.left.accept(self), + "operator": boolean_ops[expr.operator.__class__], + "right": expr.right.accept(self), + } + + def visit_set_expr(self, expr: SetExpr) -> dict: + return { + "_type": "SetExpr", + "object": expr.object.accept(self), + "name": expr.name, + "value": expr.value.accept(self), + }