diff --git a/gen/python.py b/gen/python.py index 33a01cc..db12f42 100644 --- a/gen/python.py +++ b/gen/python.py @@ -59,15 +59,15 @@ class TypeAssign: type: MidasType +class AssignStmt: + targets: list[Expr] + value: Expr + + ###< ###> Expr | Expressions -class AssignExpr: - name: str - value: Expr - - class BinaryExpr: left: Expr operator: ast.operator diff --git a/midas/ast/printer.py b/midas/ast/printer.py index f467a85..9ac012f 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -435,13 +435,19 @@ class PythonAstPrinter( with self._child_level(single=True): stmt.type.accept(self) - def visit_assign_expr(self, expr: p.AssignExpr) -> None: - self._write_line("AssignExpr") + def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: + self._write_line("AssignStmt") with self._child_level(): - self._write_line(f"name: {expr.name}") + self._write_line("targets") + with self._child_level(): + for i, target in enumerate(stmt.targets): + self._idx = i + if i == len(stmt.targets) - 1: + self._mark_last() + target.accept(self) self._write_line("value", last=True) with self._child_level(single=True): - expr.value.accept(self) + stmt.value.accept(self) def visit_binary_expr(self, expr: p.BinaryExpr) -> None: self._write_line("BinaryExpr") diff --git a/midas/ast/python.py b/midas/ast/python.py index 5e219c8..7ca700b 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -97,6 +97,9 @@ class Stmt(ABC): @abstractmethod def visit_type_assign(self, stmt: TypeAssign) -> T: ... + @abstractmethod + def visit_assign_stmt(self, stmt: AssignStmt) -> T: ... + @dataclass(frozen=True) class ExpressionStmt(Stmt): @@ -133,6 +136,15 @@ class TypeAssign(Stmt): return visitor.visit_type_assign(self) +@dataclass(frozen=True) +class AssignStmt(Stmt): + targets: list[Expr] + value: Expr + + def accept(self, visitor: Stmt.Visitor[T]) -> T: + return visitor.visit_assign_stmt(self) + + ############### # Expressions # ############### @@ -146,9 +158,6 @@ class Expr(ABC): def accept(self, visitor: Visitor[T]) -> T: ... class Visitor(ABC, Generic[T]): - @abstractmethod - def visit_assign_expr(self, expr: AssignExpr) -> T: ... - @abstractmethod def visit_binary_expr(self, expr: BinaryExpr) -> T: ... @@ -174,15 +183,6 @@ class Expr(ABC): def visit_set_expr(self, expr: SetExpr) -> T: ... -@dataclass(frozen=True) -class AssignExpr(Expr): - name: str - value: Expr - - def accept(self, visitor: Expr.Visitor[T]) -> T: - return visitor.visit_assign_expr(self) - - @dataclass(frozen=True) class BinaryExpr(Expr): left: Expr diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py index 45ed55c..e9c3c4e 100644 --- a/midas/cli/highlighter.py +++ b/midas/cli/highlighter.py @@ -151,7 +151,7 @@ class PythonHighlighter( def visit_type_assign(self, stmt: p.TypeAssign) -> None: stmt.type.accept(self) - def visit_assign_expr(self, expr: p.AssignExpr) -> None: ... + def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: ... def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ... diff --git a/midas/parser/python.py b/midas/parser/python.py index 277a71c..95fe0c0 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -4,17 +4,17 @@ from typing import Optional from midas.ast.location import Location from midas.ast.python import ( - AssignExpr, + AssignStmt, BaseType, ConstraintType, Expr, - ExpressionStmt, FrameColumn, FrameType, Function, MidasType, Stmt, TypeAssign, + VariableExpr, ) @@ -45,11 +45,14 @@ class PythonParser: case ast.AnnAssign(): return self.parse_annotation_assign(node) + case ast.Assign(): + return self.parse_assign(node) + case ast.FunctionDef(): return self.parse_function(node) case _: - print(f"Unsupported assignment: {ast.unparse(node)}") + print(f"Unsupported statement: {ast.unparse(node)}") return None def parse_annotation_assign(self, node: ast.AnnAssign) -> list[Stmt]: @@ -73,21 +76,32 @@ class PythonParser: ) if value is not None: - parsed_value: Expr = self.parse_expr(value) statements.append( - ExpressionStmt( + AssignStmt( location=loc, - expr=AssignExpr( - location=loc, - name=target, - value=parsed_value, - ), - ) + targets=[ + VariableExpr( + location=Location.from_ast(node.target), name=target + ), + ], + value=self.parse_expr(value), + ), ) case _: print(f"Unsupported annotation: {ast.unparse(node)}") return statements + def parse_assign(self, node: ast.Assign) -> AssignStmt: + targets: list[Expr] = [] + for target in node.targets: + targets.append(self.parse_expr(target)) + value: Expr = self.parse_expr(node.value) + return AssignStmt( + location=Location.from_ast(node), + targets=targets, + value=value, + ) + def parse_function(self, node: ast.FunctionDef) -> Function: loc: Location = Location.from_ast(node) match node: