diff --git a/examples/00_syntax_prototype/04_functions.py b/examples/00_syntax_prototype/04_functions.py new file mode 100644 index 0000000..3b07899 --- /dev/null +++ b/examples/00_syntax_prototype/04_functions.py @@ -0,0 +1,15 @@ +# type: ignore +# ruff: disable[F821] +from __future__ import annotations + + +def func( + col1: Column[float + (0 <= _ <= 1)], + col2: Column[float + (0 <= _ <= 1)], +) -> Column[float + (0 <= _ <= 2)]: + result: Column[float + (0 <= _ <= 2)] = col1 + col2 + return result + + +def func2(a: int, /, b: float, *, c: str): + pass diff --git a/midas/ast/printer.py b/midas/ast/printer.py index e9033e2..b92e40f 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -350,7 +350,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}" -class PythonAstPrinter(AstPrinter, p.MidasType.Visitor[None]): +class PythonAstPrinter(AstPrinter, p.Expr.Visitor[None]): def visit_base_type(self, node: p.BaseType) -> None: self._write_line("BaseType") with self._child_level(): @@ -381,3 +381,40 @@ class PythonAstPrinter(AstPrinter, p.MidasType.Visitor[None]): if i == len(node.columns) - 1: self._mark_last() col.accept(self) + + def visit_function(self, node: p.Function) -> None: + self._write_line("Function") + with self._child_level(): + self._write_line(f"name: {node.name}") + + self._write_line("posonlyargs") + with self._child_level(): + for i, arg in enumerate(node.posonlyargs): + self._idx = i + if i == len(node.posonlyargs) - 1: + self._mark_last() + arg.accept(self) + + self._write_line("args") + with self._child_level(): + for i, arg in enumerate(node.args): + self._idx = i + if i == len(node.args) - 1: + self._mark_last() + arg.accept(self) + + self._write_line("kwonlyargs") + with self._child_level(): + for i, arg in enumerate(node.kwonlyargs): + self._idx = i + if i == len(node.kwonlyargs) - 1: + self._mark_last() + arg.accept(self) + + self._write_optional_child("returns", node.returns, last=True) + + def visit_function_argument(self, node: p.FunctionArgument) -> None: + self._write_line("FunctionArgument") + with self._child_level(): + self._write_line(f"name: {node.name}") + self._write_optional_child("type", node.type, last=True) diff --git a/midas/ast/python.py b/midas/ast/python.py index 8b7f03e..9350fd0 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -9,7 +9,7 @@ T = TypeVar("T") @dataclass(frozen=True) -class MidasType(ABC): +class Expr(ABC): @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... @@ -26,13 +26,24 @@ class MidasType(ABC): @abstractmethod def visit_frame_type(self, node: FrameType) -> T: ... + @abstractmethod + def visit_function(self, node: Function) -> T: ... + + @abstractmethod + def visit_function_argument(self, node: FunctionArgument) -> T: ... + + +@dataclass(frozen=True) +class MidasType(Expr): + pass + @dataclass(frozen=True) class BaseType(MidasType): base: str param: Optional[MidasType] - def accept(self, visitor: MidasType.Visitor[T]) -> T: + def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_base_type(self) @@ -41,7 +52,7 @@ class ConstraintType(MidasType): type: MidasType constraint: ast.expr - def accept(self, visitor: MidasType.Visitor[T]) -> T: + def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_constraint_type(self) @@ -50,7 +61,7 @@ class FrameColumn(MidasType): name: Optional[str] type: Optional[MidasType] - def accept(self, visitor: MidasType.Visitor[T]) -> T: + def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_frame_column(self) @@ -58,5 +69,26 @@ class FrameColumn(MidasType): class FrameType(MidasType): columns: list[FrameColumn] - def accept(self, visitor: MidasType.Visitor[T]) -> T: + def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_frame_type(self) + + +@dataclass(frozen=True) +class Function(Expr): + name: str + posonlyargs: list[FunctionArgument] + args: list[FunctionArgument] + kwonlyargs: list[FunctionArgument] + returns: Optional[MidasType] + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_function(self) + + +@dataclass(frozen=True) +class FunctionArgument(Expr): + name: Optional[str] + type: Optional[MidasType] + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_function_argument(self) diff --git a/midas/cli/main.py b/midas/cli/main.py index 3c033f0..65ed210 100644 --- a/midas/cli/main.py +++ b/midas/cli/main.py @@ -44,6 +44,11 @@ def dump_ast(output: Optional[TextIO], parse: bool, file: TextIO): else: dump += printer.print(annotation) dump += "\n" + + dump += "\n# Functions\n\n" + + for func in parser.functions: + dump += printer.print(func) + "\n" else: dump = ast.dump(tree, indent=4) diff --git a/midas/parser/python.py b/midas/parser/python.py index 6139801..dc24022 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -1,7 +1,15 @@ import ast from typing import Any, Optional -from midas.ast.python import BaseType, ConstraintType, FrameColumn, FrameType, MidasType +from midas.ast.python import ( + BaseType, + ConstraintType, + FrameColumn, + FrameType, + Function, + FunctionArgument, + MidasType, +) class InvalidSyntaxError(Exception): @@ -20,6 +28,7 @@ class PythonParser(ast.NodeVisitor): super().__init__() self.annotations: list[tuple[str, Optional[MidasType]]] = [] + self.functions: list[Function] = [] def visit_AnnAssign(self, node: ast.AnnAssign) -> Any: match node: @@ -33,6 +42,43 @@ class PythonParser(ast.NodeVisitor): case _: print(f"Unsupported annotation: {ast.unparse(node)}") + def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: + self.functions.append(self._parse_function(node)) + + # Call visit on children to process body + # TODO: scope the resulting nodes to the function + self.generic_visit(node) + + def _parse_function(self, node: ast.FunctionDef) -> Function: + match node: + case ast.FunctionDef( + name=name, + args=ast.arguments( + posonlyargs=posonlyargs, + args=args, + kwonlyargs=kwonlyargs, + ), + returns=returns, + ): + + def parse_args(args_list: list[ast.arg]) -> list[FunctionArgument]: + return [self._parse_function_argument(arg) for arg in args_list] + + return Function( + name=name, + posonlyargs=parse_args(posonlyargs), + args=parse_args(args), + kwonlyargs=parse_args(kwonlyargs), + returns=self._parse_type(returns) if returns is not None else None, + ) + + def _parse_function_argument(self, arg: ast.arg) -> FunctionArgument: + name: str = arg.arg + type: Optional[MidasType] = None + if arg.annotation is not None: + type = self._parse_type(arg.annotation) + return FunctionArgument(name=name, type=type) + def _parse_type( self, type_expr: ast.expr, root: bool = False ) -> Optional[MidasType]: @@ -50,7 +96,7 @@ class PythonParser(ast.NodeVisitor): left = self._parse_type(left_expr) match left: case None: - raise InvalidSyntaxError("") + raise InvalidSyntaxError() # If chained constraints, separate base type and rebuild constraint case ConstraintType(type=left_type, constraint=left_constraint):