147 lines
3.7 KiB
Python
147 lines
3.7 KiB
Python
import re
|
|
from pathlib import Path
|
|
|
|
HEADER = '''"""
|
|
This file was generated by a script. Any manual changes might be overwritten.
|
|
Please modify {defs_path} instead and run {gen_path}
|
|
"""'''
|
|
|
|
SECTION_TEMPLATE = """{banner}
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class {base}(ABC):
|
|
location: Location
|
|
|
|
@abstractmethod
|
|
def accept(self, visitor: Visitor[T]) -> T: ...
|
|
|
|
class Visitor(ABC, Generic[T]):
|
|
{visitor_methods}
|
|
|
|
|
|
{classes}"""
|
|
|
|
TEMPLATE = """{header}
|
|
|
|
from __future__ import annotations
|
|
|
|
{imports}
|
|
|
|
T = TypeVar("T")
|
|
|
|
{sections}
|
|
"""
|
|
|
|
VISITOR_METHOD_TEMPLATE = """
|
|
@abstractmethod
|
|
def visit_{func_name}(self, {param}: {cls}) -> T: ...
|
|
"""
|
|
|
|
CLASS_TEMPLATE = """
|
|
@dataclass(frozen=True)
|
|
class {cls}({base}):
|
|
{body}
|
|
|
|
def accept(self, visitor: {base}.Visitor[T]) -> T:
|
|
return visitor.visit_{func_name}(self)
|
|
"""
|
|
|
|
SECTION_REGEX = re.compile(
|
|
r"^###>\s*(?P<base>[^\n]*?)\s*\|\s*(?P<name>[^\n]*?)(\s*\|\s*(?P<param>[^\n]*?))?\s*?\n(?P<body>.*?)\n###<$",
|
|
re.MULTILINE | re.DOTALL,
|
|
)
|
|
|
|
IMPORTS_REGEX = re.compile(
|
|
r"^###>\s*Imports\s*?\n(?P<body>.*?)\n###<$",
|
|
re.MULTILINE | re.DOTALL,
|
|
)
|
|
|
|
|
|
def snake_case(text: str) -> str:
|
|
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
|
|
|
|
|
|
def make_visitor_method(cls: str, param: str):
|
|
method: str = VISITOR_METHOD_TEMPLATE.format(
|
|
func_name=snake_case(cls), param=param, cls=cls
|
|
)
|
|
return method.strip("\n")
|
|
|
|
|
|
def make_class(name: str, cls: str, base: str):
|
|
body: str = cls.split("\n", 1)[1]
|
|
func_name: str = snake_case(name)
|
|
cls_def: str = CLASS_TEMPLATE.format(
|
|
cls=name,
|
|
base=base,
|
|
body=body,
|
|
func_name=func_name,
|
|
)
|
|
return cls_def.strip("\n")
|
|
|
|
|
|
def make_banner(text: str) -> str:
|
|
middle: str = f"# {text} #"
|
|
rule: str = "#" * len(middle)
|
|
return "\n".join((rule, middle, rule))
|
|
|
|
|
|
def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
|
visitor_methods: list[str] = []
|
|
classes: list[str] = []
|
|
definitions: list[str] = body.strip("\n").split("\n\n\n")
|
|
for cls in definitions:
|
|
cls = cls.strip("\n")
|
|
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
|
|
print(f"Processing {name}")
|
|
visitor_methods.append(make_visitor_method(name, param))
|
|
classes.append(make_class(name, cls, base))
|
|
|
|
return SECTION_TEMPLATE.format(
|
|
banner=make_banner(full_name),
|
|
base=base,
|
|
visitor_methods="\n\n".join(visitor_methods),
|
|
classes="\n\n\n".join(classes),
|
|
)
|
|
|
|
|
|
def generate(definitions_path: Path, out_path: Path):
|
|
root_dir: Path = Path(__file__).parent.parent
|
|
rel_path: Path = definitions_path.relative_to(root_dir)
|
|
src: str = definitions_path.read_text()
|
|
sections: list[str] = []
|
|
|
|
imports: str = ""
|
|
if m := IMPORTS_REGEX.search(src):
|
|
imports = m.group("body").strip("\n")
|
|
|
|
for section_m in SECTION_REGEX.finditer(src):
|
|
full_name: str = section_m.group("name")
|
|
base: str = section_m.group("base")
|
|
param: str = section_m.group("param") or base.lower()
|
|
body: str = section_m.group("body")
|
|
sections.append(make_section(full_name, base, param, body))
|
|
|
|
result: str = TEMPLATE.format(
|
|
header=HEADER.format(
|
|
defs_path=rel_path,
|
|
gen_path=Path(__file__).relative_to(root_dir),
|
|
),
|
|
imports=imports,
|
|
sections="\n\n\n".join(sections),
|
|
)
|
|
out_path.write_text(result)
|
|
|
|
|
|
def main():
|
|
root: Path = Path(__file__).parent.parent
|
|
defs_dir: Path = root / "gen"
|
|
ast_dir: Path = root / "midas" / "ast"
|
|
generate(defs_dir / "midas.py", ast_dir / "midas.py")
|
|
generate(defs_dir / "python.py", ast_dir / "python.py")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|