feat(parser): parse python expressions

This commit is contained in:
2026-05-25 23:17:52 +02:00
parent 8a9b4f3989
commit 0b3f33d7fe
5 changed files with 150 additions and 7 deletions

View File

@@ -74,6 +74,12 @@ class BinaryExpr:
right: Expr
class CompareExpr:
left: Expr
operator: ast.cmpop
right: Expr
class UnaryExpr:
operator: ast.unaryop
right: Expr

View File

@@ -462,6 +462,19 @@ class PythonAstPrinter(
with self._child_level(single=True):
expr.right.accept(self)
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
self._write_line("CompareExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.__class__.__name__}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
self._write_line("UnaryExpr")
with self._child_level():
@@ -478,7 +491,7 @@ class PythonAstPrinter(
with self._child_level(single=True):
expr.callee.accept(self)
self._write_line("arguments", last=True)
self._write_line("arguments")
with self._child_level():
for i, arg in enumerate(expr.arguments):
self._idx = i
@@ -486,6 +499,16 @@ class PythonAstPrinter(
self._mark_last()
arg.accept(self)
self._write_line("keywords", last=True)
with self._child_level():
for i, (name, arg) in enumerate(expr.keywords.items()):
self._idx = i
if i == len(expr.keywords) - 1:
self._mark_last()
self._write_line(name)
with self._child_level(single=True):
arg.accept(self)
def visit_get_expr(self, expr: p.GetExpr) -> None:
self._write_line("GetExpr")
with self._child_level():

View File

@@ -161,6 +161,9 @@ class Expr(ABC):
@abstractmethod
def visit_binary_expr(self, expr: BinaryExpr) -> T: ...
@abstractmethod
def visit_compare_expr(self, expr: CompareExpr) -> T: ...
@abstractmethod
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
@@ -193,6 +196,16 @@ class BinaryExpr(Expr):
return visitor.visit_binary_expr(self)
@dataclass(frozen=True)
class CompareExpr(Expr):
left: Expr
operator: ast.cmpop
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_compare_expr(self)
@dataclass(frozen=True)
class UnaryExpr(Expr):
operator: ast.unaryop
@@ -206,6 +219,7 @@ class UnaryExpr(Expr):
class CallExpr(Expr):
callee: Expr
arguments: list[Expr]
keywords: dict[str, Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_call_expr(self)

View File

@@ -155,6 +155,8 @@ class PythonHighlighter(
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
def visit_compare_expr(self, expr: p.CompareExpr) -> None: ...
def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ...
def visit_call_expr(self, expr: p.CallExpr) -> None: ...

View File

@@ -6,14 +6,22 @@ from midas.ast.location import Location
from midas.ast.python import (
AssignStmt,
BaseType,
BinaryExpr,
CallExpr,
CompareExpr,
ConstraintType,
Expr,
ExpressionStmt,
FrameColumn,
FrameType,
Function,
GetExpr,
LiteralExpr,
LogicalExpr,
MidasType,
Stmt,
TypeAssign,
UnaryExpr,
VariableExpr,
)
@@ -33,11 +41,15 @@ class PythonParser:
def parse_module(self, node: ast.Module) -> list[Stmt]:
statements: list[Stmt] = []
for stmt in node.body:
parsed: None | Stmt | list[Stmt] = self.parse_stmt(stmt)
if isinstance(parsed, Stmt):
statements.append(parsed)
elif parsed is not None:
statements.extend(parsed)
try:
parsed: None | Stmt | list[Stmt] = self.parse_stmt(stmt)
if isinstance(parsed, Stmt):
statements.append(parsed)
elif parsed is not None:
statements.extend(parsed)
except UnsupportedSyntaxError as e:
print(f"{e}, skipping")
continue
return statements
def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]:
@@ -51,6 +63,9 @@ class PythonParser:
case ast.FunctionDef():
return self.parse_function(node)
case ast.Expr(value=expr):
return ExpressionStmt(expr=self.parse_expr(expr))
case _:
print(f"Unsupported statement: {ast.unparse(node)}")
return None
@@ -242,4 +257,87 @@ class PythonParser:
raise UnsupportedSyntaxError(column)
def parse_expr(self, node: ast.expr) -> Expr:
raise NotImplementedError()
match node:
case ast.BoolOp():
return self.parse_bool_op(node)
case ast.BinOp(left=left, op=op, right=right):
return BinaryExpr(
left=self.parse_expr(left),
operator=op,
right=self.parse_expr(right),
)
case ast.UnaryOp(op=op, operand=right):
return UnaryExpr(
operator=op,
right=self.parse_expr(right),
)
case ast.Compare():
return self.parse_compare(node)
case ast.Call():
return self.parse_call(node)
case ast.Constant(value=value):
return LiteralExpr(value=value)
case ast.Attribute(value=object, attr=name):
return GetExpr(
object=self.parse_expr(object),
name=name,
)
case ast.Name(id=name):
return VariableExpr(name=name)
case _:
raise UnsupportedSyntaxError(node)
def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr:
op: ast.boolop = node.op
values: list[ast.expr] = node.values
expr: LogicalExpr = LogicalExpr(
left=self.parse_expr(values[0]),
operator=op,
right=self.parse_expr(values[1]),
)
for value in values[2:]:
expr = LogicalExpr(
left=expr,
operator=op,
right=self.parse_expr(value),
)
return expr
def parse_compare(self, node: ast.Compare) -> Expr:
ops: list[ast.cmpop] = node.ops
rights: list[Expr] = [self.parse_expr(expr) for expr in node.comparators]
expr: Expr = CompareExpr(
left=self.parse_expr(node.left),
operator=ops[0],
right=rights[0],
)
for i, right in enumerate(rights[1:]):
expr = LogicalExpr(
left=expr,
operator=ast.And(),
right=CompareExpr(
left=rights[i],
operator=ops[i],
right=right,
),
)
return expr
def parse_call(self, node: ast.Call) -> CallExpr:
return CallExpr(
callee=self.parse_expr(node.func),
arguments=[self.parse_expr(arg) for arg in node.args],
keywords={
arg.arg: self.parse_expr(arg.value)
for arg in node.keywords
if arg.arg is not None # Should always be True, type checker happy
},
)