feat(parser): generate python AST classes

use the generation script to create Python AST node classes, also distinguish between Midas type annotation nodes and statements
This commit is contained in:
2026-05-25 20:53:36 +02:00
parent 939e5af4ce
commit 0bbdf04621
5 changed files with 122 additions and 56 deletions

View File

@@ -48,7 +48,7 @@ class {cls}({base}):
""" """
SECTION_REGEX = re.compile( SECTION_REGEX = re.compile(
r"^###>\s*(?P<base>[^\n]*?)\s*\|\s*(?P<name>[^\n]*?)\s*?\n(?P<body>.*?)\n###<$", r"^###>\s*(?P<base>[^\n]*?)\s*\|\s*(?P<name>[^\n]*?)(\s*\|\s*(?P<param>[^\n]*?))?\s*?\n(?P<body>.*?)\n###<$",
re.MULTILINE | re.DOTALL, re.MULTILINE | re.DOTALL,
) )
@@ -87,15 +87,15 @@ def make_banner(text: str) -> str:
return "\n".join((rule, middle, rule)) return "\n".join((rule, middle, rule))
def make_section(full_name: str, base: str, body: str) -> str: def make_section(full_name: str, base: str, param: str, body: str) -> str:
visitor_methods: list[str] = [] visitor_methods: list[str] = []
classes: list[str] = [] classes: list[str] = []
definitions: list[str] = body.strip("\n").split("\n\n") definitions: list[str] = body.strip("\n").split("\n\n\n")
for cls in definitions: for cls in definitions:
cls = cls.strip("\n") cls = cls.strip("\n")
name: str = re.match("class (.*?):", cls).group(1) # type: ignore name: str = re.match("class (.*?):", cls).group(1) # type: ignore
print(f"Processing {name}") print(f"Processing {name}")
visitor_methods.append(make_visitor_method(name, base.lower())) visitor_methods.append(make_visitor_method(name, param))
classes.append(make_class(name, cls, base)) classes.append(make_class(name, cls, base))
return SECTION_TEMPLATE.format( return SECTION_TEMPLATE.format(
@@ -119,8 +119,9 @@ def generate(definitions_path: Path, out_path: Path):
for section_m in SECTION_REGEX.finditer(src): for section_m in SECTION_REGEX.finditer(src):
full_name: str = section_m.group("name") full_name: str = section_m.group("name")
base: str = section_m.group("base") base: str = section_m.group("base")
param: str = section_m.group("param") or base.lower()
body: str = section_m.group("body") body: str = section_m.group("body")
sections.append(make_section(full_name, base, body)) sections.append(make_section(full_name, base, param, body))
result: str = TEMPLATE.format( result: str = TEMPLATE.format(
header=HEADER.format( header=HEADER.format(
@@ -138,6 +139,7 @@ def main():
defs_dir: Path = root / "gen" defs_dir: Path = root / "gen"
ast_dir: Path = root / "midas" / "ast" ast_dir: Path = root / "midas" / "ast"
generate(defs_dir / "midas.py", ast_dir / "midas.py") generate(defs_dir / "midas.py", ast_dir / "midas.py")
generate(defs_dir / "python.py", ast_dir / "python.py")
if __name__ == "__main__": if __name__ == "__main__":

53
gen/python.py Normal file
View File

@@ -0,0 +1,53 @@
# type: ignore
# ruff: disable[F821, F401]
###> Imports
import ast
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, Optional, TypeVar
from midas.ast.location import Location
###<
###> MidasType | Type annotations | node
class BaseType:
base: str
param: Optional[MidasType]
class ConstraintType:
type: MidasType
constraint: ast.expr
class FrameColumn:
name: Optional[str]
type: Optional[MidasType]
class FrameType:
columns: list[FrameColumn]
###<
###> Stmt | Statements
class Function:
name: str
posonlyargs: list[Argument]
args: list[Argument]
kwonlyargs: list[Argument]
returns: Optional[MidasType]
@dataclass(frozen=True, kw_only=True)
class Argument:
location: Optional[Location] = None
name: Optional[str]
type: Optional[MidasType]
###<

View File

@@ -350,7 +350,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}" return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}"
class PythonAstPrinter(AstPrinter, p.Expr.Visitor[None]): class PythonAstPrinter(AstPrinter, p.MidasType.Visitor[None], p.Stmt.Visitor[None]):
def visit_base_type(self, node: p.BaseType) -> None: def visit_base_type(self, node: p.BaseType) -> None:
self._write_line("BaseType") self._write_line("BaseType")
with self._child_level(): with self._child_level():
@@ -382,39 +382,39 @@ class PythonAstPrinter(AstPrinter, p.Expr.Visitor[None]):
self._mark_last() self._mark_last()
col.accept(self) col.accept(self)
def visit_function(self, node: p.Function) -> None: def visit_function(self, stmt: p.Function) -> None:
self._write_line("Function") self._write_line("Function")
with self._child_level(): with self._child_level():
self._write_line(f"name: {node.name}") self._write_line(f"name: {stmt.name}")
self._write_line("posonlyargs") self._write_line("posonlyargs")
with self._child_level(): with self._child_level():
for i, arg in enumerate(node.posonlyargs): for i, arg in enumerate(stmt.posonlyargs):
self._idx = i self._idx = i
if i == len(node.posonlyargs) - 1: if i == len(stmt.posonlyargs) - 1:
self._mark_last() self._mark_last()
arg.accept(self) self._print_argument(arg)
self._write_line("args") self._write_line("args")
with self._child_level(): with self._child_level():
for i, arg in enumerate(node.args): for i, arg in enumerate(stmt.args):
self._idx = i self._idx = i
if i == len(node.args) - 1: if i == len(stmt.args) - 1:
self._mark_last() self._mark_last()
arg.accept(self) self._print_argument(arg)
self._write_line("kwonlyargs") self._write_line("kwonlyargs")
with self._child_level(): with self._child_level():
for i, arg in enumerate(node.kwonlyargs): for i, arg in enumerate(stmt.kwonlyargs):
self._idx = i self._idx = i
if i == len(node.kwonlyargs) - 1: if i == len(stmt.kwonlyargs) - 1:
self._mark_last() self._mark_last()
arg.accept(self) self._print_argument(arg)
self._write_optional_child("returns", node.returns, last=True) self._write_optional_child("returns", stmt.returns, last=True)
def visit_function_argument(self, node: p.FunctionArgument) -> None: def _print_argument(self, arg: p.Function.Argument) -> None:
self._write_line("FunctionArgument") self._write_line("FunctionArgument")
with self._child_level(): with self._child_level():
self._write_line(f"name: {node.name}") self._write_line(f"name: {arg.name}")
self._write_optional_child("type", node.type, last=True) self._write_optional_child("type", arg.type, last=True)

View File

@@ -1,7 +1,12 @@
"""
This file was generated by a script. Any manual changes might be overwritten.
Please modify gen/python.py instead and run gen/gen.py
"""
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod
import ast import ast
from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, Optional, TypeVar from typing import Generic, Optional, TypeVar
@@ -9,9 +14,13 @@ from midas.ast.location import Location
T = TypeVar("T") T = TypeVar("T")
####################
# Type annotations #
####################
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Expr(ABC): class MidasType(ABC):
location: Optional[Location] = None location: Optional[Location] = None
@abstractmethod @abstractmethod
@@ -30,24 +39,13 @@ class Expr(ABC):
@abstractmethod @abstractmethod
def visit_frame_type(self, node: FrameType) -> T: ... def visit_frame_type(self, node: FrameType) -> T: ...
@abstractmethod
def visit_function(self, node: Function) -> T: ...
@abstractmethod
def visit_function_argument(self, node: FunctionArgument) -> T: ...
@dataclass(frozen=True)
class MidasType(Expr):
pass
@dataclass(frozen=True) @dataclass(frozen=True)
class BaseType(MidasType): class BaseType(MidasType):
base: str base: str
param: Optional[MidasType] param: Optional[MidasType]
def accept(self, visitor: Expr.Visitor[T]) -> T: def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_base_type(self) return visitor.visit_base_type(self)
@@ -56,7 +54,7 @@ class ConstraintType(MidasType):
type: MidasType type: MidasType
constraint: ast.expr constraint: ast.expr
def accept(self, visitor: Expr.Visitor[T]) -> T: def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_constraint_type(self) return visitor.visit_constraint_type(self)
@@ -65,7 +63,7 @@ class FrameColumn(MidasType):
name: Optional[str] name: Optional[str]
type: Optional[MidasType] type: Optional[MidasType]
def accept(self, visitor: Expr.Visitor[T]) -> T: def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_frame_column(self) return visitor.visit_frame_column(self)
@@ -73,26 +71,40 @@ class FrameColumn(MidasType):
class FrameType(MidasType): class FrameType(MidasType):
columns: list[FrameColumn] columns: list[FrameColumn]
def accept(self, visitor: Expr.Visitor[T]) -> T: def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_frame_type(self) return visitor.visit_frame_type(self)
##############
# Statements #
##############
@dataclass(frozen=True, kw_only=True)
class Stmt(ABC):
location: Optional[Location] = None
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_function(self, stmt: Function) -> T: ...
@dataclass(frozen=True) @dataclass(frozen=True)
class Function(Expr): class Function(Stmt):
name: str name: str
posonlyargs: list[FunctionArgument] posonlyargs: list[Argument]
args: list[FunctionArgument] args: list[Argument]
kwonlyargs: list[FunctionArgument] kwonlyargs: list[Argument]
returns: Optional[MidasType] returns: Optional[MidasType]
def accept(self, visitor: Expr.Visitor[T]) -> T: @dataclass(frozen=True, kw_only=True)
class Argument:
location: Optional[Location] = None
name: Optional[str]
type: Optional[MidasType]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_function(self) return visitor.visit_function(self)
@dataclass(frozen=True)
class FunctionArgument(Expr):
name: Optional[str]
type: Optional[MidasType]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_function_argument(self)

View File

@@ -8,7 +8,6 @@ from midas.ast.python import (
FrameColumn, FrameColumn,
FrameType, FrameType,
Function, Function,
FunctionArgument,
MidasType, MidasType,
) )
@@ -63,7 +62,7 @@ class PythonParser(ast.NodeVisitor):
returns=returns, returns=returns,
): ):
def parse_args(args_list: list[ast.arg]) -> list[FunctionArgument]: def parse_args(args_list: list[ast.arg]) -> list[Function.Argument]:
return [self._parse_function_argument(arg) for arg in args_list] return [self._parse_function_argument(arg) for arg in args_list]
return Function( return Function(
@@ -75,13 +74,13 @@ class PythonParser(ast.NodeVisitor):
returns=self._parse_type(returns) if returns is not None else None, returns=self._parse_type(returns) if returns is not None else None,
) )
def _parse_function_argument(self, arg: ast.arg) -> FunctionArgument: def _parse_function_argument(self, arg: ast.arg) -> Function.Argument:
loc: Location = Location.from_ast(arg) loc: Location = Location.from_ast(arg)
name: str = arg.arg name: str = arg.arg
type: Optional[MidasType] = None type: Optional[MidasType] = None
if arg.annotation is not None: if arg.annotation is not None:
type = self._parse_type(arg.annotation) type = self._parse_type(arg.annotation)
return FunctionArgument( return Function.Argument(
location=loc, location=loc,
name=name, name=name,
type=type, type=type,