Merge pull request 'Simple code generator and CLI redesign' (#10) from feat/code-generator into main

Reviewed-on: #10
This commit was merged in pull request #10.
This commit is contained in:
2026-06-15 12:29:22 +00:00
21 changed files with 663 additions and 263 deletions

View File

@@ -82,6 +82,10 @@ class IfStmt:
orelse: list[Stmt]
class Pass:
pass
###<

View File

@@ -593,6 +593,9 @@ class PythonAstPrinter(
self._mark_last()
else_stmt.accept(self)
def visit_pass(self, stmt: p.Pass) -> None:
self._write_line("Pass")
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self._write_line("BinaryExpr")
with self._child_level():

View File

@@ -107,6 +107,9 @@ class Stmt(ABC):
@abstractmethod
def visit_if_stmt(self, stmt: IfStmt) -> T: ...
@abstractmethod
def visit_pass(self, stmt: Pass) -> T: ...
@dataclass(frozen=True)
class ExpressionStmt(Stmt):
@@ -178,6 +181,14 @@ class IfStmt(Stmt):
return visitor.visit_if_stmt(self)
@dataclass(frozen=True)
class Pass(Stmt):
pass
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_pass(self)
###############
# Expressions #
###############

View File

@@ -6,6 +6,7 @@ from midas.checker.midas import MidasTyper
from midas.checker.python import PythonTyper
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import Reporter
from midas.utils import TypedAST
class TypeChecker:
@@ -23,12 +24,12 @@ class TypeChecker:
def import_midas_source(self, source: str, path: Optional[str] = None):
self.midas_typer.process(source, path)
def type_check(self, path: Path):
def type_check(self, path: Path) -> TypedAST:
source: str = path.read_text()
return self.type_check_source(source, path=str(path))
def type_check_source(self, source: str, path: Optional[str] = None):
self.python_typer.process(source, path)
def type_check_source(self, source: str, path: Optional[str] = None) -> TypedAST:
return self.python_typer.process(source, path)
@property
def diagnostics(self) -> list[Diagnostic]:

View File

@@ -19,6 +19,7 @@ from midas.checker.types import (
unfold_type,
)
from midas.parser.python import PythonParser
from midas.utils import TypedAST
TypedExpr = tuple[p.Expr, Type]
@@ -60,7 +61,7 @@ class PythonTyper(
self.locals: dict[p.Expr, int] = {}
self.judgements: list[tuple[p.Expr, Type]] = []
def process(self, source: str, path: Optional[str]):
def process(self, source: str, path: Optional[str]) -> TypedAST:
self.reporter = self.reporter.for_file(path)
tree: ast.Module = ast.parse(source, filename=path or "<unknown>")
@@ -75,6 +76,8 @@ class PythonTyper(
self.check(stmts)
return TypedAST(stmts=stmts, judgements=self.judgements)
def type_of(self, expr: p.Expr) -> Type:
"""Evaluate the type of an expression
@@ -328,6 +331,9 @@ class PythonTyper(
if body_returned and else_returned:
raise ReturnException()
def visit_pass(self, stmt: p.Pass) -> None:
pass
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
if method is None:

View File

@@ -150,6 +150,9 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
self.resolve(*stmt.orelse)
self.end_scope()
def visit_pass(self, stmt: p.Pass) -> None:
pass
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self.resolve(expr.left)
self.resolve(expr.right)

View File

@@ -0,0 +1,8 @@
from .check import check as check
from .compile import compile as compile
from .format import format as format
from .highlight import highlight as highlight
from .parse import parse as parse
from .registry import dump_registry as dump_registry
from .types import types as types
from .validate import validate as validate

View File

@@ -0,0 +1,41 @@
# **Run type checker and report diagnostics**
# ```shell
# midas check <file.py> [--types <file.midas>]
# ```
from pathlib import Path
from typing import Optional, TextIO
import click
from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import Diagnostic
from midas.cli.highlighter import DiagnosticsHighlighter
from midas.cli.utils import DiagnosticPrinter
@click.command()
@click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-l", "--highlight", type=click.File("w"))
def check(
file: TextIO,
types: tuple[TextIO],
highlight: Optional[TextIO],
):
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
for types_file in types:
checker.import_midas(Path(types_file.name).resolve())
checker.type_check(source_path)
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
printer = DiagnosticPrinter()
printer.print_all(diagnostics)
if highlight is not None:
source: str = file.read()
highlighter = DiagnosticsHighlighter(source)
highlighter.highlight(diagnostics)
highlighter.dump(highlight)

View File

@@ -0,0 +1,38 @@
# **Compile source**
# ```shell
# midas compile <file.py> [--types <file.midas>] [-o <output>] [--assertions|--strict|--no-checks]
# ```
from pathlib import Path
from typing import TextIO
import click
from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import Diagnostic
from midas.cli.utils import DiagnosticPrinter
from midas.generator.generator import Generator
from midas.utils import TypedAST
@click.command()
@click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
def compile(
file: TextIO,
types: tuple[TextIO],
):
source: str = file.read()
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
for types_file in types:
checker.import_midas(Path(types_file.name).resolve())
typed_ast: TypedAST = checker.type_check_source(source, str(source_path))
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
printer = DiagnosticPrinter()
printer.print_all(diagnostics)
generator = Generator(workdir=source_path.parent)
generator.generate(typed_ast, source_path)

View File

@@ -0,0 +1,25 @@
from typing import TextIO
import click
import midas.ast.midas as m
from midas.ast.printer import MidasPrinter
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
@click.command()
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
def format(file: TextIO, output: TextIO):
source: str = file.read()
printer = MidasPrinter()
lexer = MidasLexer(source, file=file.name)
tokens: list[Token] = lexer.process()
parser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
for err in parser.errors:
print(err.get_report())
for stmt in stmts:
output.write(printer.print(stmt) + "\n")

View File

@@ -0,0 +1,63 @@
import ast
from typing import TextIO
import click
import midas.ast.midas as m
import midas.ast.python as p
from midas.cli.highlighter import (
Highlighter,
LocatableToken,
MidasHighlighter,
PythonHighlighter,
)
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token, TokenType
from midas.parser.midas import MidasParser
from midas.parser.python import PythonParser
def highlight_python(source: str, path: str) -> Highlighter:
tree: ast.Module = ast.parse(source, filename=path)
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
highlighter = PythonHighlighter(source)
for stmt in stmts:
highlighter.highlight(stmt)
return highlighter
def highlight_midas(source: str, path: str) -> Highlighter:
lexer = MidasLexer(source, file=path)
tokens: list[Token] = lexer.process()
parser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
highlighter = MidasHighlighter(source)
for err in parser.errors:
print(err.get_report())
for stmt in stmts:
highlighter.highlight(stmt)
for token in tokens:
if token.type == TokenType.COMMENT:
highlighter.wrap(LocatableToken(token), "comment")
elif token.is_keyword:
highlighter.wrap(LocatableToken(token), "keyword")
return highlighter
@click.command()
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
def highlight(output: TextIO, file: TextIO):
source: str = file.read()
highlighter: Highlighter
if file.name.endswith(".py"):
highlighter = highlight_python(source, file.name)
elif file.name.endswith(".midas"):
highlighter = highlight_midas(source, file.name)
else:
raise ValueError("Unsupported file type")
highlighter.dump(output)

View File

@@ -0,0 +1,66 @@
# **Parse and pretty-print AST**
# ```shell
# midas parse <file.midas / file.py>
# ```
import ast
from typing import TextIO
import click
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.printer import MidasAstPrinter, PythonAstPrinter
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
from midas.parser.python import PythonParser
def dump_python_ast(tree: ast.Module) -> str:
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
printer = PythonAstPrinter()
dump: str = ""
for stmt in stmts:
dump += printer.print(stmt)
dump += "\n"
return dump
def dump_midas_ast(source: str, filename: str) -> str:
lexer = MidasLexer(source, file=filename)
tokens: list[Token] = lexer.process()
parser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
if len(parser.errors) != 0:
for err in parser.errors:
print(err.get_report())
raise RuntimeError("A parsing error occurred")
printer = MidasAstPrinter()
dump: str = ""
for stmt in stmts:
dump += printer.print(stmt)
dump += "\n"
return dump
@click.command()
@click.argument("file", type=click.File("r"))
@click.option("--raw", is_flag=True)
def parse(file: TextIO, raw: bool):
source: str = file.read()
dump: str
if file.name.endswith(".py"):
tree: ast.Module = ast.parse(source, filename=file.name)
if raw:
dump = ast.dump(tree, indent=4)
else:
dump = dump_python_ast(tree)
elif file.name.endswith(".midas"):
dump = dump_midas_ast(source, file.name)
else:
raise ValueError("Unsupported file type")
click.echo(dump)

View File

@@ -0,0 +1,30 @@
# **Dump types registry**
# ```shell
# midas dump-registry [--types <file.midas>]
# ```
from pathlib import Path
from typing import TextIO
import click
from midas.checker.checker import TypeChecker
from midas.checker.types import Type
@click.command()
@click.option("-t", "--types", type=click.File("r"), multiple=True)
def dump_registry(
types: tuple[TextIO],
):
checker = TypeChecker()
for types_file in types:
checker.import_midas(Path(types_file.name).resolve())
for name, type in checker.types._types.items():
members: dict[str, Type] = checker.types._members.get(name, {})
print(f"{name} = {type}")
if len(members) != 0:
print(" " * 4 + "Members:")
for member_name, member_type in members.items():
print(" " * 8 + f"{member_name}: {member_type}")

View File

@@ -0,0 +1,51 @@
# **Print judgements**
# ```shell
# midas types <file.py> [--types <file.midas>]
# ```
from pathlib import Path
from typing import Optional, TextIO
import click
from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.cli.highlighter import DiagnosticsHighlighter
from midas.cli.utils import DiagnosticPrinter
@click.command()
@click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-l", "--highlight", type=click.File("w"))
def types(
file: TextIO,
types: tuple[TextIO],
highlight: Optional[TextIO],
):
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
for types_file in types:
checker.import_midas(Path(types_file.name).resolve())
checker.type_check(source_path)
diagnostics: list[Diagnostic] = []
for expr, type in checker.python_typer.judgements:
diagnostics.append(
Diagnostic(
file_path=str(source_path),
location=expr.location,
type=DiagnosticType.INFO,
message=f"Type: {type}",
)
)
printer = DiagnosticPrinter()
printer.print_all(diagnostics)
if highlight is not None:
source: str = file.read()
highlighter = DiagnosticsHighlighter(source)
highlighter.highlight(diagnostics)
highlighter.dump(highlight)

View File

@@ -0,0 +1,37 @@
# **Validate midas definitions**
# ```shell
# midas validate <file.midas>
# ```
from pathlib import Path
from typing import Optional, TextIO
import click
from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import Diagnostic
from midas.cli.highlighter import DiagnosticsHighlighter
from midas.cli.utils import DiagnosticPrinter
@click.command()
@click.argument("file", type=click.File("r"))
@click.option("-l", "--highlight", type=click.File("w"))
def validate(
file: TextIO,
highlight: Optional[TextIO],
):
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
checker.import_midas(source_path)
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
printer = DiagnosticPrinter()
printer.print_all(diagnostics)
if highlight is not None:
source: str = file.read()
highlighter = DiagnosticsHighlighter(source)
highlighter.highlight(diagnostics)
highlighter.dump(highlight)

View File

@@ -188,6 +188,9 @@ class PythonHighlighter(
for else_stmt in stmt.orelse:
else_stmt.accept(self)
def visit_pass(self, stmt: p.Pass) -> None:
pass
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
def visit_compare_expr(self, expr: p.CompareExpr) -> None: ...

View File

@@ -1,273 +1,24 @@
import ast
import json
import logging
from pathlib import Path
from typing import Optional, TextIO, get_args
import click
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.types import Type
from midas.cli.ansi import Ansi
from midas.cli.highlighter import (
DiagnosticsHighlighter,
Highlighter,
LocatableToken,
MidasHighlighter,
PythonHighlighter,
)
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token, TokenType
from midas.parser.midas import MidasParser
from midas.parser.python import PythonParser
from midas.utils import UniversalJSONDumper
from midas.cli import commands
@click.group()
def midas():
pass
def print_diagnostic(lines: list[str], diagnostic: Diagnostic, indent: int = 4):
"""Pretty-print a diagnostic, showing some context if possible
If the diagnostic concerns a specific part of one line, the line is shown
with the affected part highlighted. The message is clearly printed under the
line with an underline further indicating the target expression.
If multiple lines are concerned, no context is shown, only the
diagnostic type, location and message
Args:
lines (list[str]): source code lines
diagnostic (Diagnostic): the diagnostic to print
indent (int, optional): the number of spaces added before the target line to indent if from the location header. Defaults to 4.
"""
loc: Location = diagnostic.location
if loc.lineno != loc.end_lineno:
print(diagnostic)
return
start_offset: int = loc.col_offset
end_offset: int = loc.end_col_offset or (start_offset + 1)
line: str = lines[loc.lineno - 1]
before: str = line[:start_offset]
after: str = line[end_offset:]
color: int = {
DiagnosticType.ERROR: Ansi.RED,
DiagnosticType.WARNING: Ansi.YELLOW,
DiagnosticType.INFO: Ansi.CYAN,
}.get(diagnostic.type, Ansi.WHITE)
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
cursor: str = (
" " * start_offset
+ Ansi.FG(color)
+ "~" * (end_offset - start_offset)
+ "> "
+ diagnostic.message
+ Ansi.RESET
)
indent_str: str = " " * indent
print(diagnostic.location_str + ":")
print(indent_str + before + subject + after)
print(indent_str + cursor)
print()
@midas.command()
@click.option("-l", "--highlight", type=click.File("w"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-v", "--verbose", is_flag=True)
@click.option("-j", "--show-judgements", is_flag=True)
@click.argument("file", type=click.File("r"))
def compile(
highlight: Optional[TextIO],
types: tuple[TextIO],
verbose: bool,
show_judgements: bool,
file: TextIO,
):
def midas(verbose: bool):
logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN)
source: str = file.read()
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
for types_file in types:
checker.import_midas(Path(types_file.name).resolve())
checker.type_check_source(source, str(source_path))
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
lines: list[str] = source.split("\n")
files: dict[Optional[str], list[str]] = {None: []}
if show_judgements:
for expr, type in checker.python_typer.judgements:
print(f"Judged that {expr} at {expr.location} is of type {type}")
diagnostics.append(
Diagnostic(
file_path=str(source_path),
location=expr.location,
type=DiagnosticType.INFO,
message=f"Type: {type}",
)
)
for diagnostic in diagnostics:
filename: Optional[str] = diagnostic.file_path
if filename is not None and filename not in files:
path: Path = Path(filename)
if path.exists() and path.is_file():
files[filename] = path.read_text().split("\n")
else:
files[filename] = []
lines: list[str] = files[filename]
print_diagnostic(lines, diagnostic)
if verbose:
print(
json.dumps(
UniversalJSONDumper.dump(
checker.python_typer.global_env,
[("Environment", "_children")],
lambda obj: isinstance(obj, get_args(Type)),
),
indent=4,
)
)
if highlight is not None:
highlighter = DiagnosticsHighlighter(source)
highlighter.highlight(diagnostics)
highlighter.dump(highlight)
@midas.group()
def utils():
pass
def dump_python_ast(tree: ast.Module) -> str:
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
printer = PythonAstPrinter()
dump: str = ""
for stmt in stmts:
dump += printer.print(stmt)
dump += "\n"
return dump
def dump_midas_ast(source: str, filename: str) -> str:
lexer = MidasLexer(source, file=filename)
tokens: list[Token] = lexer.process()
parser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
if len(parser.errors) != 0:
for err in parser.errors:
print(err.get_report())
raise RuntimeError("A parsing error occurred")
printer = MidasAstPrinter()
dump: str = ""
for stmt in stmts:
dump += printer.print(stmt)
dump += "\n"
return dump
@utils.command()
@click.option("-o", "--output", type=click.File("w"))
@click.option("-p", "--parse", is_flag=True)
@click.argument("file", type=click.File("r"))
def dump_ast(output: Optional[TextIO], parse: bool, file: TextIO):
source: str = file.read()
dump: str
if file.name.endswith(".py"):
tree: ast.Module = ast.parse(source, filename=file.name)
if parse:
dump = dump_python_ast(tree)
else:
dump = ast.dump(tree, indent=4)
elif file.name.endswith(".midas"):
dump = dump_midas_ast(source, file.name)
else:
raise ValueError("Unsupported file type")
if output is None:
click.echo(dump)
else:
output.write(dump)
def highlight_python(source: str, path: str) -> Highlighter:
tree: ast.Module = ast.parse(source, filename=path)
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
highlighter = PythonHighlighter(source)
for stmt in stmts:
highlighter.highlight(stmt)
return highlighter
def highlight_midas(source: str, path: str) -> Highlighter:
lexer = MidasLexer(source, file=path)
tokens: list[Token] = lexer.process()
parser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
highlighter = MidasHighlighter(source)
for err in parser.errors:
print(err.get_report())
for stmt in stmts:
highlighter.highlight(stmt)
for token in tokens:
if token.type == TokenType.COMMENT:
highlighter.wrap(LocatableToken(token), "comment")
elif token.is_keyword:
highlighter.wrap(LocatableToken(token), "keyword")
return highlighter
@utils.command()
@click.option("-o", "--output", type=click.File("w"), default="-")
@click.argument("file", type=click.File("r"))
def highlight(output: TextIO, file: TextIO):
source: str = file.read()
highlighter: Highlighter
if file.name.endswith(".py"):
highlighter = highlight_python(source, file.name)
elif file.name.endswith(".midas"):
highlighter = highlight_midas(source, file.name)
else:
raise ValueError("Unsupported file type")
highlighter.dump(output)
@midas.command()
@click.option("-o", "--output", type=click.File("w"), default="-")
@click.argument("file", type=click.File("r"))
def format(output: TextIO, file: TextIO):
source: str = file.read()
printer = MidasPrinter()
lexer = MidasLexer(source, file=file.name)
tokens: list[Token] = lexer.process()
parser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
for err in parser.errors:
print(err.get_report())
for stmt in stmts:
output.write(printer.print(stmt) + "\n")
midas.add_command(commands.check)
midas.add_command(commands.compile)
midas.add_command(commands.format)
midas.add_command(commands.highlight)
midas.add_command(commands.parse)
midas.add_command(commands.dump_registry)
midas.add_command(commands.types)
midas.add_command(commands.validate)
if __name__ == "__main__":

78
midas/cli/utils.py Normal file
View File

@@ -0,0 +1,78 @@
from pathlib import Path
from typing import Optional
from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.cli.ansi import Ansi
class DiagnosticPrinter:
def __init__(self) -> None:
self.files: dict[Optional[str], list[str]] = {}
def get_lines(self, filename: Optional[str]) -> list[str]:
if filename is None:
return []
if filename not in self.files:
path: Path = Path(filename)
if path.exists() and path.is_file():
self.files[filename] = path.read_text().split("\n")
else:
self.files[filename] = []
return self.files[filename]
def print_all(self, diagnostics: list[Diagnostic], indent: int = 4):
for diagnostic in diagnostics:
filename: Optional[str] = diagnostic.file_path
lines = self.get_lines(filename)
self.print(lines, diagnostic, indent=indent)
def print(self, lines: list[str], diagnostic: Diagnostic, indent: int = 4):
"""Pretty-print a diagnostic, showing some context if possible
If the diagnostic concerns a specific part of one line, the line is shown
with the affected part highlighted. The message is clearly printed under the
line with an underline further indicating the target expression.
If multiple lines are concerned, no context is shown, only the
diagnostic type, location and message
Args:
lines (list[str]): source code lines
diagnostic (Diagnostic): the diagnostic to print
indent (int, optional): the number of spaces added before the target line to indent if from the location header. Defaults to 4.
"""
loc: Location = diagnostic.location
if loc.lineno != loc.end_lineno:
print(diagnostic)
return
start_offset: int = loc.col_offset
end_offset: int = loc.end_col_offset or (start_offset + 1)
line: str = lines[loc.lineno - 1]
before: str = line[:start_offset]
after: str = line[end_offset:]
color: int = {
DiagnosticType.ERROR: Ansi.RED,
DiagnosticType.WARNING: Ansi.YELLOW,
DiagnosticType.INFO: Ansi.CYAN,
}.get(diagnostic.type, Ansi.WHITE)
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
cursor: str = (
" " * start_offset
+ Ansi.FG(color)
+ "~" * (end_offset - start_offset)
+ "> "
+ diagnostic.message
+ Ansi.RESET
)
indent_str: str = " " * indent
print(diagnostic.location_str + ":")
print(indent_str + before + subject + after)
print(indent_str + cursor)
print()

View File

@@ -0,0 +1,165 @@
import ast
import shutil
from pathlib import Path
import midas.ast.python as p
from midas.utils import TypedAST
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def __init__(self, workdir: Path) -> None:
self.workdir: Path = workdir.resolve()
self.build_dir: Path = self.workdir / "build" / "midas"
if self.build_dir.exists():
shutil.rmtree(self.build_dir)
self.build_dir.mkdir(parents=True, exist_ok=True)
def generate(self, typed_ast: TypedAST, src_path: Path) -> Path:
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
module = ast.Module(body=body, type_ignores=[])
module = ast.fix_missing_locations(module)
compiled: str = ast.unparse(module)
rel_src_path: Path = src_path.relative_to(self.workdir)
out_path: Path = (self.build_dir / rel_src_path).resolve()
try:
_ = out_path.relative_to(self.build_dir)
except ValueError:
raise ValueError(
f"Directory traversal, {rel_src_path} points outside of parent directory"
)
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(compiled)
return out_path
def visit_binary_expr(self, expr: p.BinaryExpr) -> ast.expr:
return ast.BinOp(
left=expr.left.accept(self),
op=expr.operator,
right=expr.right.accept(self),
)
def visit_compare_expr(self, expr: p.CompareExpr) -> ast.expr:
return ast.Compare(
left=expr.left.accept(self),
ops=[expr.operator],
comparators=[expr.right.accept(self)],
)
def visit_unary_expr(self, expr: p.UnaryExpr) -> ast.expr:
return ast.UnaryOp(
op=expr.operator,
operand=expr.right.accept(self),
)
def visit_call_expr(self, expr: p.CallExpr) -> ast.expr:
return ast.Call(
func=expr.callee.accept(self),
args=[arg.accept(self) for arg in expr.arguments],
keywords=[
ast.keyword(arg=name, value=arg.accept(self))
for name, arg in expr.keywords.items()
],
)
def visit_get_expr(self, expr: p.GetExpr) -> ast.expr:
return ast.Attribute(
value=expr.object.accept(self),
attr=expr.name,
)
def visit_literal_expr(self, expr: p.LiteralExpr) -> ast.expr:
return ast.Constant(value=expr.value)
def visit_variable_expr(self, expr: p.VariableExpr) -> ast.expr:
return ast.Name(id=expr.name)
def visit_logical_expr(self, expr: p.LogicalExpr) -> ast.expr:
return ast.BoolOp(
op=expr.operator,
values=[expr.left.accept(self), expr.right.accept(self)],
)
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
# TODO: insert assertion
return expr.expr.accept(self)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
return ast.IfExp(
test=expr.test.accept(self),
body=expr.if_true.accept(self),
orelse=expr.if_false.accept(self),
)
def visit_list_expr(self, expr: p.ListExpr) -> ast.expr:
return ast.List(
elts=[item.accept(self) for item in expr.items],
)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr:
return ast.Subscript(
value=expr.object.accept(self),
slice=expr.index.accept(self),
)
def visit_slice_expr(self, expr: p.SliceExpr) -> ast.expr:
return ast.Slice(
lower=expr.lower.accept(self) if expr.lower is not None else None,
upper=expr.upper.accept(self) if expr.upper is not None else None,
step=expr.step.accept(self) if expr.step is not None else None,
)
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt:
return ast.Expr(
value=stmt.expr.accept(self),
)
def visit_function(self, stmt: p.Function) -> ast.stmt:
return ast.FunctionDef(
name=stmt.name,
args=ast.arguments(
posonlyargs=[ast.arg(arg=arg.name) for arg in stmt.posonlyargs],
vararg=None,
args=[ast.arg(arg=arg.name) for arg in stmt.args],
kwonlyargs=[ast.arg(arg=arg.name) for arg in stmt.kwonlyargs],
kwarg=None,
defaults=[
arg.default.accept(self)
for arg in stmt.posonlyargs + stmt.args
if arg.default is not None
],
kw_defaults=[
arg.default.accept(self) if arg.default is not None else None
for arg in stmt.kwonlyargs
],
),
body=self._visit_body(stmt.body),
decorator_list=[],
)
def visit_type_assign(self, stmt: p.TypeAssign) -> ast.stmt:
# TODO: is that ok?
return ast.Pass()
def visit_assign_stmt(self, stmt: p.AssignStmt) -> ast.stmt:
return ast.Assign(
targets=[target.accept(self) for target in stmt.targets],
value=stmt.value.accept(self),
)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> ast.stmt:
return ast.Return(
value=stmt.value.accept(self) if stmt.value is not None else None,
)
def visit_if_stmt(self, stmt: p.IfStmt) -> ast.stmt:
return ast.If(
test=stmt.test.accept(self),
body=self._visit_body(stmt.body),
orelse=self._visit_body(stmt.orelse),
)
def visit_pass(self, stmt: p.Pass) -> ast.stmt:
return ast.Pass()
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
return [stmt.accept(self) for stmt in stmts]

View File

@@ -1,5 +1,9 @@
from dataclasses import dataclass
from typing import Any, Callable, Optional
import midas.ast.python as p
from midas.checker.types import Type
AllowRepeat = Callable[[object], bool]
@@ -52,3 +56,9 @@ class UniversalJSONDumper:
}
case _:
raise ValueError(f"Unsupported value: {obj}")
@dataclass(frozen=True, kw_only=True)
class TypedAST:
stmts: list[p.Stmt]
judgements: list[tuple[p.Expr, Type]]

View File

@@ -20,6 +20,7 @@ from midas.ast.python import (
LiteralExpr,
LogicalExpr,
MidasType,
Pass,
ReturnStmt,
SliceExpr,
Stmt,
@@ -176,6 +177,11 @@ class PythonAstJsonSerializer(
"orelse": self._serialize_list(stmt.orelse),
}
def visit_pass(self, stmt: Pass) -> dict:
return {
"_type": "Pass",
}
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
return {
"_type": "BinaryExpr",