Files
midas/gen/gen.py

129 lines
3.1 KiB
Python

from pathlib import Path
import re
HEADER = '''"""
This file was generated by a script. Any manual changes might be overwritten.
Please modify gen/ast.py instead and run gen/gen.py
"""'''
TEMPLATE = """{header}
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
from lexer.token import Token
T = TypeVar("T")
##############
# Statements #
##############
@dataclass(frozen=True)
class Stmt(ABC):
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
{stmt_visitor_methods}
{statements}
###############
# Expressions #
###############
@dataclass(frozen=True)
class Expr(ABC):
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
{expr_visitor_methods}
{expressions}
"""
VISITOR_METHOD_TEMPLATE = """
@abstractmethod
def visit_{func_name}(self, {param}: {cls}) -> T: ...
"""
CLASS_TEMPLATE = """
@dataclass(frozen=True)
class {cls}({base}):
{body}
def accept(self, visitor: {base}.Visitor[T]) -> T:
return visitor.visit_{func_name}(self)
"""
def snake_case(text: str) -> str:
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
def make_visitor_method(cls: str, param: str):
method: str = VISITOR_METHOD_TEMPLATE.format(
func_name=snake_case(cls),
param=param,
cls=cls
)
return method.strip("\n")
def make_class(name: str, cls: str, base: str, param: str):
body: str = cls.split("\n", 1)[1]
func_name: str = snake_case(name)
cls_def: str = CLASS_TEMPLATE.format(
cls=name,
base=base,
body=body,
func_name=func_name,
)
return cls_def.strip("\n")
def generate(src: str):
classes: list[str] = src.split("\n\n")
stmt_visitor_methods: list[str] = []
expr_visitor_methods: list[str] = []
statements: list[str] = []
expressions: list[str] = []
for cls in classes:
cls = cls.strip("\n")
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
print(f"Processing {name}")
if name.endswith("Stmt"):
stmt_visitor_methods.append(make_visitor_method(name, "stmt"))
statements.append(make_class(name, cls, "Stmt", "stmt"))
elif name.endswith("Expr"):
expr_visitor_methods.append(make_visitor_method(name, "expr"))
expressions.append(make_class(name, cls, "Expr", "expr"))
return TEMPLATE.format(
header=HEADER,
stmt_visitor_methods="\n\n".join(stmt_visitor_methods),
expr_visitor_methods="\n\n".join(expr_visitor_methods),
statements="\n\n\n".join(statements),
expressions="\n\n\n".join(expressions),
)
def main():
root: Path = Path(__file__).parent.parent
in_path: Path = root / "gen" / "ast.py"
out_path: Path = root / "core" / "ast" / "midas.py"
src: str = in_path.read_text()
generated: str = generate(src)
out_path.write_text(generated)
if __name__ == "__main__":
main()