tool: add AST class generator script
This commit is contained in:
128
gen/gen.py
Normal file
128
gen/gen.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
HEADER = '''"""
|
||||
This file was generated by a script. Any manual changes might be overwritten.
|
||||
Please modify gen/ast.py instead and run gen/gen.py
|
||||
"""'''
|
||||
|
||||
TEMPLATE = """{header}
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from lexer.token import Token
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
##############
|
||||
# Statements #
|
||||
##############
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Stmt(ABC):
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
{stmt_visitor_methods}
|
||||
|
||||
|
||||
{statements}
|
||||
|
||||
|
||||
###############
|
||||
# Expressions #
|
||||
###############
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Expr(ABC):
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
{expr_visitor_methods}
|
||||
|
||||
|
||||
{expressions}
|
||||
"""
|
||||
|
||||
VISITOR_METHOD_TEMPLATE = """
|
||||
@abstractmethod
|
||||
def visit_{func_name}(self, {param}: {cls}) -> T: ...
|
||||
"""
|
||||
|
||||
CLASS_TEMPLATE = """
|
||||
@dataclass(frozen=True)
|
||||
class {cls}({base}):
|
||||
{body}
|
||||
|
||||
def accept(self, visitor: {base}.Visitor[T]) -> T:
|
||||
return visitor.visit_{func_name}(self)
|
||||
"""
|
||||
|
||||
def snake_case(text: str) -> str:
|
||||
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
|
||||
|
||||
def make_visitor_method(cls: str, param: str):
|
||||
method: str = VISITOR_METHOD_TEMPLATE.format(
|
||||
func_name=snake_case(cls),
|
||||
param=param,
|
||||
cls=cls
|
||||
)
|
||||
return method.strip("\n")
|
||||
|
||||
def make_class(name: str, cls: str, base: str, param: str):
|
||||
body: str = cls.split("\n", 1)[1]
|
||||
func_name: str = snake_case(name)
|
||||
cls_def: str = CLASS_TEMPLATE.format(
|
||||
cls=name,
|
||||
base=base,
|
||||
body=body,
|
||||
func_name=func_name,
|
||||
)
|
||||
return cls_def.strip("\n")
|
||||
|
||||
def generate(src: str):
|
||||
classes: list[str] = src.split("\n\n")
|
||||
stmt_visitor_methods: list[str] = []
|
||||
expr_visitor_methods: list[str] = []
|
||||
statements: list[str] = []
|
||||
expressions: list[str] = []
|
||||
|
||||
for cls in classes:
|
||||
cls = cls.strip("\n")
|
||||
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
|
||||
print(f"Processing {name}")
|
||||
if name.endswith("Stmt"):
|
||||
stmt_visitor_methods.append(make_visitor_method(name, "stmt"))
|
||||
statements.append(make_class(name, cls, "Stmt", "stmt"))
|
||||
elif name.endswith("Expr"):
|
||||
expr_visitor_methods.append(make_visitor_method(name, "expr"))
|
||||
expressions.append(make_class(name, cls, "Expr", "expr"))
|
||||
|
||||
return TEMPLATE.format(
|
||||
header=HEADER,
|
||||
stmt_visitor_methods="\n\n".join(stmt_visitor_methods),
|
||||
expr_visitor_methods="\n\n".join(expr_visitor_methods),
|
||||
statements="\n\n\n".join(statements),
|
||||
expressions="\n\n\n".join(expressions),
|
||||
)
|
||||
|
||||
def main():
|
||||
root: Path = Path(__file__).parent.parent
|
||||
in_path: Path = root / "gen" / "ast.py"
|
||||
out_path: Path = root / "core" / "ast" / "midas.py"
|
||||
|
||||
src: str = in_path.read_text()
|
||||
generated: str = generate(src)
|
||||
out_path.write_text(generated)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user