feat(parser): parse python expressions
This commit is contained in:
@@ -74,6 +74,12 @@ class BinaryExpr:
|
|||||||
right: Expr
|
right: Expr
|
||||||
|
|
||||||
|
|
||||||
|
class CompareExpr:
|
||||||
|
left: Expr
|
||||||
|
operator: ast.cmpop
|
||||||
|
right: Expr
|
||||||
|
|
||||||
|
|
||||||
class UnaryExpr:
|
class UnaryExpr:
|
||||||
operator: ast.unaryop
|
operator: ast.unaryop
|
||||||
right: Expr
|
right: Expr
|
||||||
|
|||||||
@@ -462,6 +462,19 @@ class PythonAstPrinter(
|
|||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
expr.right.accept(self)
|
expr.right.accept(self)
|
||||||
|
|
||||||
|
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
|
||||||
|
self._write_line("CompareExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("left")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.left.accept(self)
|
||||||
|
|
||||||
|
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||||
|
|
||||||
|
self._write_line("right", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.right.accept(self)
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
|
||||||
self._write_line("UnaryExpr")
|
self._write_line("UnaryExpr")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
@@ -478,7 +491,7 @@ class PythonAstPrinter(
|
|||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
expr.callee.accept(self)
|
expr.callee.accept(self)
|
||||||
|
|
||||||
self._write_line("arguments", last=True)
|
self._write_line("arguments")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
for i, arg in enumerate(expr.arguments):
|
for i, arg in enumerate(expr.arguments):
|
||||||
self._idx = i
|
self._idx = i
|
||||||
@@ -486,6 +499,16 @@ class PythonAstPrinter(
|
|||||||
self._mark_last()
|
self._mark_last()
|
||||||
arg.accept(self)
|
arg.accept(self)
|
||||||
|
|
||||||
|
self._write_line("keywords", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(expr.keywords) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._write_line(name)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
arg.accept(self)
|
||||||
|
|
||||||
def visit_get_expr(self, expr: p.GetExpr) -> None:
|
def visit_get_expr(self, expr: p.GetExpr) -> None:
|
||||||
self._write_line("GetExpr")
|
self._write_line("GetExpr")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
|
|||||||
@@ -161,6 +161,9 @@ class Expr(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_binary_expr(self, expr: BinaryExpr) -> T: ...
|
def visit_binary_expr(self, expr: BinaryExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_compare_expr(self, expr: CompareExpr) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
|
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
|
||||||
|
|
||||||
@@ -193,6 +196,16 @@ class BinaryExpr(Expr):
|
|||||||
return visitor.visit_binary_expr(self)
|
return visitor.visit_binary_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CompareExpr(Expr):
|
||||||
|
left: Expr
|
||||||
|
operator: ast.cmpop
|
||||||
|
right: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_compare_expr(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class UnaryExpr(Expr):
|
class UnaryExpr(Expr):
|
||||||
operator: ast.unaryop
|
operator: ast.unaryop
|
||||||
@@ -206,6 +219,7 @@ class UnaryExpr(Expr):
|
|||||||
class CallExpr(Expr):
|
class CallExpr(Expr):
|
||||||
callee: Expr
|
callee: Expr
|
||||||
arguments: list[Expr]
|
arguments: list[Expr]
|
||||||
|
keywords: dict[str, Expr]
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
return visitor.visit_call_expr(self)
|
return visitor.visit_call_expr(self)
|
||||||
|
|||||||
@@ -155,6 +155,8 @@ class PythonHighlighter(
|
|||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_compare_expr(self, expr: p.CompareExpr) -> None: ...
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ...
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ...
|
||||||
|
|
||||||
def visit_call_expr(self, expr: p.CallExpr) -> None: ...
|
def visit_call_expr(self, expr: p.CallExpr) -> None: ...
|
||||||
|
|||||||
@@ -6,14 +6,22 @@ from midas.ast.location import Location
|
|||||||
from midas.ast.python import (
|
from midas.ast.python import (
|
||||||
AssignStmt,
|
AssignStmt,
|
||||||
BaseType,
|
BaseType,
|
||||||
|
BinaryExpr,
|
||||||
|
CallExpr,
|
||||||
|
CompareExpr,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
Expr,
|
Expr,
|
||||||
|
ExpressionStmt,
|
||||||
FrameColumn,
|
FrameColumn,
|
||||||
FrameType,
|
FrameType,
|
||||||
Function,
|
Function,
|
||||||
|
GetExpr,
|
||||||
|
LiteralExpr,
|
||||||
|
LogicalExpr,
|
||||||
MidasType,
|
MidasType,
|
||||||
Stmt,
|
Stmt,
|
||||||
TypeAssign,
|
TypeAssign,
|
||||||
|
UnaryExpr,
|
||||||
VariableExpr,
|
VariableExpr,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -33,11 +41,15 @@ class PythonParser:
|
|||||||
def parse_module(self, node: ast.Module) -> list[Stmt]:
|
def parse_module(self, node: ast.Module) -> list[Stmt]:
|
||||||
statements: list[Stmt] = []
|
statements: list[Stmt] = []
|
||||||
for stmt in node.body:
|
for stmt in node.body:
|
||||||
|
try:
|
||||||
parsed: None | Stmt | list[Stmt] = self.parse_stmt(stmt)
|
parsed: None | Stmt | list[Stmt] = self.parse_stmt(stmt)
|
||||||
if isinstance(parsed, Stmt):
|
if isinstance(parsed, Stmt):
|
||||||
statements.append(parsed)
|
statements.append(parsed)
|
||||||
elif parsed is not None:
|
elif parsed is not None:
|
||||||
statements.extend(parsed)
|
statements.extend(parsed)
|
||||||
|
except UnsupportedSyntaxError as e:
|
||||||
|
print(f"{e}, skipping")
|
||||||
|
continue
|
||||||
return statements
|
return statements
|
||||||
|
|
||||||
def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]:
|
def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]:
|
||||||
@@ -51,6 +63,9 @@ class PythonParser:
|
|||||||
case ast.FunctionDef():
|
case ast.FunctionDef():
|
||||||
return self.parse_function(node)
|
return self.parse_function(node)
|
||||||
|
|
||||||
|
case ast.Expr(value=expr):
|
||||||
|
return ExpressionStmt(expr=self.parse_expr(expr))
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
print(f"Unsupported statement: {ast.unparse(node)}")
|
print(f"Unsupported statement: {ast.unparse(node)}")
|
||||||
return None
|
return None
|
||||||
@@ -242,4 +257,87 @@ class PythonParser:
|
|||||||
raise UnsupportedSyntaxError(column)
|
raise UnsupportedSyntaxError(column)
|
||||||
|
|
||||||
def parse_expr(self, node: ast.expr) -> Expr:
|
def parse_expr(self, node: ast.expr) -> Expr:
|
||||||
raise NotImplementedError()
|
match node:
|
||||||
|
case ast.BoolOp():
|
||||||
|
return self.parse_bool_op(node)
|
||||||
|
|
||||||
|
case ast.BinOp(left=left, op=op, right=right):
|
||||||
|
return BinaryExpr(
|
||||||
|
left=self.parse_expr(left),
|
||||||
|
operator=op,
|
||||||
|
right=self.parse_expr(right),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ast.UnaryOp(op=op, operand=right):
|
||||||
|
return UnaryExpr(
|
||||||
|
operator=op,
|
||||||
|
right=self.parse_expr(right),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ast.Compare():
|
||||||
|
return self.parse_compare(node)
|
||||||
|
|
||||||
|
case ast.Call():
|
||||||
|
return self.parse_call(node)
|
||||||
|
|
||||||
|
case ast.Constant(value=value):
|
||||||
|
return LiteralExpr(value=value)
|
||||||
|
|
||||||
|
case ast.Attribute(value=object, attr=name):
|
||||||
|
return GetExpr(
|
||||||
|
object=self.parse_expr(object),
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
|
||||||
|
case ast.Name(id=name):
|
||||||
|
return VariableExpr(name=name)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise UnsupportedSyntaxError(node)
|
||||||
|
|
||||||
|
def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr:
|
||||||
|
op: ast.boolop = node.op
|
||||||
|
values: list[ast.expr] = node.values
|
||||||
|
expr: LogicalExpr = LogicalExpr(
|
||||||
|
left=self.parse_expr(values[0]),
|
||||||
|
operator=op,
|
||||||
|
right=self.parse_expr(values[1]),
|
||||||
|
)
|
||||||
|
for value in values[2:]:
|
||||||
|
expr = LogicalExpr(
|
||||||
|
left=expr,
|
||||||
|
operator=op,
|
||||||
|
right=self.parse_expr(value),
|
||||||
|
)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
def parse_compare(self, node: ast.Compare) -> Expr:
|
||||||
|
ops: list[ast.cmpop] = node.ops
|
||||||
|
rights: list[Expr] = [self.parse_expr(expr) for expr in node.comparators]
|
||||||
|
expr: Expr = CompareExpr(
|
||||||
|
left=self.parse_expr(node.left),
|
||||||
|
operator=ops[0],
|
||||||
|
right=rights[0],
|
||||||
|
)
|
||||||
|
for i, right in enumerate(rights[1:]):
|
||||||
|
expr = LogicalExpr(
|
||||||
|
left=expr,
|
||||||
|
operator=ast.And(),
|
||||||
|
right=CompareExpr(
|
||||||
|
left=rights[i],
|
||||||
|
operator=ops[i],
|
||||||
|
right=right,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
def parse_call(self, node: ast.Call) -> CallExpr:
|
||||||
|
return CallExpr(
|
||||||
|
callee=self.parse_expr(node.func),
|
||||||
|
arguments=[self.parse_expr(arg) for arg in node.args],
|
||||||
|
keywords={
|
||||||
|
arg.arg: self.parse_expr(arg.value)
|
||||||
|
for arg in node.keywords
|
||||||
|
if arg.arg is not None # Should always be True, type checker happy
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user