From 0bbdf04621ce2e3c68bbc961266dfaaecdf66059 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Mon, 25 May 2026 20:53:36 +0200 Subject: [PATCH] feat(parser): generate python AST classes use the generation script to create Python AST node classes, also distinguish between Midas type annotation nodes and statements --- gen/gen.py | 12 ++++--- gen/python.py | 53 ++++++++++++++++++++++++++++++ midas/ast/printer.py | 32 +++++++++--------- midas/ast/python.py | 74 ++++++++++++++++++++++++------------------ midas/parser/python.py | 7 ++-- 5 files changed, 122 insertions(+), 56 deletions(-) create mode 100644 gen/python.py diff --git a/gen/gen.py b/gen/gen.py index 03f38d0..75e6100 100644 --- a/gen/gen.py +++ b/gen/gen.py @@ -48,7 +48,7 @@ class {cls}({base}): """ SECTION_REGEX = re.compile( - r"^###>\s*(?P[^\n]*?)\s*\|\s*(?P[^\n]*?)\s*?\n(?P.*?)\n###<$", + r"^###>\s*(?P[^\n]*?)\s*\|\s*(?P[^\n]*?)(\s*\|\s*(?P[^\n]*?))?\s*?\n(?P.*?)\n###<$", re.MULTILINE | re.DOTALL, ) @@ -87,15 +87,15 @@ def make_banner(text: str) -> str: return "\n".join((rule, middle, rule)) -def make_section(full_name: str, base: str, body: str) -> str: +def make_section(full_name: str, base: str, param: str, body: str) -> str: visitor_methods: list[str] = [] classes: list[str] = [] - definitions: list[str] = body.strip("\n").split("\n\n") + definitions: list[str] = body.strip("\n").split("\n\n\n") for cls in definitions: cls = cls.strip("\n") name: str = re.match("class (.*?):", cls).group(1) # type: ignore print(f"Processing {name}") - visitor_methods.append(make_visitor_method(name, base.lower())) + visitor_methods.append(make_visitor_method(name, param)) classes.append(make_class(name, cls, base)) return SECTION_TEMPLATE.format( @@ -119,8 +119,9 @@ def generate(definitions_path: Path, out_path: Path): for section_m in SECTION_REGEX.finditer(src): full_name: str = section_m.group("name") base: str = section_m.group("base") + param: str = section_m.group("param") or base.lower() body: str = section_m.group("body") - sections.append(make_section(full_name, base, body)) + sections.append(make_section(full_name, base, param, body)) result: str = TEMPLATE.format( header=HEADER.format( @@ -138,6 +139,7 @@ def main(): defs_dir: Path = root / "gen" ast_dir: Path = root / "midas" / "ast" generate(defs_dir / "midas.py", ast_dir / "midas.py") + generate(defs_dir / "python.py", ast_dir / "python.py") if __name__ == "__main__": diff --git a/gen/python.py b/gen/python.py new file mode 100644 index 0000000..15df1d9 --- /dev/null +++ b/gen/python.py @@ -0,0 +1,53 @@ +# type: ignore +# ruff: disable[F821, F401] + +###> Imports +import ast +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Generic, Optional, TypeVar + +from midas.ast.location import Location + +###< + + +###> MidasType | Type annotations | node +class BaseType: + base: str + param: Optional[MidasType] + + +class ConstraintType: + type: MidasType + constraint: ast.expr + + +class FrameColumn: + name: Optional[str] + type: Optional[MidasType] + + +class FrameType: + columns: list[FrameColumn] + + +###< + + +###> Stmt | Statements +class Function: + name: str + posonlyargs: list[Argument] + args: list[Argument] + kwonlyargs: list[Argument] + returns: Optional[MidasType] + + @dataclass(frozen=True, kw_only=True) + class Argument: + location: Optional[Location] = None + name: Optional[str] + type: Optional[MidasType] + + +###< diff --git a/midas/ast/printer.py b/midas/ast/printer.py index b92e40f..2d38241 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -350,7 +350,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}" -class PythonAstPrinter(AstPrinter, p.Expr.Visitor[None]): +class PythonAstPrinter(AstPrinter, p.MidasType.Visitor[None], p.Stmt.Visitor[None]): def visit_base_type(self, node: p.BaseType) -> None: self._write_line("BaseType") with self._child_level(): @@ -382,39 +382,39 @@ class PythonAstPrinter(AstPrinter, p.Expr.Visitor[None]): self._mark_last() col.accept(self) - def visit_function(self, node: p.Function) -> None: + def visit_function(self, stmt: p.Function) -> None: self._write_line("Function") with self._child_level(): - self._write_line(f"name: {node.name}") + self._write_line(f"name: {stmt.name}") self._write_line("posonlyargs") with self._child_level(): - for i, arg in enumerate(node.posonlyargs): + for i, arg in enumerate(stmt.posonlyargs): self._idx = i - if i == len(node.posonlyargs) - 1: + if i == len(stmt.posonlyargs) - 1: self._mark_last() - arg.accept(self) + self._print_argument(arg) self._write_line("args") with self._child_level(): - for i, arg in enumerate(node.args): + for i, arg in enumerate(stmt.args): self._idx = i - if i == len(node.args) - 1: + if i == len(stmt.args) - 1: self._mark_last() - arg.accept(self) + self._print_argument(arg) self._write_line("kwonlyargs") with self._child_level(): - for i, arg in enumerate(node.kwonlyargs): + for i, arg in enumerate(stmt.kwonlyargs): self._idx = i - if i == len(node.kwonlyargs) - 1: + if i == len(stmt.kwonlyargs) - 1: self._mark_last() - arg.accept(self) + self._print_argument(arg) - self._write_optional_child("returns", node.returns, last=True) + self._write_optional_child("returns", stmt.returns, last=True) - def visit_function_argument(self, node: p.FunctionArgument) -> None: + def _print_argument(self, arg: p.Function.Argument) -> None: self._write_line("FunctionArgument") with self._child_level(): - self._write_line(f"name: {node.name}") - self._write_optional_child("type", node.type, last=True) + self._write_line(f"name: {arg.name}") + self._write_optional_child("type", arg.type, last=True) diff --git a/midas/ast/python.py b/midas/ast/python.py index c25b438..cd120ee 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -1,7 +1,12 @@ +""" +This file was generated by a script. Any manual changes might be overwritten. +Please modify gen/python.py instead and run gen/gen.py +""" + from __future__ import annotations -from abc import ABC, abstractmethod import ast +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Generic, Optional, TypeVar @@ -9,9 +14,13 @@ from midas.ast.location import Location T = TypeVar("T") +#################### +# Type annotations # +#################### + @dataclass(frozen=True, kw_only=True) -class Expr(ABC): +class MidasType(ABC): location: Optional[Location] = None @abstractmethod @@ -30,24 +39,13 @@ class Expr(ABC): @abstractmethod def visit_frame_type(self, node: FrameType) -> T: ... - @abstractmethod - def visit_function(self, node: Function) -> T: ... - - @abstractmethod - def visit_function_argument(self, node: FunctionArgument) -> T: ... - - -@dataclass(frozen=True) -class MidasType(Expr): - pass - @dataclass(frozen=True) class BaseType(MidasType): base: str param: Optional[MidasType] - def accept(self, visitor: Expr.Visitor[T]) -> T: + def accept(self, visitor: MidasType.Visitor[T]) -> T: return visitor.visit_base_type(self) @@ -56,7 +54,7 @@ class ConstraintType(MidasType): type: MidasType constraint: ast.expr - def accept(self, visitor: Expr.Visitor[T]) -> T: + def accept(self, visitor: MidasType.Visitor[T]) -> T: return visitor.visit_constraint_type(self) @@ -65,7 +63,7 @@ class FrameColumn(MidasType): name: Optional[str] type: Optional[MidasType] - def accept(self, visitor: Expr.Visitor[T]) -> T: + def accept(self, visitor: MidasType.Visitor[T]) -> T: return visitor.visit_frame_column(self) @@ -73,26 +71,40 @@ class FrameColumn(MidasType): class FrameType(MidasType): columns: list[FrameColumn] - def accept(self, visitor: Expr.Visitor[T]) -> T: + def accept(self, visitor: MidasType.Visitor[T]) -> T: return visitor.visit_frame_type(self) +############## +# Statements # +############## + + +@dataclass(frozen=True, kw_only=True) +class Stmt(ABC): + location: Optional[Location] = None + + @abstractmethod + def accept(self, visitor: Visitor[T]) -> T: ... + + class Visitor(ABC, Generic[T]): + @abstractmethod + def visit_function(self, stmt: Function) -> T: ... + + @dataclass(frozen=True) -class Function(Expr): +class Function(Stmt): name: str - posonlyargs: list[FunctionArgument] - args: list[FunctionArgument] - kwonlyargs: list[FunctionArgument] + posonlyargs: list[Argument] + args: list[Argument] + kwonlyargs: list[Argument] returns: Optional[MidasType] - def accept(self, visitor: Expr.Visitor[T]) -> T: + @dataclass(frozen=True, kw_only=True) + class Argument: + location: Optional[Location] = None + name: Optional[str] + type: Optional[MidasType] + + def accept(self, visitor: Stmt.Visitor[T]) -> T: return visitor.visit_function(self) - - -@dataclass(frozen=True) -class FunctionArgument(Expr): - name: Optional[str] - type: Optional[MidasType] - - def accept(self, visitor: Expr.Visitor[T]) -> T: - return visitor.visit_function_argument(self) diff --git a/midas/parser/python.py b/midas/parser/python.py index 6e0ffe1..082cab1 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -8,7 +8,6 @@ from midas.ast.python import ( FrameColumn, FrameType, Function, - FunctionArgument, MidasType, ) @@ -63,7 +62,7 @@ class PythonParser(ast.NodeVisitor): returns=returns, ): - def parse_args(args_list: list[ast.arg]) -> list[FunctionArgument]: + def parse_args(args_list: list[ast.arg]) -> list[Function.Argument]: return [self._parse_function_argument(arg) for arg in args_list] return Function( @@ -75,13 +74,13 @@ class PythonParser(ast.NodeVisitor): returns=self._parse_type(returns) if returns is not None else None, ) - def _parse_function_argument(self, arg: ast.arg) -> FunctionArgument: + def _parse_function_argument(self, arg: ast.arg) -> Function.Argument: loc: Location = Location.from_ast(arg) name: str = arg.arg type: Optional[MidasType] = None if arg.annotation is not None: type = self._parse_type(arg.annotation) - return FunctionArgument( + return Function.Argument( location=loc, name=name, type=type,