diff --git a/gen/gen.py b/gen/gen.py
index 34781b3..03f38d0 100644
--- a/gen/gen.py
+++ b/gen/gen.py
@@ -3,58 +3,34 @@ import re
HEADER = '''"""
This file was generated by a script. Any manual changes might be overwritten.
-Please modify gen/ast.py instead and run gen/gen.py
+Please modify {defs_path} instead and run {gen_path}
"""'''
+SECTION_TEMPLATE = """{banner}
+
+
+@dataclass(frozen=True, kw_only=True)
+class {base}(ABC):
+ location: Optional[Location] = None
+
+ @abstractmethod
+ def accept(self, visitor: Visitor[T]) -> T: ...
+
+ class Visitor(ABC, Generic[T]):
+{visitor_methods}
+
+
+{classes}"""
+
TEMPLATE = """{header}
from __future__ import annotations
-from abc import ABC, abstractmethod
-from dataclasses import dataclass
-from typing import Any, Generic, Optional, TypeVar
-
-from midas.ast.location import Location
-from midas.lexer.token import Token
+{imports}
T = TypeVar("T")
-##############
-# Statements #
-##############
-
-
-@dataclass(frozen=True, kw_only=True)
-class Stmt(ABC):
- location: Optional[Location] = None
-
- @abstractmethod
- def accept(self, visitor: Visitor[T]) -> T: ...
-
- class Visitor(ABC, Generic[T]):
-{stmt_visitor_methods}
-
-
-{statements}
-
-
-###############
-# Expressions #
-###############
-
-
-@dataclass(frozen=True, kw_only=True)
-class Expr(ABC):
- location: Optional[Location] = None
-
- @abstractmethod
- def accept(self, visitor: Visitor[T]) -> T: ...
-
- class Visitor(ABC, Generic[T]):
-{expr_visitor_methods}
-
-
-{expressions}
+{sections}
"""
VISITOR_METHOD_TEMPLATE = """
@@ -71,6 +47,16 @@ class {cls}({base}):
return visitor.visit_{func_name}(self)
"""
+SECTION_REGEX = re.compile(
+ r"^###>\s*(?P[^\n]*?)\s*\|\s*(?P[^\n]*?)\s*?\n(?P.*?)\n###<$",
+ re.MULTILINE | re.DOTALL,
+)
+
+IMPORTS_REGEX = re.compile(
+ r"^###>\s*Imports\s*?\n(?P.*?)\n###<$",
+ re.MULTILINE | re.DOTALL,
+)
+
def snake_case(text: str) -> str:
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
@@ -95,41 +81,63 @@ def make_class(name: str, cls: str, base: str):
return cls_def.strip("\n")
-def generate(src: str):
- classes: list[str] = src.split("\n\n")
- stmt_visitor_methods: list[str] = []
- expr_visitor_methods: list[str] = []
- statements: list[str] = []
- expressions: list[str] = []
+def make_banner(text: str) -> str:
+ middle: str = f"# {text} #"
+ rule: str = "#" * len(middle)
+ return "\n".join((rule, middle, rule))
- for cls in classes:
+
+def make_section(full_name: str, base: str, body: str) -> str:
+ visitor_methods: list[str] = []
+ classes: list[str] = []
+ definitions: list[str] = body.strip("\n").split("\n\n")
+ for cls in definitions:
cls = cls.strip("\n")
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
print(f"Processing {name}")
- if name.endswith("Stmt"):
- stmt_visitor_methods.append(make_visitor_method(name, "stmt"))
- statements.append(make_class(name, cls, "Stmt"))
- elif name.endswith("Expr"):
- expr_visitor_methods.append(make_visitor_method(name, "expr"))
- expressions.append(make_class(name, cls, "Expr"))
+ visitor_methods.append(make_visitor_method(name, base.lower()))
+ classes.append(make_class(name, cls, base))
- return TEMPLATE.format(
- header=HEADER,
- stmt_visitor_methods="\n\n".join(stmt_visitor_methods),
- expr_visitor_methods="\n\n".join(expr_visitor_methods),
- statements="\n\n\n".join(statements),
- expressions="\n\n\n".join(expressions),
+ return SECTION_TEMPLATE.format(
+ banner=make_banner(full_name),
+ base=base,
+ visitor_methods="\n\n".join(visitor_methods),
+ classes="\n\n\n".join(classes),
)
+def generate(definitions_path: Path, out_path: Path):
+ root_dir: Path = Path(__file__).parent.parent
+ rel_path: Path = definitions_path.relative_to(root_dir)
+ src: str = definitions_path.read_text()
+ sections: list[str] = []
+
+ imports: str = ""
+ if m := IMPORTS_REGEX.search(src):
+ imports = m.group("body").strip("\n")
+
+ for section_m in SECTION_REGEX.finditer(src):
+ full_name: str = section_m.group("name")
+ base: str = section_m.group("base")
+ body: str = section_m.group("body")
+ sections.append(make_section(full_name, base, body))
+
+ result: str = TEMPLATE.format(
+ header=HEADER.format(
+ defs_path=rel_path,
+ gen_path=Path(__file__).relative_to(root_dir),
+ ),
+ imports=imports,
+ sections="\n\n\n".join(sections),
+ )
+ out_path.write_text(result)
+
+
def main():
root: Path = Path(__file__).parent.parent
- in_path: Path = root / "gen" / "ast.py"
- out_path: Path = root / "midas" / "ast" / "midas.py"
-
- src: str = in_path.read_text()
- generated: str = generate(src)
- out_path.write_text(generated)
+ defs_dir: Path = root / "gen"
+ ast_dir: Path = root / "midas" / "ast"
+ generate(defs_dir / "midas.py", ast_dir / "midas.py")
if __name__ == "__main__":
diff --git a/gen/ast.py b/gen/midas.py
similarity index 76%
rename from gen/ast.py
rename to gen/midas.py
index 6fca631..7187554 100644
--- a/gen/ast.py
+++ b/gen/midas.py
@@ -1,72 +1,110 @@
+# type: ignore
+# ruff: disable[F821, F401]
+
+###> Imports
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Any, Generic, Optional, TypeVar
+
+from midas.ast.location import Location
+from midas.lexer.token import Token
+
+###<
+
+
+###> Stmt | Statements
class SimpleTypeStmt:
name: Token
template: Optional[TemplateExpr]
base: TypeExpr
constraint: Optional[Expr]
-class SimpleTypeExpr:
- name: Token
- optional: bool
-
-class LogicalExpr:
- left: Expr
- operator: Token
- right: Expr
-
-class BinaryExpr:
- left: Expr
- operator: Token
- right: Expr
-
-class UnaryExpr:
- operator: Token
- right: Expr
-
-class GetExpr:
- expr: Expr
- name: Token
-
-class VariableExpr:
- name: Token
-
-class GroupingExpr:
- expr: Expr
-
-class LiteralExpr:
- value: Any
-
-class WildcardExpr:
- token: Token
-
-class TemplateExpr:
- type: TypeExpr
-
-class TypeExpr:
- name: Token
- template: Optional[TemplateExpr]
- optional: bool
class ComplexTypeStmt:
name: Token
template: Optional[TemplateExpr]
properties: list[PropertyStmt]
+
class PropertyStmt:
name: Token
type: TypeExpr
constraint: Optional[Expr]
+
class ExtendStmt:
type: TypeExpr
operations: list[OpStmt]
+
class OpStmt:
name: Token
operand: TypeExpr
result: TypeExpr
+
class PredicateStmt:
name: Token
subject: Token
type: TypeExpr
condition: Expr
+
+
+###<
+
+
+###> Expr | Expressions
+class SimpleTypeExpr:
+ name: Token
+ optional: bool
+
+
+class LogicalExpr:
+ left: Expr
+ operator: Token
+ right: Expr
+
+
+class BinaryExpr:
+ left: Expr
+ operator: Token
+ right: Expr
+
+
+class UnaryExpr:
+ operator: Token
+ right: Expr
+
+
+class GetExpr:
+ expr: Expr
+ name: Token
+
+
+class VariableExpr:
+ name: Token
+
+
+class GroupingExpr:
+ expr: Expr
+
+
+class LiteralExpr:
+ value: Any
+
+
+class WildcardExpr:
+ token: Token
+
+
+class TemplateExpr:
+ type: TypeExpr
+
+
+class TypeExpr:
+ name: Token
+ template: Optional[TemplateExpr]
+ optional: bool
+
+
+###<
diff --git a/midas/ast/midas.py b/midas/ast/midas.py
index 1ff503d..9cea8c2 100644
--- a/midas/ast/midas.py
+++ b/midas/ast/midas.py
@@ -1,6 +1,6 @@
"""
This file was generated by a script. Any manual changes might be overwritten.
-Please modify gen/ast.py instead and run gen/gen.py
+Please modify gen/midas.py instead and run gen/gen.py
"""
from __future__ import annotations