diff --git a/midas/ast/printer.py b/midas/ast/printer.py index ed9e069..ee59279 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -283,7 +283,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] def indented(self, text: str) -> str: return " " * (self.level * self.indent) + text - def print(self, expr: m.Expr | m.Stmt): + def print(self, expr: m.Expr | m.Stmt | m.Type) -> str: self.level = 0 return expr.accept(self) @@ -314,13 +314,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] for op in stmt.operations: res += op.accept(self) self.level -= 1 - res += "\n" + self.indented("}") + res += self.indented("}") return res def visit_op_stmt(self, stmt: m.OpStmt): operand: str = stmt.operand.accept(self) result: str = stmt.result.accept(self) - return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}") + return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}\n") def visit_predicate_stmt(self, stmt: m.PredicateStmt): name: str = stmt.name.lexeme diff --git a/midas/cli/main.py b/midas/cli/main.py index 71635d0..a9833bb 100644 --- a/midas/cli/main.py +++ b/midas/cli/main.py @@ -8,7 +8,7 @@ import click import midas.ast.midas as m import midas.ast.python as p -from midas.ast.printer import MidasAstPrinter, PythonAstPrinter +from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter from midas.checker.checker import Checker from midas.checker.diagnostic import Diagnostic from midas.checker.types import Type @@ -167,5 +167,21 @@ def highlight(output: TextIO, file: TextIO): highlighter.dump(output) +@midas.command() +@click.option("-o", "--output", type=click.File("w"), default="-") +@click.argument("file", type=click.File("r")) +def format(output: TextIO, file: TextIO): + source: str = file.read() + printer = MidasPrinter() + lexer = MidasLexer(source, file=file.name) + tokens: list[Token] = lexer.process() + parser = MidasParser(tokens) + stmts: list[m.Stmt] = parser.parse() + for err in parser.errors: + print(err.get_report()) + for stmt in stmts: + output.write(printer.print(stmt) + "\n") + + if __name__ == "__main__": midas()