132 lines
3.0 KiB
Python
132 lines
3.0 KiB
Python
from pathlib import Path
|
|
import re
|
|
|
|
HEADER = '''"""
|
|
This file was generated by a script. Any manual changes might be overwritten.
|
|
Please modify gen/ast.py instead and run gen/gen.py
|
|
"""'''
|
|
|
|
TEMPLATE = """{header}
|
|
|
|
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import Any, Generic, Optional, TypeVar
|
|
|
|
from lexer.token import Token
|
|
|
|
T = TypeVar("T")
|
|
|
|
##############
|
|
# Statements #
|
|
##############
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Stmt(ABC):
|
|
@abstractmethod
|
|
def accept(self, visitor: Visitor[T]) -> T: ...
|
|
|
|
class Visitor(ABC, Generic[T]):
|
|
{stmt_visitor_methods}
|
|
|
|
|
|
{statements}
|
|
|
|
|
|
###############
|
|
# Expressions #
|
|
###############
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Expr(ABC):
|
|
@abstractmethod
|
|
def accept(self, visitor: Visitor[T]) -> T: ...
|
|
|
|
class Visitor(ABC, Generic[T]):
|
|
{expr_visitor_methods}
|
|
|
|
|
|
{expressions}
|
|
"""
|
|
|
|
VISITOR_METHOD_TEMPLATE = """
|
|
@abstractmethod
|
|
def visit_{func_name}(self, {param}: {cls}) -> T: ...
|
|
"""
|
|
|
|
CLASS_TEMPLATE = """
|
|
@dataclass(frozen=True)
|
|
class {cls}({base}):
|
|
{body}
|
|
|
|
def accept(self, visitor: {base}.Visitor[T]) -> T:
|
|
return visitor.visit_{func_name}(self)
|
|
"""
|
|
|
|
|
|
def snake_case(text: str) -> str:
|
|
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
|
|
|
|
|
|
def make_visitor_method(cls: str, param: str):
|
|
method: str = VISITOR_METHOD_TEMPLATE.format(
|
|
func_name=snake_case(cls), param=param, cls=cls
|
|
)
|
|
return method.strip("\n")
|
|
|
|
|
|
def make_class(name: str, cls: str, base: str):
|
|
body: str = cls.split("\n", 1)[1]
|
|
func_name: str = snake_case(name)
|
|
cls_def: str = CLASS_TEMPLATE.format(
|
|
cls=name,
|
|
base=base,
|
|
body=body,
|
|
func_name=func_name,
|
|
)
|
|
return cls_def.strip("\n")
|
|
|
|
|
|
def generate(src: str):
|
|
classes: list[str] = src.split("\n\n")
|
|
stmt_visitor_methods: list[str] = []
|
|
expr_visitor_methods: list[str] = []
|
|
statements: list[str] = []
|
|
expressions: list[str] = []
|
|
|
|
for cls in classes:
|
|
cls = cls.strip("\n")
|
|
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
|
|
print(f"Processing {name}")
|
|
if name.endswith("Stmt"):
|
|
stmt_visitor_methods.append(make_visitor_method(name, "stmt"))
|
|
statements.append(make_class(name, cls, "Stmt"))
|
|
elif name.endswith("Expr"):
|
|
expr_visitor_methods.append(make_visitor_method(name, "expr"))
|
|
expressions.append(make_class(name, cls, "Expr"))
|
|
|
|
return TEMPLATE.format(
|
|
header=HEADER,
|
|
stmt_visitor_methods="\n\n".join(stmt_visitor_methods),
|
|
expr_visitor_methods="\n\n".join(expr_visitor_methods),
|
|
statements="\n\n\n".join(statements),
|
|
expressions="\n\n\n".join(expressions),
|
|
)
|
|
|
|
|
|
def main():
|
|
root: Path = Path(__file__).parent.parent
|
|
in_path: Path = root / "gen" / "ast.py"
|
|
out_path: Path = root / "midas" / "ast" / "midas.py"
|
|
|
|
src: str = in_path.read_text()
|
|
generated: str = generate(src)
|
|
out_path.write_text(generated)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|