use the generation script to create Python AST node classes, also distinguish between Midas type annotation nodes and statements
147 lines
3.7 KiB
Python
147 lines
3.7 KiB
Python
from pathlib import Path
|
|
import re
|
|
|
|
HEADER = '''"""
|
|
This file was generated by a script. Any manual changes might be overwritten.
|
|
Please modify {defs_path} instead and run {gen_path}
|
|
"""'''
|
|
|
|
SECTION_TEMPLATE = """{banner}
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class {base}(ABC):
|
|
location: Optional[Location] = None
|
|
|
|
@abstractmethod
|
|
def accept(self, visitor: Visitor[T]) -> T: ...
|
|
|
|
class Visitor(ABC, Generic[T]):
|
|
{visitor_methods}
|
|
|
|
|
|
{classes}"""
|
|
|
|
TEMPLATE = """{header}
|
|
|
|
from __future__ import annotations
|
|
|
|
{imports}
|
|
|
|
T = TypeVar("T")
|
|
|
|
{sections}
|
|
"""
|
|
|
|
VISITOR_METHOD_TEMPLATE = """
|
|
@abstractmethod
|
|
def visit_{func_name}(self, {param}: {cls}) -> T: ...
|
|
"""
|
|
|
|
CLASS_TEMPLATE = """
|
|
@dataclass(frozen=True)
|
|
class {cls}({base}):
|
|
{body}
|
|
|
|
def accept(self, visitor: {base}.Visitor[T]) -> T:
|
|
return visitor.visit_{func_name}(self)
|
|
"""
|
|
|
|
SECTION_REGEX = re.compile(
|
|
r"^###>\s*(?P<base>[^\n]*?)\s*\|\s*(?P<name>[^\n]*?)(\s*\|\s*(?P<param>[^\n]*?))?\s*?\n(?P<body>.*?)\n###<$",
|
|
re.MULTILINE | re.DOTALL,
|
|
)
|
|
|
|
IMPORTS_REGEX = re.compile(
|
|
r"^###>\s*Imports\s*?\n(?P<body>.*?)\n###<$",
|
|
re.MULTILINE | re.DOTALL,
|
|
)
|
|
|
|
|
|
def snake_case(text: str) -> str:
|
|
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
|
|
|
|
|
|
def make_visitor_method(cls: str, param: str):
|
|
method: str = VISITOR_METHOD_TEMPLATE.format(
|
|
func_name=snake_case(cls), param=param, cls=cls
|
|
)
|
|
return method.strip("\n")
|
|
|
|
|
|
def make_class(name: str, cls: str, base: str):
|
|
body: str = cls.split("\n", 1)[1]
|
|
func_name: str = snake_case(name)
|
|
cls_def: str = CLASS_TEMPLATE.format(
|
|
cls=name,
|
|
base=base,
|
|
body=body,
|
|
func_name=func_name,
|
|
)
|
|
return cls_def.strip("\n")
|
|
|
|
|
|
def make_banner(text: str) -> str:
|
|
middle: str = f"# {text} #"
|
|
rule: str = "#" * len(middle)
|
|
return "\n".join((rule, middle, rule))
|
|
|
|
|
|
def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
|
visitor_methods: list[str] = []
|
|
classes: list[str] = []
|
|
definitions: list[str] = body.strip("\n").split("\n\n\n")
|
|
for cls in definitions:
|
|
cls = cls.strip("\n")
|
|
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
|
|
print(f"Processing {name}")
|
|
visitor_methods.append(make_visitor_method(name, param))
|
|
classes.append(make_class(name, cls, base))
|
|
|
|
return SECTION_TEMPLATE.format(
|
|
banner=make_banner(full_name),
|
|
base=base,
|
|
visitor_methods="\n\n".join(visitor_methods),
|
|
classes="\n\n\n".join(classes),
|
|
)
|
|
|
|
|
|
def generate(definitions_path: Path, out_path: Path):
|
|
root_dir: Path = Path(__file__).parent.parent
|
|
rel_path: Path = definitions_path.relative_to(root_dir)
|
|
src: str = definitions_path.read_text()
|
|
sections: list[str] = []
|
|
|
|
imports: str = ""
|
|
if m := IMPORTS_REGEX.search(src):
|
|
imports = m.group("body").strip("\n")
|
|
|
|
for section_m in SECTION_REGEX.finditer(src):
|
|
full_name: str = section_m.group("name")
|
|
base: str = section_m.group("base")
|
|
param: str = section_m.group("param") or base.lower()
|
|
body: str = section_m.group("body")
|
|
sections.append(make_section(full_name, base, param, body))
|
|
|
|
result: str = TEMPLATE.format(
|
|
header=HEADER.format(
|
|
defs_path=rel_path,
|
|
gen_path=Path(__file__).relative_to(root_dir),
|
|
),
|
|
imports=imports,
|
|
sections="\n\n\n".join(sections),
|
|
)
|
|
out_path.write_text(result)
|
|
|
|
|
|
def main():
|
|
root: Path = Path(__file__).parent.parent
|
|
defs_dir: Path = root / "gen"
|
|
ast_dir: Path = root / "midas" / "ast"
|
|
generate(defs_dir / "midas.py", ast_dir / "midas.py")
|
|
generate(defs_dir / "python.py", ast_dir / "python.py")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|