refactor(parser): improve AST class generator

make the generation script more flexible
This commit is contained in:
2026-05-25 20:38:38 +02:00
parent a735113466
commit 939e5af4ce
3 changed files with 156 additions and 110 deletions

View File

@@ -3,58 +3,34 @@ import re
HEADER = '''"""
This file was generated by a script. Any manual changes might be overwritten.
Please modify gen/ast.py instead and run gen/gen.py
Please modify {defs_path} instead and run {gen_path}
"""'''
SECTION_TEMPLATE = """{banner}
@dataclass(frozen=True, kw_only=True)
class {base}(ABC):
location: Optional[Location] = None
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
{visitor_methods}
{classes}"""
TEMPLATE = """{header}
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
from midas.ast.location import Location
from midas.lexer.token import Token
{imports}
T = TypeVar("T")
##############
# Statements #
##############
@dataclass(frozen=True, kw_only=True)
class Stmt(ABC):
location: Optional[Location] = None
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
{stmt_visitor_methods}
{statements}
###############
# Expressions #
###############
@dataclass(frozen=True, kw_only=True)
class Expr(ABC):
location: Optional[Location] = None
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
{expr_visitor_methods}
{expressions}
{sections}
"""
VISITOR_METHOD_TEMPLATE = """
@@ -71,6 +47,16 @@ class {cls}({base}):
return visitor.visit_{func_name}(self)
"""
SECTION_REGEX = re.compile(
r"^###>\s*(?P<base>[^\n]*?)\s*\|\s*(?P<name>[^\n]*?)\s*?\n(?P<body>.*?)\n###<$",
re.MULTILINE | re.DOTALL,
)
IMPORTS_REGEX = re.compile(
r"^###>\s*Imports\s*?\n(?P<body>.*?)\n###<$",
re.MULTILINE | re.DOTALL,
)
def snake_case(text: str) -> str:
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
@@ -95,41 +81,63 @@ def make_class(name: str, cls: str, base: str):
return cls_def.strip("\n")
def generate(src: str):
classes: list[str] = src.split("\n\n")
stmt_visitor_methods: list[str] = []
expr_visitor_methods: list[str] = []
statements: list[str] = []
expressions: list[str] = []
def make_banner(text: str) -> str:
middle: str = f"# {text} #"
rule: str = "#" * len(middle)
return "\n".join((rule, middle, rule))
for cls in classes:
def make_section(full_name: str, base: str, body: str) -> str:
visitor_methods: list[str] = []
classes: list[str] = []
definitions: list[str] = body.strip("\n").split("\n\n")
for cls in definitions:
cls = cls.strip("\n")
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
print(f"Processing {name}")
if name.endswith("Stmt"):
stmt_visitor_methods.append(make_visitor_method(name, "stmt"))
statements.append(make_class(name, cls, "Stmt"))
elif name.endswith("Expr"):
expr_visitor_methods.append(make_visitor_method(name, "expr"))
expressions.append(make_class(name, cls, "Expr"))
visitor_methods.append(make_visitor_method(name, base.lower()))
classes.append(make_class(name, cls, base))
return TEMPLATE.format(
header=HEADER,
stmt_visitor_methods="\n\n".join(stmt_visitor_methods),
expr_visitor_methods="\n\n".join(expr_visitor_methods),
statements="\n\n\n".join(statements),
expressions="\n\n\n".join(expressions),
return SECTION_TEMPLATE.format(
banner=make_banner(full_name),
base=base,
visitor_methods="\n\n".join(visitor_methods),
classes="\n\n\n".join(classes),
)
def generate(definitions_path: Path, out_path: Path):
root_dir: Path = Path(__file__).parent.parent
rel_path: Path = definitions_path.relative_to(root_dir)
src: str = definitions_path.read_text()
sections: list[str] = []
imports: str = ""
if m := IMPORTS_REGEX.search(src):
imports = m.group("body").strip("\n")
for section_m in SECTION_REGEX.finditer(src):
full_name: str = section_m.group("name")
base: str = section_m.group("base")
body: str = section_m.group("body")
sections.append(make_section(full_name, base, body))
result: str = TEMPLATE.format(
header=HEADER.format(
defs_path=rel_path,
gen_path=Path(__file__).relative_to(root_dir),
),
imports=imports,
sections="\n\n\n".join(sections),
)
out_path.write_text(result)
def main():
root: Path = Path(__file__).parent.parent
in_path: Path = root / "gen" / "ast.py"
out_path: Path = root / "midas" / "ast" / "midas.py"
src: str = in_path.read_text()
generated: str = generate(src)
out_path.write_text(generated)
defs_dir: Path = root / "gen"
ast_dir: Path = root / "midas" / "ast"
generate(defs_dir / "midas.py", ast_dir / "midas.py")
if __name__ == "__main__":

View File

@@ -1,72 +1,110 @@
# type: ignore
# ruff: disable[F821, F401]
###> Imports
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
from midas.ast.location import Location
from midas.lexer.token import Token
###<
###> Stmt | Statements
class SimpleTypeStmt:
name: Token
template: Optional[TemplateExpr]
base: TypeExpr
constraint: Optional[Expr]
class SimpleTypeExpr:
name: Token
optional: bool
class LogicalExpr:
left: Expr
operator: Token
right: Expr
class BinaryExpr:
left: Expr
operator: Token
right: Expr
class UnaryExpr:
operator: Token
right: Expr
class GetExpr:
expr: Expr
name: Token
class VariableExpr:
name: Token
class GroupingExpr:
expr: Expr
class LiteralExpr:
value: Any
class WildcardExpr:
token: Token
class TemplateExpr:
type: TypeExpr
class TypeExpr:
name: Token
template: Optional[TemplateExpr]
optional: bool
class ComplexTypeStmt:
name: Token
template: Optional[TemplateExpr]
properties: list[PropertyStmt]
class PropertyStmt:
name: Token
type: TypeExpr
constraint: Optional[Expr]
class ExtendStmt:
type: TypeExpr
operations: list[OpStmt]
class OpStmt:
name: Token
operand: TypeExpr
result: TypeExpr
class PredicateStmt:
name: Token
subject: Token
type: TypeExpr
condition: Expr
###<
###> Expr | Expressions
class SimpleTypeExpr:
name: Token
optional: bool
class LogicalExpr:
left: Expr
operator: Token
right: Expr
class BinaryExpr:
left: Expr
operator: Token
right: Expr
class UnaryExpr:
operator: Token
right: Expr
class GetExpr:
expr: Expr
name: Token
class VariableExpr:
name: Token
class GroupingExpr:
expr: Expr
class LiteralExpr:
value: Any
class WildcardExpr:
token: Token
class TemplateExpr:
type: TypeExpr
class TypeExpr:
name: Token
template: Optional[TemplateExpr]
optional: bool
###<

View File

@@ -1,6 +1,6 @@
"""
This file was generated by a script. Any manual changes might be overwritten.
Please modify gen/ast.py instead and run gen/gen.py
Please modify gen/midas.py instead and run gen/gen.py
"""
from __future__ import annotations