feat(parser): generate python AST classes
use the generation script to create Python AST node classes, also distinguish between Midas type annotation nodes and statements
This commit is contained in:
12
gen/gen.py
12
gen/gen.py
@@ -48,7 +48,7 @@ class {cls}({base}):
|
||||
"""
|
||||
|
||||
SECTION_REGEX = re.compile(
|
||||
r"^###>\s*(?P<base>[^\n]*?)\s*\|\s*(?P<name>[^\n]*?)\s*?\n(?P<body>.*?)\n###<$",
|
||||
r"^###>\s*(?P<base>[^\n]*?)\s*\|\s*(?P<name>[^\n]*?)(\s*\|\s*(?P<param>[^\n]*?))?\s*?\n(?P<body>.*?)\n###<$",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
@@ -87,15 +87,15 @@ def make_banner(text: str) -> str:
|
||||
return "\n".join((rule, middle, rule))
|
||||
|
||||
|
||||
def make_section(full_name: str, base: str, body: str) -> str:
|
||||
def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
||||
visitor_methods: list[str] = []
|
||||
classes: list[str] = []
|
||||
definitions: list[str] = body.strip("\n").split("\n\n")
|
||||
definitions: list[str] = body.strip("\n").split("\n\n\n")
|
||||
for cls in definitions:
|
||||
cls = cls.strip("\n")
|
||||
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
|
||||
print(f"Processing {name}")
|
||||
visitor_methods.append(make_visitor_method(name, base.lower()))
|
||||
visitor_methods.append(make_visitor_method(name, param))
|
||||
classes.append(make_class(name, cls, base))
|
||||
|
||||
return SECTION_TEMPLATE.format(
|
||||
@@ -119,8 +119,9 @@ def generate(definitions_path: Path, out_path: Path):
|
||||
for section_m in SECTION_REGEX.finditer(src):
|
||||
full_name: str = section_m.group("name")
|
||||
base: str = section_m.group("base")
|
||||
param: str = section_m.group("param") or base.lower()
|
||||
body: str = section_m.group("body")
|
||||
sections.append(make_section(full_name, base, body))
|
||||
sections.append(make_section(full_name, base, param, body))
|
||||
|
||||
result: str = TEMPLATE.format(
|
||||
header=HEADER.format(
|
||||
@@ -138,6 +139,7 @@ def main():
|
||||
defs_dir: Path = root / "gen"
|
||||
ast_dir: Path = root / "midas" / "ast"
|
||||
generate(defs_dir / "midas.py", ast_dir / "midas.py")
|
||||
generate(defs_dir / "python.py", ast_dir / "python.py")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
53
gen/python.py
Normal file
53
gen/python.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821, F401]
|
||||
|
||||
###> Imports
|
||||
import ast
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, Optional, TypeVar
|
||||
|
||||
from midas.ast.location import Location
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> MidasType | Type annotations | node
|
||||
class BaseType:
|
||||
base: str
|
||||
param: Optional[MidasType]
|
||||
|
||||
|
||||
class ConstraintType:
|
||||
type: MidasType
|
||||
constraint: ast.expr
|
||||
|
||||
|
||||
class FrameColumn:
|
||||
name: Optional[str]
|
||||
type: Optional[MidasType]
|
||||
|
||||
|
||||
class FrameType:
|
||||
columns: list[FrameColumn]
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Stmt | Statements
|
||||
class Function:
|
||||
name: str
|
||||
posonlyargs: list[Argument]
|
||||
args: list[Argument]
|
||||
kwonlyargs: list[Argument]
|
||||
returns: Optional[MidasType]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
location: Optional[Location] = None
|
||||
name: Optional[str]
|
||||
type: Optional[MidasType]
|
||||
|
||||
|
||||
###<
|
||||
@@ -350,7 +350,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
||||
return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}"
|
||||
|
||||
|
||||
class PythonAstPrinter(AstPrinter, p.Expr.Visitor[None]):
|
||||
class PythonAstPrinter(AstPrinter, p.MidasType.Visitor[None], p.Stmt.Visitor[None]):
|
||||
def visit_base_type(self, node: p.BaseType) -> None:
|
||||
self._write_line("BaseType")
|
||||
with self._child_level():
|
||||
@@ -382,39 +382,39 @@ class PythonAstPrinter(AstPrinter, p.Expr.Visitor[None]):
|
||||
self._mark_last()
|
||||
col.accept(self)
|
||||
|
||||
def visit_function(self, node: p.Function) -> None:
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
self._write_line("Function")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {node.name}")
|
||||
self._write_line(f"name: {stmt.name}")
|
||||
|
||||
self._write_line("posonlyargs")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(node.posonlyargs):
|
||||
for i, arg in enumerate(stmt.posonlyargs):
|
||||
self._idx = i
|
||||
if i == len(node.posonlyargs) - 1:
|
||||
if i == len(stmt.posonlyargs) - 1:
|
||||
self._mark_last()
|
||||
arg.accept(self)
|
||||
self._print_argument(arg)
|
||||
|
||||
self._write_line("args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(node.args):
|
||||
for i, arg in enumerate(stmt.args):
|
||||
self._idx = i
|
||||
if i == len(node.args) - 1:
|
||||
if i == len(stmt.args) - 1:
|
||||
self._mark_last()
|
||||
arg.accept(self)
|
||||
self._print_argument(arg)
|
||||
|
||||
self._write_line("kwonlyargs")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(node.kwonlyargs):
|
||||
for i, arg in enumerate(stmt.kwonlyargs):
|
||||
self._idx = i
|
||||
if i == len(node.kwonlyargs) - 1:
|
||||
if i == len(stmt.kwonlyargs) - 1:
|
||||
self._mark_last()
|
||||
arg.accept(self)
|
||||
self._print_argument(arg)
|
||||
|
||||
self._write_optional_child("returns", node.returns, last=True)
|
||||
self._write_optional_child("returns", stmt.returns, last=True)
|
||||
|
||||
def visit_function_argument(self, node: p.FunctionArgument) -> None:
|
||||
def _print_argument(self, arg: p.Function.Argument) -> None:
|
||||
self._write_line("FunctionArgument")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {node.name}")
|
||||
self._write_optional_child("type", node.type, last=True)
|
||||
self._write_line(f"name: {arg.name}")
|
||||
self._write_optional_child("type", arg.type, last=True)
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
"""
|
||||
This file was generated by a script. Any manual changes might be overwritten.
|
||||
Please modify gen/python.py instead and run gen/gen.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import ast
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, Optional, TypeVar
|
||||
|
||||
@@ -9,9 +14,13 @@ from midas.ast.location import Location
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
####################
|
||||
# Type annotations #
|
||||
####################
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Expr(ABC):
|
||||
class MidasType(ABC):
|
||||
location: Optional[Location] = None
|
||||
|
||||
@abstractmethod
|
||||
@@ -30,24 +39,13 @@ class Expr(ABC):
|
||||
@abstractmethod
|
||||
def visit_frame_type(self, node: FrameType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_function(self, node: Function) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_function_argument(self, node: FunctionArgument) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MidasType(Expr):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BaseType(MidasType):
|
||||
base: str
|
||||
param: Optional[MidasType]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_base_type(self)
|
||||
|
||||
|
||||
@@ -56,7 +54,7 @@ class ConstraintType(MidasType):
|
||||
type: MidasType
|
||||
constraint: ast.expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_type(self)
|
||||
|
||||
|
||||
@@ -65,7 +63,7 @@ class FrameColumn(MidasType):
|
||||
name: Optional[str]
|
||||
type: Optional[MidasType]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_frame_column(self)
|
||||
|
||||
|
||||
@@ -73,26 +71,40 @@ class FrameColumn(MidasType):
|
||||
class FrameType(MidasType):
|
||||
columns: list[FrameColumn]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_frame_type(self)
|
||||
|
||||
|
||||
##############
|
||||
# Statements #
|
||||
##############
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Stmt(ABC):
|
||||
location: Optional[Location] = None
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_function(self, stmt: Function) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Function(Expr):
|
||||
class Function(Stmt):
|
||||
name: str
|
||||
posonlyargs: list[FunctionArgument]
|
||||
args: list[FunctionArgument]
|
||||
kwonlyargs: list[FunctionArgument]
|
||||
posonlyargs: list[Argument]
|
||||
args: list[Argument]
|
||||
kwonlyargs: list[Argument]
|
||||
returns: Optional[MidasType]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_function(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionArgument(Expr):
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
location: Optional[Location] = None
|
||||
name: Optional[str]
|
||||
type: Optional[MidasType]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_function_argument(self)
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_function(self)
|
||||
|
||||
@@ -8,7 +8,6 @@ from midas.ast.python import (
|
||||
FrameColumn,
|
||||
FrameType,
|
||||
Function,
|
||||
FunctionArgument,
|
||||
MidasType,
|
||||
)
|
||||
|
||||
@@ -63,7 +62,7 @@ class PythonParser(ast.NodeVisitor):
|
||||
returns=returns,
|
||||
):
|
||||
|
||||
def parse_args(args_list: list[ast.arg]) -> list[FunctionArgument]:
|
||||
def parse_args(args_list: list[ast.arg]) -> list[Function.Argument]:
|
||||
return [self._parse_function_argument(arg) for arg in args_list]
|
||||
|
||||
return Function(
|
||||
@@ -75,13 +74,13 @@ class PythonParser(ast.NodeVisitor):
|
||||
returns=self._parse_type(returns) if returns is not None else None,
|
||||
)
|
||||
|
||||
def _parse_function_argument(self, arg: ast.arg) -> FunctionArgument:
|
||||
def _parse_function_argument(self, arg: ast.arg) -> Function.Argument:
|
||||
loc: Location = Location.from_ast(arg)
|
||||
name: str = arg.arg
|
||||
type: Optional[MidasType] = None
|
||||
if arg.annotation is not None:
|
||||
type = self._parse_type(arg.annotation)
|
||||
return FunctionArgument(
|
||||
return Function.Argument(
|
||||
location=loc,
|
||||
name=name,
|
||||
type=type,
|
||||
|
||||
Reference in New Issue
Block a user