diff --git a/gen/gen.py b/gen/gen.py
index 03f38d0..75e6100 100644
--- a/gen/gen.py
+++ b/gen/gen.py
@@ -48,7 +48,7 @@ class {cls}({base}):
"""
SECTION_REGEX = re.compile(
- r"^###>\s*(?P[^\n]*?)\s*\|\s*(?P[^\n]*?)\s*?\n(?P.*?)\n###<$",
+ r"^###>\s*(?P[^\n]*?)\s*\|\s*(?P[^\n]*?)(\s*\|\s*(?P[^\n]*?))?\s*?\n(?P.*?)\n###<$",
re.MULTILINE | re.DOTALL,
)
@@ -87,15 +87,15 @@ def make_banner(text: str) -> str:
return "\n".join((rule, middle, rule))
-def make_section(full_name: str, base: str, body: str) -> str:
+def make_section(full_name: str, base: str, param: str, body: str) -> str:
visitor_methods: list[str] = []
classes: list[str] = []
- definitions: list[str] = body.strip("\n").split("\n\n")
+ definitions: list[str] = body.strip("\n").split("\n\n\n")
for cls in definitions:
cls = cls.strip("\n")
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
print(f"Processing {name}")
- visitor_methods.append(make_visitor_method(name, base.lower()))
+ visitor_methods.append(make_visitor_method(name, param))
classes.append(make_class(name, cls, base))
return SECTION_TEMPLATE.format(
@@ -119,8 +119,9 @@ def generate(definitions_path: Path, out_path: Path):
for section_m in SECTION_REGEX.finditer(src):
full_name: str = section_m.group("name")
base: str = section_m.group("base")
+ param: str = section_m.group("param") or base.lower()
body: str = section_m.group("body")
- sections.append(make_section(full_name, base, body))
+ sections.append(make_section(full_name, base, param, body))
result: str = TEMPLATE.format(
header=HEADER.format(
@@ -138,6 +139,7 @@ def main():
defs_dir: Path = root / "gen"
ast_dir: Path = root / "midas" / "ast"
generate(defs_dir / "midas.py", ast_dir / "midas.py")
+ generate(defs_dir / "python.py", ast_dir / "python.py")
if __name__ == "__main__":
diff --git a/gen/python.py b/gen/python.py
new file mode 100644
index 0000000..15df1d9
--- /dev/null
+++ b/gen/python.py
@@ -0,0 +1,53 @@
+# type: ignore
+# ruff: disable[F821, F401]
+
+###> Imports
+import ast
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Generic, Optional, TypeVar
+
+from midas.ast.location import Location
+
+###<
+
+
+###> MidasType | Type annotations | node
+class BaseType:
+ base: str
+ param: Optional[MidasType]
+
+
+class ConstraintType:
+ type: MidasType
+ constraint: ast.expr
+
+
+class FrameColumn:
+ name: Optional[str]
+ type: Optional[MidasType]
+
+
+class FrameType:
+ columns: list[FrameColumn]
+
+
+###<
+
+
+###> Stmt | Statements
+class Function:
+ name: str
+ posonlyargs: list[Argument]
+ args: list[Argument]
+ kwonlyargs: list[Argument]
+ returns: Optional[MidasType]
+
+ @dataclass(frozen=True, kw_only=True)
+ class Argument:
+ location: Optional[Location] = None
+ name: Optional[str]
+ type: Optional[MidasType]
+
+
+###<
diff --git a/midas/ast/printer.py b/midas/ast/printer.py
index b92e40f..2d38241 100644
--- a/midas/ast/printer.py
+++ b/midas/ast/printer.py
@@ -350,7 +350,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}"
-class PythonAstPrinter(AstPrinter, p.Expr.Visitor[None]):
+class PythonAstPrinter(AstPrinter, p.MidasType.Visitor[None], p.Stmt.Visitor[None]):
def visit_base_type(self, node: p.BaseType) -> None:
self._write_line("BaseType")
with self._child_level():
@@ -382,39 +382,39 @@ class PythonAstPrinter(AstPrinter, p.Expr.Visitor[None]):
self._mark_last()
col.accept(self)
- def visit_function(self, node: p.Function) -> None:
+ def visit_function(self, stmt: p.Function) -> None:
self._write_line("Function")
with self._child_level():
- self._write_line(f"name: {node.name}")
+ self._write_line(f"name: {stmt.name}")
self._write_line("posonlyargs")
with self._child_level():
- for i, arg in enumerate(node.posonlyargs):
+ for i, arg in enumerate(stmt.posonlyargs):
self._idx = i
- if i == len(node.posonlyargs) - 1:
+ if i == len(stmt.posonlyargs) - 1:
self._mark_last()
- arg.accept(self)
+ self._print_argument(arg)
self._write_line("args")
with self._child_level():
- for i, arg in enumerate(node.args):
+ for i, arg in enumerate(stmt.args):
self._idx = i
- if i == len(node.args) - 1:
+ if i == len(stmt.args) - 1:
self._mark_last()
- arg.accept(self)
+ self._print_argument(arg)
self._write_line("kwonlyargs")
with self._child_level():
- for i, arg in enumerate(node.kwonlyargs):
+ for i, arg in enumerate(stmt.kwonlyargs):
self._idx = i
- if i == len(node.kwonlyargs) - 1:
+ if i == len(stmt.kwonlyargs) - 1:
self._mark_last()
- arg.accept(self)
+ self._print_argument(arg)
- self._write_optional_child("returns", node.returns, last=True)
+ self._write_optional_child("returns", stmt.returns, last=True)
- def visit_function_argument(self, node: p.FunctionArgument) -> None:
+ def _print_argument(self, arg: p.Function.Argument) -> None:
self._write_line("FunctionArgument")
with self._child_level():
- self._write_line(f"name: {node.name}")
- self._write_optional_child("type", node.type, last=True)
+ self._write_line(f"name: {arg.name}")
+ self._write_optional_child("type", arg.type, last=True)
diff --git a/midas/ast/python.py b/midas/ast/python.py
index c25b438..cd120ee 100644
--- a/midas/ast/python.py
+++ b/midas/ast/python.py
@@ -1,7 +1,12 @@
+"""
+This file was generated by a script. Any manual changes might be overwritten.
+Please modify gen/python.py instead and run gen/gen.py
+"""
+
from __future__ import annotations
-from abc import ABC, abstractmethod
import ast
+from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, Optional, TypeVar
@@ -9,9 +14,13 @@ from midas.ast.location import Location
T = TypeVar("T")
+####################
+# Type annotations #
+####################
+
@dataclass(frozen=True, kw_only=True)
-class Expr(ABC):
+class MidasType(ABC):
location: Optional[Location] = None
@abstractmethod
@@ -30,24 +39,13 @@ class Expr(ABC):
@abstractmethod
def visit_frame_type(self, node: FrameType) -> T: ...
- @abstractmethod
- def visit_function(self, node: Function) -> T: ...
-
- @abstractmethod
- def visit_function_argument(self, node: FunctionArgument) -> T: ...
-
-
-@dataclass(frozen=True)
-class MidasType(Expr):
- pass
-
@dataclass(frozen=True)
class BaseType(MidasType):
base: str
param: Optional[MidasType]
- def accept(self, visitor: Expr.Visitor[T]) -> T:
+ def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_base_type(self)
@@ -56,7 +54,7 @@ class ConstraintType(MidasType):
type: MidasType
constraint: ast.expr
- def accept(self, visitor: Expr.Visitor[T]) -> T:
+ def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_constraint_type(self)
@@ -65,7 +63,7 @@ class FrameColumn(MidasType):
name: Optional[str]
type: Optional[MidasType]
- def accept(self, visitor: Expr.Visitor[T]) -> T:
+ def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_frame_column(self)
@@ -73,26 +71,40 @@ class FrameColumn(MidasType):
class FrameType(MidasType):
columns: list[FrameColumn]
- def accept(self, visitor: Expr.Visitor[T]) -> T:
+ def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_frame_type(self)
+##############
+# Statements #
+##############
+
+
+@dataclass(frozen=True, kw_only=True)
+class Stmt(ABC):
+ location: Optional[Location] = None
+
+ @abstractmethod
+ def accept(self, visitor: Visitor[T]) -> T: ...
+
+ class Visitor(ABC, Generic[T]):
+ @abstractmethod
+ def visit_function(self, stmt: Function) -> T: ...
+
+
@dataclass(frozen=True)
-class Function(Expr):
+class Function(Stmt):
name: str
- posonlyargs: list[FunctionArgument]
- args: list[FunctionArgument]
- kwonlyargs: list[FunctionArgument]
+ posonlyargs: list[Argument]
+ args: list[Argument]
+ kwonlyargs: list[Argument]
returns: Optional[MidasType]
- def accept(self, visitor: Expr.Visitor[T]) -> T:
+ @dataclass(frozen=True, kw_only=True)
+ class Argument:
+ location: Optional[Location] = None
+ name: Optional[str]
+ type: Optional[MidasType]
+
+ def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_function(self)
-
-
-@dataclass(frozen=True)
-class FunctionArgument(Expr):
- name: Optional[str]
- type: Optional[MidasType]
-
- def accept(self, visitor: Expr.Visitor[T]) -> T:
- return visitor.visit_function_argument(self)
diff --git a/midas/parser/python.py b/midas/parser/python.py
index 6e0ffe1..082cab1 100644
--- a/midas/parser/python.py
+++ b/midas/parser/python.py
@@ -8,7 +8,6 @@ from midas.ast.python import (
FrameColumn,
FrameType,
Function,
- FunctionArgument,
MidasType,
)
@@ -63,7 +62,7 @@ class PythonParser(ast.NodeVisitor):
returns=returns,
):
- def parse_args(args_list: list[ast.arg]) -> list[FunctionArgument]:
+ def parse_args(args_list: list[ast.arg]) -> list[Function.Argument]:
return [self._parse_function_argument(arg) for arg in args_list]
return Function(
@@ -75,13 +74,13 @@ class PythonParser(ast.NodeVisitor):
returns=self._parse_type(returns) if returns is not None else None,
)
- def _parse_function_argument(self, arg: ast.arg) -> FunctionArgument:
+ def _parse_function_argument(self, arg: ast.arg) -> Function.Argument:
loc: Location = Location.from_ast(arg)
name: str = arg.arg
type: Optional[MidasType] = None
if arg.annotation is not None:
type = self._parse_type(arg.annotation)
- return FunctionArgument(
+ return Function.Argument(
location=loc,
name=name,
type=type,