feat(parser): parse functions in python

This commit is contained in:
2026-05-22 19:32:15 +02:00
parent 8d7c115432
commit 5aedddfabb
5 changed files with 143 additions and 8 deletions

View File

@@ -0,0 +1,15 @@
# type: ignore
# ruff: disable[F821]
from __future__ import annotations
def func(
col1: Column[float + (0 <= _ <= 1)],
col2: Column[float + (0 <= _ <= 1)],
) -> Column[float + (0 <= _ <= 2)]:
result: Column[float + (0 <= _ <= 2)] = col1 + col2
return result
def func2(a: int, /, b: float, *, c: str):
pass

View File

@@ -350,7 +350,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}" return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}"
class PythonAstPrinter(AstPrinter, p.MidasType.Visitor[None]): class PythonAstPrinter(AstPrinter, p.Expr.Visitor[None]):
def visit_base_type(self, node: p.BaseType) -> None: def visit_base_type(self, node: p.BaseType) -> None:
self._write_line("BaseType") self._write_line("BaseType")
with self._child_level(): with self._child_level():
@@ -381,3 +381,40 @@ class PythonAstPrinter(AstPrinter, p.MidasType.Visitor[None]):
if i == len(node.columns) - 1: if i == len(node.columns) - 1:
self._mark_last() self._mark_last()
col.accept(self) col.accept(self)
def visit_function(self, node: p.Function) -> None:
self._write_line("Function")
with self._child_level():
self._write_line(f"name: {node.name}")
self._write_line("posonlyargs")
with self._child_level():
for i, arg in enumerate(node.posonlyargs):
self._idx = i
if i == len(node.posonlyargs) - 1:
self._mark_last()
arg.accept(self)
self._write_line("args")
with self._child_level():
for i, arg in enumerate(node.args):
self._idx = i
if i == len(node.args) - 1:
self._mark_last()
arg.accept(self)
self._write_line("kwonlyargs")
with self._child_level():
for i, arg in enumerate(node.kwonlyargs):
self._idx = i
if i == len(node.kwonlyargs) - 1:
self._mark_last()
arg.accept(self)
self._write_optional_child("returns", node.returns, last=True)
def visit_function_argument(self, node: p.FunctionArgument) -> None:
self._write_line("FunctionArgument")
with self._child_level():
self._write_line(f"name: {node.name}")
self._write_optional_child("type", node.type, last=True)

View File

@@ -9,7 +9,7 @@ T = TypeVar("T")
@dataclass(frozen=True) @dataclass(frozen=True)
class MidasType(ABC): class Expr(ABC):
@abstractmethod @abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ... def accept(self, visitor: Visitor[T]) -> T: ...
@@ -26,13 +26,24 @@ class MidasType(ABC):
@abstractmethod @abstractmethod
def visit_frame_type(self, node: FrameType) -> T: ... def visit_frame_type(self, node: FrameType) -> T: ...
@abstractmethod
def visit_function(self, node: Function) -> T: ...
@abstractmethod
def visit_function_argument(self, node: FunctionArgument) -> T: ...
@dataclass(frozen=True)
class MidasType(Expr):
pass
@dataclass(frozen=True) @dataclass(frozen=True)
class BaseType(MidasType): class BaseType(MidasType):
base: str base: str
param: Optional[MidasType] param: Optional[MidasType]
def accept(self, visitor: MidasType.Visitor[T]) -> T: def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_base_type(self) return visitor.visit_base_type(self)
@@ -41,7 +52,7 @@ class ConstraintType(MidasType):
type: MidasType type: MidasType
constraint: ast.expr constraint: ast.expr
def accept(self, visitor: MidasType.Visitor[T]) -> T: def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_constraint_type(self) return visitor.visit_constraint_type(self)
@@ -50,7 +61,7 @@ class FrameColumn(MidasType):
name: Optional[str] name: Optional[str]
type: Optional[MidasType] type: Optional[MidasType]
def accept(self, visitor: MidasType.Visitor[T]) -> T: def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_frame_column(self) return visitor.visit_frame_column(self)
@@ -58,5 +69,26 @@ class FrameColumn(MidasType):
class FrameType(MidasType): class FrameType(MidasType):
columns: list[FrameColumn] columns: list[FrameColumn]
def accept(self, visitor: MidasType.Visitor[T]) -> T: def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_frame_type(self) return visitor.visit_frame_type(self)
@dataclass(frozen=True)
class Function(Expr):
name: str
posonlyargs: list[FunctionArgument]
args: list[FunctionArgument]
kwonlyargs: list[FunctionArgument]
returns: Optional[MidasType]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_function(self)
@dataclass(frozen=True)
class FunctionArgument(Expr):
name: Optional[str]
type: Optional[MidasType]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_function_argument(self)

View File

@@ -44,6 +44,11 @@ def dump_ast(output: Optional[TextIO], parse: bool, file: TextIO):
else: else:
dump += printer.print(annotation) dump += printer.print(annotation)
dump += "\n" dump += "\n"
dump += "\n# Functions\n\n"
for func in parser.functions:
dump += printer.print(func) + "\n"
else: else:
dump = ast.dump(tree, indent=4) dump = ast.dump(tree, indent=4)

View File

@@ -1,7 +1,15 @@
import ast import ast
from typing import Any, Optional from typing import Any, Optional
from midas.ast.python import BaseType, ConstraintType, FrameColumn, FrameType, MidasType from midas.ast.python import (
BaseType,
ConstraintType,
FrameColumn,
FrameType,
Function,
FunctionArgument,
MidasType,
)
class InvalidSyntaxError(Exception): class InvalidSyntaxError(Exception):
@@ -20,6 +28,7 @@ class PythonParser(ast.NodeVisitor):
super().__init__() super().__init__()
self.annotations: list[tuple[str, Optional[MidasType]]] = [] self.annotations: list[tuple[str, Optional[MidasType]]] = []
self.functions: list[Function] = []
def visit_AnnAssign(self, node: ast.AnnAssign) -> Any: def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
match node: match node:
@@ -33,6 +42,43 @@ class PythonParser(ast.NodeVisitor):
case _: case _:
print(f"Unsupported annotation: {ast.unparse(node)}") print(f"Unsupported annotation: {ast.unparse(node)}")
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
self.functions.append(self._parse_function(node))
# Call visit on children to process body
# TODO: scope the resulting nodes to the function
self.generic_visit(node)
def _parse_function(self, node: ast.FunctionDef) -> Function:
match node:
case ast.FunctionDef(
name=name,
args=ast.arguments(
posonlyargs=posonlyargs,
args=args,
kwonlyargs=kwonlyargs,
),
returns=returns,
):
def parse_args(args_list: list[ast.arg]) -> list[FunctionArgument]:
return [self._parse_function_argument(arg) for arg in args_list]
return Function(
name=name,
posonlyargs=parse_args(posonlyargs),
args=parse_args(args),
kwonlyargs=parse_args(kwonlyargs),
returns=self._parse_type(returns) if returns is not None else None,
)
def _parse_function_argument(self, arg: ast.arg) -> FunctionArgument:
name: str = arg.arg
type: Optional[MidasType] = None
if arg.annotation is not None:
type = self._parse_type(arg.annotation)
return FunctionArgument(name=name, type=type)
def _parse_type( def _parse_type(
self, type_expr: ast.expr, root: bool = False self, type_expr: ast.expr, root: bool = False
) -> Optional[MidasType]: ) -> Optional[MidasType]:
@@ -50,7 +96,7 @@ class PythonParser(ast.NodeVisitor):
left = self._parse_type(left_expr) left = self._parse_type(left_expr)
match left: match left:
case None: case None:
raise InvalidSyntaxError("") raise InvalidSyntaxError()
# If chained constraints, separate base type and rebuild constraint # If chained constraints, separate base type and rebuild constraint
case ConstraintType(type=left_type, constraint=left_constraint): case ConstraintType(type=left_type, constraint=left_constraint):