Files
midas/gen/gen.py
2026-05-29 17:25:12 +02:00

147 lines
3.7 KiB
Python

import re
from pathlib import Path
HEADER = '''"""
This file was generated by a script. Any manual changes might be overwritten.
Please modify {defs_path} instead and run {gen_path}
"""'''
SECTION_TEMPLATE = """{banner}
@dataclass(frozen=True, kw_only=True)
class {base}(ABC):
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
{visitor_methods}
{classes}"""
TEMPLATE = """{header}
from __future__ import annotations
{imports}
T = TypeVar("T")
{sections}
"""
VISITOR_METHOD_TEMPLATE = """
@abstractmethod
def visit_{func_name}(self, {param}: {cls}) -> T: ...
"""
CLASS_TEMPLATE = """
@dataclass(frozen=True)
class {cls}({base}):
{body}
def accept(self, visitor: {base}.Visitor[T]) -> T:
return visitor.visit_{func_name}(self)
"""
SECTION_REGEX = re.compile(
r"^###>\s*(?P<base>[^\n]*?)\s*\|\s*(?P<name>[^\n]*?)(\s*\|\s*(?P<param>[^\n]*?))?\s*?\n(?P<body>.*?)\n###<$",
re.MULTILINE | re.DOTALL,
)
IMPORTS_REGEX = re.compile(
r"^###>\s*Imports\s*?\n(?P<body>.*?)\n###<$",
re.MULTILINE | re.DOTALL,
)
def snake_case(text: str) -> str:
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
def make_visitor_method(cls: str, param: str):
method: str = VISITOR_METHOD_TEMPLATE.format(
func_name=snake_case(cls), param=param, cls=cls
)
return method.strip("\n")
def make_class(name: str, cls: str, base: str):
body: str = cls.split("\n", 1)[1]
func_name: str = snake_case(name)
cls_def: str = CLASS_TEMPLATE.format(
cls=name,
base=base,
body=body,
func_name=func_name,
)
return cls_def.strip("\n")
def make_banner(text: str) -> str:
middle: str = f"# {text} #"
rule: str = "#" * len(middle)
return "\n".join((rule, middle, rule))
def make_section(full_name: str, base: str, param: str, body: str) -> str:
visitor_methods: list[str] = []
classes: list[str] = []
definitions: list[str] = body.strip("\n").split("\n\n\n")
for cls in definitions:
cls = cls.strip("\n")
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
print(f"Processing {name}")
visitor_methods.append(make_visitor_method(name, param))
classes.append(make_class(name, cls, base))
return SECTION_TEMPLATE.format(
banner=make_banner(full_name),
base=base,
visitor_methods="\n\n".join(visitor_methods),
classes="\n\n\n".join(classes),
)
def generate(definitions_path: Path, out_path: Path):
root_dir: Path = Path(__file__).parent.parent
rel_path: Path = definitions_path.relative_to(root_dir)
src: str = definitions_path.read_text()
sections: list[str] = []
imports: str = ""
if m := IMPORTS_REGEX.search(src):
imports = m.group("body").strip("\n")
for section_m in SECTION_REGEX.finditer(src):
full_name: str = section_m.group("name")
base: str = section_m.group("base")
param: str = section_m.group("param") or base.lower()
body: str = section_m.group("body")
sections.append(make_section(full_name, base, param, body))
result: str = TEMPLATE.format(
header=HEADER.format(
defs_path=rel_path,
gen_path=Path(__file__).relative_to(root_dir),
),
imports=imports,
sections="\n\n\n".join(sections),
)
out_path.write_text(result)
def main():
root: Path = Path(__file__).parent.parent
defs_dir: Path = root / "gen"
ast_dir: Path = root / "midas" / "ast"
generate(defs_dir / "midas.py", ast_dir / "midas.py")
generate(defs_dir / "python.py", ast_dir / "python.py")
if __name__ == "__main__":
main()