diff --git a/gen/gen.py b/gen/gen.py index 34781b3..03f38d0 100644 --- a/gen/gen.py +++ b/gen/gen.py @@ -3,58 +3,34 @@ import re HEADER = '''""" This file was generated by a script. Any manual changes might be overwritten. -Please modify gen/ast.py instead and run gen/gen.py +Please modify {defs_path} instead and run {gen_path} """''' +SECTION_TEMPLATE = """{banner} + + +@dataclass(frozen=True, kw_only=True) +class {base}(ABC): + location: Optional[Location] = None + + @abstractmethod + def accept(self, visitor: Visitor[T]) -> T: ... + + class Visitor(ABC, Generic[T]): +{visitor_methods} + + +{classes}""" + TEMPLATE = """{header} from __future__ import annotations -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any, Generic, Optional, TypeVar - -from midas.ast.location import Location -from midas.lexer.token import Token +{imports} T = TypeVar("T") -############## -# Statements # -############## - - -@dataclass(frozen=True, kw_only=True) -class Stmt(ABC): - location: Optional[Location] = None - - @abstractmethod - def accept(self, visitor: Visitor[T]) -> T: ... - - class Visitor(ABC, Generic[T]): -{stmt_visitor_methods} - - -{statements} - - -############### -# Expressions # -############### - - -@dataclass(frozen=True, kw_only=True) -class Expr(ABC): - location: Optional[Location] = None - - @abstractmethod - def accept(self, visitor: Visitor[T]) -> T: ... - - class Visitor(ABC, Generic[T]): -{expr_visitor_methods} - - -{expressions} +{sections} """ VISITOR_METHOD_TEMPLATE = """ @@ -71,6 +47,16 @@ class {cls}({base}): return visitor.visit_{func_name}(self) """ +SECTION_REGEX = re.compile( + r"^###>\s*(?P[^\n]*?)\s*\|\s*(?P[^\n]*?)\s*?\n(?P.*?)\n###<$", + re.MULTILINE | re.DOTALL, +) + +IMPORTS_REGEX = re.compile( + r"^###>\s*Imports\s*?\n(?P.*?)\n###<$", + re.MULTILINE | re.DOTALL, +) + def snake_case(text: str) -> str: return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_") @@ -95,41 +81,63 @@ def make_class(name: str, cls: str, base: str): return cls_def.strip("\n") -def generate(src: str): - classes: list[str] = src.split("\n\n") - stmt_visitor_methods: list[str] = [] - expr_visitor_methods: list[str] = [] - statements: list[str] = [] - expressions: list[str] = [] +def make_banner(text: str) -> str: + middle: str = f"# {text} #" + rule: str = "#" * len(middle) + return "\n".join((rule, middle, rule)) - for cls in classes: + +def make_section(full_name: str, base: str, body: str) -> str: + visitor_methods: list[str] = [] + classes: list[str] = [] + definitions: list[str] = body.strip("\n").split("\n\n") + for cls in definitions: cls = cls.strip("\n") name: str = re.match("class (.*?):", cls).group(1) # type: ignore print(f"Processing {name}") - if name.endswith("Stmt"): - stmt_visitor_methods.append(make_visitor_method(name, "stmt")) - statements.append(make_class(name, cls, "Stmt")) - elif name.endswith("Expr"): - expr_visitor_methods.append(make_visitor_method(name, "expr")) - expressions.append(make_class(name, cls, "Expr")) + visitor_methods.append(make_visitor_method(name, base.lower())) + classes.append(make_class(name, cls, base)) - return TEMPLATE.format( - header=HEADER, - stmt_visitor_methods="\n\n".join(stmt_visitor_methods), - expr_visitor_methods="\n\n".join(expr_visitor_methods), - statements="\n\n\n".join(statements), - expressions="\n\n\n".join(expressions), + return SECTION_TEMPLATE.format( + banner=make_banner(full_name), + base=base, + visitor_methods="\n\n".join(visitor_methods), + classes="\n\n\n".join(classes), ) +def generate(definitions_path: Path, out_path: Path): + root_dir: Path = Path(__file__).parent.parent + rel_path: Path = definitions_path.relative_to(root_dir) + src: str = definitions_path.read_text() + sections: list[str] = [] + + imports: str = "" + if m := IMPORTS_REGEX.search(src): + imports = m.group("body").strip("\n") + + for section_m in SECTION_REGEX.finditer(src): + full_name: str = section_m.group("name") + base: str = section_m.group("base") + body: str = section_m.group("body") + sections.append(make_section(full_name, base, body)) + + result: str = TEMPLATE.format( + header=HEADER.format( + defs_path=rel_path, + gen_path=Path(__file__).relative_to(root_dir), + ), + imports=imports, + sections="\n\n\n".join(sections), + ) + out_path.write_text(result) + + def main(): root: Path = Path(__file__).parent.parent - in_path: Path = root / "gen" / "ast.py" - out_path: Path = root / "midas" / "ast" / "midas.py" - - src: str = in_path.read_text() - generated: str = generate(src) - out_path.write_text(generated) + defs_dir: Path = root / "gen" + ast_dir: Path = root / "midas" / "ast" + generate(defs_dir / "midas.py", ast_dir / "midas.py") if __name__ == "__main__": diff --git a/gen/ast.py b/gen/midas.py similarity index 76% rename from gen/ast.py rename to gen/midas.py index 6fca631..7187554 100644 --- a/gen/ast.py +++ b/gen/midas.py @@ -1,72 +1,110 @@ +# type: ignore +# ruff: disable[F821, F401] + +###> Imports +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Generic, Optional, TypeVar + +from midas.ast.location import Location +from midas.lexer.token import Token + +###< + + +###> Stmt | Statements class SimpleTypeStmt: name: Token template: Optional[TemplateExpr] base: TypeExpr constraint: Optional[Expr] -class SimpleTypeExpr: - name: Token - optional: bool - -class LogicalExpr: - left: Expr - operator: Token - right: Expr - -class BinaryExpr: - left: Expr - operator: Token - right: Expr - -class UnaryExpr: - operator: Token - right: Expr - -class GetExpr: - expr: Expr - name: Token - -class VariableExpr: - name: Token - -class GroupingExpr: - expr: Expr - -class LiteralExpr: - value: Any - -class WildcardExpr: - token: Token - -class TemplateExpr: - type: TypeExpr - -class TypeExpr: - name: Token - template: Optional[TemplateExpr] - optional: bool class ComplexTypeStmt: name: Token template: Optional[TemplateExpr] properties: list[PropertyStmt] + class PropertyStmt: name: Token type: TypeExpr constraint: Optional[Expr] + class ExtendStmt: type: TypeExpr operations: list[OpStmt] + class OpStmt: name: Token operand: TypeExpr result: TypeExpr + class PredicateStmt: name: Token subject: Token type: TypeExpr condition: Expr + + +###< + + +###> Expr | Expressions +class SimpleTypeExpr: + name: Token + optional: bool + + +class LogicalExpr: + left: Expr + operator: Token + right: Expr + + +class BinaryExpr: + left: Expr + operator: Token + right: Expr + + +class UnaryExpr: + operator: Token + right: Expr + + +class GetExpr: + expr: Expr + name: Token + + +class VariableExpr: + name: Token + + +class GroupingExpr: + expr: Expr + + +class LiteralExpr: + value: Any + + +class WildcardExpr: + token: Token + + +class TemplateExpr: + type: TypeExpr + + +class TypeExpr: + name: Token + template: Optional[TemplateExpr] + optional: bool + + +###< diff --git a/midas/ast/midas.py b/midas/ast/midas.py index 1ff503d..9cea8c2 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -1,6 +1,6 @@ """ This file was generated by a script. Any manual changes might be overwritten. -Please modify gen/ast.py instead and run gen/gen.py +Please modify gen/midas.py instead and run gen/gen.py """ from __future__ import annotations