diff --git a/gen/gen.py b/gen/gen.py new file mode 100644 index 0000000..18aabba --- /dev/null +++ b/gen/gen.py @@ -0,0 +1,128 @@ +from pathlib import Path +import re + +HEADER = '''""" +This file was generated by a script. Any manual changes might be overwritten. +Please modify gen/ast.py instead and run gen/gen.py +"""''' + +TEMPLATE = """{header} + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Generic, Optional, TypeVar + +from lexer.token import Token + +T = TypeVar("T") + +############## +# Statements # +############## + + +@dataclass(frozen=True) +class Stmt(ABC): + @abstractmethod + def accept(self, visitor: Visitor[T]) -> T: ... + + class Visitor(ABC, Generic[T]): +{stmt_visitor_methods} + + +{statements} + + +############### +# Expressions # +############### + + +@dataclass(frozen=True) +class Expr(ABC): + @abstractmethod + def accept(self, visitor: Visitor[T]) -> T: ... + + class Visitor(ABC, Generic[T]): +{expr_visitor_methods} + + +{expressions} +""" + +VISITOR_METHOD_TEMPLATE = """ + @abstractmethod + def visit_{func_name}(self, {param}: {cls}) -> T: ... +""" + +CLASS_TEMPLATE = """ +@dataclass(frozen=True) +class {cls}({base}): +{body} + + def accept(self, visitor: {base}.Visitor[T]) -> T: + return visitor.visit_{func_name}(self) +""" + +def snake_case(text: str) -> str: + return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_") + +def make_visitor_method(cls: str, param: str): + method: str = VISITOR_METHOD_TEMPLATE.format( + func_name=snake_case(cls), + param=param, + cls=cls + ) + return method.strip("\n") + +def make_class(name: str, cls: str, base: str, param: str): + body: str = cls.split("\n", 1)[1] + func_name: str = snake_case(name) + cls_def: str = CLASS_TEMPLATE.format( + cls=name, + base=base, + body=body, + func_name=func_name, + ) + return cls_def.strip("\n") + +def generate(src: str): + classes: list[str] = src.split("\n\n") + stmt_visitor_methods: list[str] = [] + expr_visitor_methods: list[str] = [] + statements: list[str] = [] + expressions: list[str] = [] + + for cls in classes: + cls = cls.strip("\n") + name: str = re.match("class (.*?):", cls).group(1) # type: ignore + print(f"Processing {name}") + if name.endswith("Stmt"): + stmt_visitor_methods.append(make_visitor_method(name, "stmt")) + statements.append(make_class(name, cls, "Stmt", "stmt")) + elif name.endswith("Expr"): + expr_visitor_methods.append(make_visitor_method(name, "expr")) + expressions.append(make_class(name, cls, "Expr", "expr")) + + return TEMPLATE.format( + header=HEADER, + stmt_visitor_methods="\n\n".join(stmt_visitor_methods), + expr_visitor_methods="\n\n".join(expr_visitor_methods), + statements="\n\n\n".join(statements), + expressions="\n\n\n".join(expressions), + ) + +def main(): + root: Path = Path(__file__).parent.parent + in_path: Path = root / "gen" / "ast.py" + out_path: Path = root / "core" / "ast" / "midas.py" + + src: str = in_path.read_text() + generated: str = generate(src) + out_path.write_text(generated) + + +if __name__ == "__main__": + main()