From 0b3f33d7fe3ec47ab133a3eb22190a340eb2d7db Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Mon, 25 May 2026 23:17:52 +0200 Subject: [PATCH] feat(parser): parse python expressions --- gen/python.py | 6 +++ midas/ast/printer.py | 25 ++++++++- midas/ast/python.py | 14 +++++ midas/cli/highlighter.py | 2 + midas/parser/python.py | 110 ++++++++++++++++++++++++++++++++++++--- 5 files changed, 150 insertions(+), 7 deletions(-) diff --git a/gen/python.py b/gen/python.py index db12f42..9bf984d 100644 --- a/gen/python.py +++ b/gen/python.py @@ -74,6 +74,12 @@ class BinaryExpr: right: Expr +class CompareExpr: + left: Expr + operator: ast.cmpop + right: Expr + + class UnaryExpr: operator: ast.unaryop right: Expr diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 9ac012f..e3ecde9 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -462,6 +462,19 @@ class PythonAstPrinter( with self._child_level(single=True): expr.right.accept(self) + def visit_compare_expr(self, expr: p.CompareExpr) -> None: + self._write_line("CompareExpr") + with self._child_level(): + self._write_line("left") + with self._child_level(single=True): + expr.left.accept(self) + + self._write_line(f"operator: {expr.operator.__class__.__name__}") + + self._write_line("right", last=True) + with self._child_level(single=True): + expr.right.accept(self) + def visit_unary_expr(self, expr: p.UnaryExpr) -> None: self._write_line("UnaryExpr") with self._child_level(): @@ -478,7 +491,7 @@ class PythonAstPrinter( with self._child_level(single=True): expr.callee.accept(self) - self._write_line("arguments", last=True) + self._write_line("arguments") with self._child_level(): for i, arg in enumerate(expr.arguments): self._idx = i @@ -486,6 +499,16 @@ class PythonAstPrinter( self._mark_last() arg.accept(self) + self._write_line("keywords", last=True) + with self._child_level(): + for i, (name, arg) in enumerate(expr.keywords.items()): + self._idx = i + if i == len(expr.keywords) - 1: + self._mark_last() + self._write_line(name) + with self._child_level(single=True): + arg.accept(self) + def visit_get_expr(self, expr: p.GetExpr) -> None: self._write_line("GetExpr") with self._child_level(): diff --git a/midas/ast/python.py b/midas/ast/python.py index 7ca700b..d4fc032 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -161,6 +161,9 @@ class Expr(ABC): @abstractmethod def visit_binary_expr(self, expr: BinaryExpr) -> T: ... + @abstractmethod + def visit_compare_expr(self, expr: CompareExpr) -> T: ... + @abstractmethod def visit_unary_expr(self, expr: UnaryExpr) -> T: ... @@ -193,6 +196,16 @@ class BinaryExpr(Expr): return visitor.visit_binary_expr(self) +@dataclass(frozen=True) +class CompareExpr(Expr): + left: Expr + operator: ast.cmpop + right: Expr + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_compare_expr(self) + + @dataclass(frozen=True) class UnaryExpr(Expr): operator: ast.unaryop @@ -206,6 +219,7 @@ class UnaryExpr(Expr): class CallExpr(Expr): callee: Expr arguments: list[Expr] + keywords: dict[str, Expr] def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_call_expr(self) diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py index e9c3c4e..f4801bb 100644 --- a/midas/cli/highlighter.py +++ b/midas/cli/highlighter.py @@ -155,6 +155,8 @@ class PythonHighlighter( def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ... + def visit_compare_expr(self, expr: p.CompareExpr) -> None: ... + def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ... def visit_call_expr(self, expr: p.CallExpr) -> None: ... diff --git a/midas/parser/python.py b/midas/parser/python.py index 95fe0c0..4b6a3f1 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -6,14 +6,22 @@ from midas.ast.location import Location from midas.ast.python import ( AssignStmt, BaseType, + BinaryExpr, + CallExpr, + CompareExpr, ConstraintType, Expr, + ExpressionStmt, FrameColumn, FrameType, Function, + GetExpr, + LiteralExpr, + LogicalExpr, MidasType, Stmt, TypeAssign, + UnaryExpr, VariableExpr, ) @@ -33,11 +41,15 @@ class PythonParser: def parse_module(self, node: ast.Module) -> list[Stmt]: statements: list[Stmt] = [] for stmt in node.body: - parsed: None | Stmt | list[Stmt] = self.parse_stmt(stmt) - if isinstance(parsed, Stmt): - statements.append(parsed) - elif parsed is not None: - statements.extend(parsed) + try: + parsed: None | Stmt | list[Stmt] = self.parse_stmt(stmt) + if isinstance(parsed, Stmt): + statements.append(parsed) + elif parsed is not None: + statements.extend(parsed) + except UnsupportedSyntaxError as e: + print(f"{e}, skipping") + continue return statements def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]: @@ -51,6 +63,9 @@ class PythonParser: case ast.FunctionDef(): return self.parse_function(node) + case ast.Expr(value=expr): + return ExpressionStmt(expr=self.parse_expr(expr)) + case _: print(f"Unsupported statement: {ast.unparse(node)}") return None @@ -242,4 +257,87 @@ class PythonParser: raise UnsupportedSyntaxError(column) def parse_expr(self, node: ast.expr) -> Expr: - raise NotImplementedError() + match node: + case ast.BoolOp(): + return self.parse_bool_op(node) + + case ast.BinOp(left=left, op=op, right=right): + return BinaryExpr( + left=self.parse_expr(left), + operator=op, + right=self.parse_expr(right), + ) + + case ast.UnaryOp(op=op, operand=right): + return UnaryExpr( + operator=op, + right=self.parse_expr(right), + ) + + case ast.Compare(): + return self.parse_compare(node) + + case ast.Call(): + return self.parse_call(node) + + case ast.Constant(value=value): + return LiteralExpr(value=value) + + case ast.Attribute(value=object, attr=name): + return GetExpr( + object=self.parse_expr(object), + name=name, + ) + + case ast.Name(id=name): + return VariableExpr(name=name) + + case _: + raise UnsupportedSyntaxError(node) + + def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr: + op: ast.boolop = node.op + values: list[ast.expr] = node.values + expr: LogicalExpr = LogicalExpr( + left=self.parse_expr(values[0]), + operator=op, + right=self.parse_expr(values[1]), + ) + for value in values[2:]: + expr = LogicalExpr( + left=expr, + operator=op, + right=self.parse_expr(value), + ) + return expr + + def parse_compare(self, node: ast.Compare) -> Expr: + ops: list[ast.cmpop] = node.ops + rights: list[Expr] = [self.parse_expr(expr) for expr in node.comparators] + expr: Expr = CompareExpr( + left=self.parse_expr(node.left), + operator=ops[0], + right=rights[0], + ) + for i, right in enumerate(rights[1:]): + expr = LogicalExpr( + left=expr, + operator=ast.And(), + right=CompareExpr( + left=rights[i], + operator=ops[i], + right=right, + ), + ) + return expr + + def parse_call(self, node: ast.Call) -> CallExpr: + return CallExpr( + callee=self.parse_expr(node.func), + arguments=[self.parse_expr(arg) for arg in node.args], + keywords={ + arg.arg: self.parse_expr(arg.value) + for arg in node.keywords + if arg.arg is not None # Should always be True, type checker happy + }, + )