diff --git a/midas/parser/python.py b/midas/parser/python.py index 082cab1..277a71c 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -1,14 +1,20 @@ import ast -from typing import Any, Optional +from typing import Optional from midas.ast.location import Location + from midas.ast.python import ( + AssignExpr, BaseType, ConstraintType, + Expr, + ExpressionStmt, FrameColumn, FrameType, Function, MidasType, + Stmt, + TypeAssign, ) @@ -23,33 +29,66 @@ class UnsupportedSyntaxError(Exception): ) -class PythonParser(ast.NodeVisitor): - def __init__(self) -> None: - super().__init__() +class PythonParser: + def parse_module(self, node: ast.Module) -> list[Stmt]: + statements: list[Stmt] = [] + for stmt in node.body: + parsed: None | Stmt | list[Stmt] = self.parse_stmt(stmt) + if isinstance(parsed, Stmt): + statements.append(parsed) + elif parsed is not None: + statements.extend(parsed) + return statements - self.annotations: list[tuple[str, Optional[MidasType]]] = [] - self.functions: list[Function] = [] - - def visit_AnnAssign(self, node: ast.AnnAssign) -> Any: + def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]: match node: - case ast.AnnAssign( - target=ast.Name(id=target), annotation=annotation, simple=1 - ): - self.annotations.append( - (target, self._parse_type(annotation, root=True)) - ) + case ast.AnnAssign(): + return self.parse_annotation_assign(node) + + case ast.FunctionDef(): + return self.parse_function(node) + case _: + print(f"Unsupported assignment: {ast.unparse(node)}") + return None + + def parse_annotation_assign(self, node: ast.AnnAssign) -> list[Stmt]: + statements: list[Stmt] = [] + loc: Location = Location.from_ast(node) + match node: + case ast.AnnAssign( + target=ast.Name(id=target), + annotation=annotation, + value=value, + simple=1, + ): + type = self._parse_type(annotation, root=True) + if type is not None: + statements.append( + TypeAssign( + location=loc, + name=target, + type=type, + ) + ) + + if value is not None: + parsed_value: Expr = self.parse_expr(value) + statements.append( + ExpressionStmt( + location=loc, + expr=AssignExpr( + location=loc, + name=target, + value=parsed_value, + ), + ) + ) case _: print(f"Unsupported annotation: {ast.unparse(node)}") + return statements - def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: - self.functions.append(self._parse_function(node)) - - # Call visit on children to process body - # TODO: scope the resulting nodes to the function - self.generic_visit(node) - - def _parse_function(self, node: ast.FunctionDef) -> Function: + def parse_function(self, node: ast.FunctionDef) -> Function: loc: Location = Location.from_ast(node) match node: case ast.FunctionDef( @@ -73,6 +112,8 @@ class PythonParser(ast.NodeVisitor): kwonlyargs=parse_args(kwonlyargs), returns=self._parse_type(returns) if returns is not None else None, ) + case _: + print(f"Unsupported function definition: {ast.unparse(node)}") def _parse_function_argument(self, arg: ast.arg) -> Function.Argument: loc: Location = Location.from_ast(arg) @@ -185,3 +226,6 @@ class PythonParser(ast.NodeVisitor): case _: raise UnsupportedSyntaxError(column) + + def parse_expr(self, node: ast.expr) -> Expr: + raise NotImplementedError()