diff --git a/gen/python.py b/gen/python.py index 2e6c7d3..99b926d 100644 --- a/gen/python.py +++ b/gen/python.py @@ -86,6 +86,12 @@ class Pass: pass +class ForStmt: + target: Expr + iterator: Expr + body: list[Stmt] + + ###< diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 77a6069..364da35 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -596,6 +596,23 @@ class PythonAstPrinter( def visit_pass(self, stmt: p.Pass) -> None: self._write_line("Pass") + def visit_for_stmt(self, stmt: p.ForStmt) -> None: + self._write_line("ForStmt") + with self._child_level(): + self._write_line("target") + with self._child_level(single=True): + stmt.target.accept(self) + self._write_line("iterator") + with self._child_level(single=True): + stmt.iterator.accept(self) + self._write_line("body", last=True) + with self._child_level(): + for i, body_stmt in enumerate(stmt.body): + self._idx = i + if i == len(stmt.body) - 1: + self._mark_last() + body_stmt.accept(self) + def visit_binary_expr(self, expr: p.BinaryExpr) -> None: self._write_line("BinaryExpr") with self._child_level(): diff --git a/midas/ast/python.py b/midas/ast/python.py index f3c9e1c..b781d2e 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -110,6 +110,9 @@ class Stmt(ABC): @abstractmethod def visit_pass(self, stmt: Pass) -> T: ... + @abstractmethod + def visit_for_stmt(self, stmt: ForStmt) -> T: ... + @dataclass(frozen=True) class ExpressionStmt(Stmt): @@ -189,6 +192,16 @@ class Pass(Stmt): return visitor.visit_pass(self) +@dataclass(frozen=True) +class ForStmt(Stmt): + target: Expr + iterator: Expr + body: list[Stmt] + + def accept(self, visitor: Stmt.Visitor[T]) -> T: + return visitor.visit_for_stmt(self) + + ############### # Expressions # ############### diff --git a/midas/checker/python.py b/midas/checker/python.py index d667164..316836e 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -78,17 +78,37 @@ class PythonTyper( return TypedAST(stmts=stmts, judgements=self.judgements) - def type_of(self, expr: p.Expr) -> Type: + def judge(self, expr: p.Expr, type: Type): + """Record a typing judgement + + Args: + expr (p.Expr): the judged expression + type (Type): the type of the expression + """ + self.judgements.append((expr, type)) + + def compute_type(self, expr: p.Expr) -> Type: """Evaluate the type of an expression + Args: + expr (p.Expr): the expression to type + + Returns: + Type: the type of the given expression + """ + return expr.accept(self) + + def type_of(self, expr: p.Expr) -> Type: + """Evaluate the type of an expression and record the judgement + Args: expr (p.Expr): the expression to evaluate Returns: Type: the type of the given expression """ - type: Type = expr.accept(self) - self.judgements.append((expr, type)) + type: Type = self.compute_type(expr) + self.judge(expr, type) return type def resolve_type_expr(self, expr: p.MidasType) -> Type: @@ -334,6 +354,22 @@ class PythonTyper( def visit_pass(self, stmt: p.Pass) -> None: pass + def visit_for_stmt(self, stmt: p.ForStmt) -> None: + item_type: Optional[Type] = self._get_iterator_type(stmt.iterator) + if item_type is None: + iterator_type: Type = self.compute_type(stmt.iterator) + self.reporter.error( + stmt.iterator.location, f"{iterator_type} is not iterable" + ) + item_type = UnknownType() + + self._assign(stmt.location, stmt.target, item_type) + self.judge(stmt.target, item_type) + env: Environment = Environment(self.env) + body_returned: bool = self.process_block(stmt.body, env) + if body_returned: + raise ReturnException() + def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) if method is None: @@ -370,7 +406,13 @@ class PythonTyper( ) return UnknownType() - return self._get_call_result(location, operation, [(right_expr, right)], {}) + result: Optional[Type] = self._get_call_result( + location, + operation, + [(right_expr, right)], + {}, + ) + return result or UnknownType() def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__) @@ -390,9 +432,13 @@ class PythonTyper( ) return UnknownType() - return self._get_call_result( - expr.location, operation, [(expr.right, operand)], {} + result: Optional[Type] = self._get_call_result( + expr.location, + operation, + [], + {}, ) + return result or UnknownType() def visit_call_expr(self, expr: p.CallExpr) -> Type: callee: Type = self.type_of(expr.callee) @@ -402,11 +448,14 @@ class PythonTyper( keywords: dict[str, TypedExpr] = { name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items() } - return self._get_call_result( - location=expr.location, - callee=callee, - positional=positional, - keywords=keywords, + return ( + self._get_call_result( + location=expr.location, + callee=callee, + positional=positional, + keywords=keywords, + ) + or UnknownType() ) def visit_get_expr(self, expr: p.GetExpr) -> Type: @@ -509,8 +558,9 @@ class PythonTyper( return UnknownType() index: Type = self.type_of(expr.index) - return self._get_call_result( - expr.location, operation, [(expr.index, index)], {} + return ( + self._get_call_result(expr.location, operation, [(expr.index, index)], {}) + or UnknownType() ) def visit_slice_expr(self, expr: p.SliceExpr) -> Type: @@ -547,7 +597,8 @@ class PythonTyper( callee: Type, positional: list[TypedExpr], keywords: dict[str, TypedExpr], - ) -> Type: + report_errors: bool = True, + ) -> Optional[Type]: """Get the result type of a function call If the function has overloads, the function will try to resolve the @@ -561,9 +612,10 @@ class PythonTyper( callee (Type): the called function positional (list[TypedExpr]): the list positional arguments keywords (dict[str, TypedExpr]): the map of keyword arguments + report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. Returns: - Type: the return type of the call, or `UnknownType` if either + Type: the return type of the call, or `None` if either the call is invalid or no overload matched the arguments uniquely """ match callee: @@ -573,21 +625,22 @@ class PythonTyper( valid, mapped = self.map_call_arguments( function, location, positional, keywords ) - valid = valid and self._are_arguments_valid(mapped) + valid = valid and self._are_arguments_valid(mapped, report_errors) if not valid: - return UnknownType() + return None return function.returns case OverloadedFunction(overloads=overloads): function = self._match_overload( - overloads, location, positional, keywords + overloads, location, positional, keywords, report_errors ) if function is None: - return UnknownType() + return None return function.returns case _: - self.reporter.error(location, f"{callee} is not callable") - return UnknownType() + if report_errors: + self.reporter.error(location, f"{callee} is not callable") + return None def _are_arguments_valid( self, @@ -620,6 +673,7 @@ class PythonTyper( location: Location, positional: list[TypedExpr], keywords: dict[str, TypedExpr], + report_errors: bool = True, ) -> Optional[Function]: """Try and resolve the appropriate overload for the given arguments @@ -628,6 +682,7 @@ class PythonTyper( location (Location): the call location positional (list[TypedExpr]): the list of positional arguments keywords (dict[str, TypedExpr]): the map of keywords arguments + report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. Returns: Optional[Function]: the resolved function signature if it can be @@ -637,9 +692,10 @@ class PythonTyper( for overload in overloads: function: Type = unfold_type(overload) if not isinstance(function, Function): - self.logger.error( - f"Overload is not a function: {overload} is {function}" - ) + if report_errors: + self.logger.error( + f"Overload is not a function: {overload} is {function}" + ) continue valid, mapped = self.map_call_arguments( function=function, @@ -671,10 +727,11 @@ class PythonTyper( # No match -> invalid call if n_candidates == 0: overloads_str: str = ", ".join(map(str, overloads)) - self.reporter.error( - location, - f"No matching overload in [{overloads_str}] {for_args}", - ) + if report_errors: + self.reporter.error( + location, + f"No matching overload in [{overloads_str}] {for_args}", + ) return None # Multiple matches -> see if one <: all others (more specific) @@ -695,10 +752,11 @@ class PythonTyper( candidates_str: str = ", ".join( str(candidate.function) for candidate in candidates ) - self.reporter.error( - location, - f"Multiple matching overloads {for_args}: {candidates_str}", - ) + if report_errors: + self.reporter.error( + location, + f"Multiple matching overloads {for_args}: {candidates_str}", + ) return None def map_call_arguments( @@ -863,3 +921,21 @@ class PythonTyper( if not self.is_subtype(type1, type2): return False return True + + def _get_iterator_type(self, expr: p.Expr) -> Optional[Type]: + # TODO: lookup __iter__ + type: Type = self.type_of(expr) + getitem: Optional[Type] = self.types.lookup_member(type, "__getitem__") + if getitem is None: + return None + + index: p.Expr = p.LiteralExpr(location=expr.location, value=0) + index_type: Type = self.compute_type(index) + result: Optional[Type] = self._get_call_result( + location=expr.location, + callee=getitem, + positional=[(index, index_type)], + keywords={}, + report_errors=False, + ) + return result diff --git a/midas/checker/resolver.py b/midas/checker/resolver.py index eb0a6e8..c99a18d 100644 --- a/midas/checker/resolver.py +++ b/midas/checker/resolver.py @@ -116,17 +116,20 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: self.resolve(stmt.value) for target in stmt.targets: - match target: - case p.VariableExpr(name=name): - if not self.is_defined(name): - self.declare(name) - self.define(name) - target.accept(self) + self._visit_assign(target) - case p.GetExpr(): - target.accept(self) - case _: - raise Exception(f"Unsupported assignment to {target}") + def _visit_assign(self, target: p.Expr): + match target: + case p.VariableExpr(name=name): + if not self.is_defined(name): + self.declare(name) + self.define(name) + target.accept(self) + + case p.GetExpr(): + target.accept(self) + case _: + raise Exception(f"Unsupported assignment to {target}") def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: if stmt.value is not None: @@ -153,6 +156,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): def visit_pass(self, stmt: p.Pass) -> None: pass + def visit_for_stmt(self, stmt: p.ForStmt) -> None: + self.resolve(stmt.iterator) + self._visit_assign(stmt.target) + self.begin_scope() + self.resolve(*stmt.body) + self.end_scope() + def visit_binary_expr(self, expr: p.BinaryExpr) -> None: self.resolve(expr.left) self.resolve(expr.right) diff --git a/midas/cli/commands/check.py b/midas/cli/commands/check.py index f59c479..22f654e 100644 --- a/midas/cli/commands/check.py +++ b/midas/cli/commands/check.py @@ -14,7 +14,7 @@ from midas.cli.highlighter import DiagnosticsHighlighter from midas.cli.utils import DiagnosticPrinter -@click.command() +@click.command(help="Run type checker and report diagnostics") @click.argument("file", type=click.File("r")) @click.option("-t", "--types", type=click.File("r"), multiple=True) @click.option("-l", "--highlight", type=click.File("w")) diff --git a/midas/cli/commands/compile.py b/midas/cli/commands/compile.py index 7d495fc..5a410f7 100644 --- a/midas/cli/commands/compile.py +++ b/midas/cli/commands/compile.py @@ -3,19 +3,20 @@ # midas compile [--types ] [-o ] [--assertions|--strict|--no-checks] # ``` +import sys from pathlib import Path from typing import TextIO import click from midas.checker.checker import TypeChecker -from midas.checker.diagnostic import Diagnostic +from midas.checker.diagnostic import Diagnostic, DiagnosticType from midas.cli.utils import DiagnosticPrinter from midas.generator.generator import Generator from midas.utils import TypedAST -@click.command() +@click.command(help="Compile source") @click.argument("file", type=click.File("r")) @click.option("-t", "--types", type=click.File("r"), multiple=True) def compile( @@ -34,5 +35,8 @@ def compile( printer = DiagnosticPrinter() printer.print_all(diagnostics) + if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)): + sys.exit(1) + generator = Generator(workdir=source_path.parent) generator.generate(typed_ast, source_path) diff --git a/midas/cli/commands/format.py b/midas/cli/commands/format.py index ce570c0..1bd3dbf 100644 --- a/midas/cli/commands/format.py +++ b/midas/cli/commands/format.py @@ -9,7 +9,7 @@ from midas.lexer.token import Token from midas.parser.midas import MidasParser -@click.command() +@click.command(help="Parse and pretty print a Midas file") @click.argument("file", type=click.File("r")) @click.option("-o", "--output", type=click.File("w"), default="-") def format(file: TextIO, output: TextIO): diff --git a/midas/cli/commands/highlight.py b/midas/cli/commands/highlight.py index acb05e9..cf04a8d 100644 --- a/midas/cli/commands/highlight.py +++ b/midas/cli/commands/highlight.py @@ -46,7 +46,10 @@ def highlight_midas(source: str, path: str) -> Highlighter: return highlighter -@click.command() +@click.command( + help="Parse a Python or Midas file and produce a highlighted version showing AST node types inline", + short_help="Parse and highlight a Python or Midas file", +) @click.argument("file", type=click.File("r")) @click.option("-o", "--output", type=click.File("w"), default="-") def highlight(output: TextIO, file: TextIO): diff --git a/midas/cli/commands/parse.py b/midas/cli/commands/parse.py index c2a2f1a..d2f5338 100644 --- a/midas/cli/commands/parse.py +++ b/midas/cli/commands/parse.py @@ -45,7 +45,7 @@ def dump_midas_ast(source: str, filename: str) -> str: return dump -@click.command() +@click.command(help="Parse a Python or Midas file and pretty-print its AST") @click.argument("file", type=click.File("r")) @click.option("--raw", is_flag=True) def parse(file: TextIO, raw: bool): diff --git a/midas/cli/commands/registry.py b/midas/cli/commands/registry.py index 502be75..4e830be 100644 --- a/midas/cli/commands/registry.py +++ b/midas/cli/commands/registry.py @@ -12,7 +12,7 @@ from midas.checker.checker import TypeChecker from midas.checker.types import Type -@click.command() +@click.command(help="Dump types registry") @click.option("-t", "--types", type=click.File("r"), multiple=True) def dump_registry( types: tuple[TextIO], diff --git a/midas/cli/commands/types.py b/midas/cli/commands/types.py index 439611d..d17a0a8 100644 --- a/midas/cli/commands/types.py +++ b/midas/cli/commands/types.py @@ -14,7 +14,7 @@ from midas.cli.highlighter import DiagnosticsHighlighter from midas.cli.utils import DiagnosticPrinter -@click.command() +@click.command(help="Print typing judgements") @click.argument("file", type=click.File("r")) @click.option("-t", "--types", type=click.File("r"), multiple=True) @click.option("-l", "--highlight", type=click.File("w")) diff --git a/midas/cli/commands/validate.py b/midas/cli/commands/validate.py index 44e33bb..931c3e5 100644 --- a/midas/cli/commands/validate.py +++ b/midas/cli/commands/validate.py @@ -14,7 +14,7 @@ from midas.cli.highlighter import DiagnosticsHighlighter from midas.cli.utils import DiagnosticPrinter -@click.command() +@click.command(help="Validate Midas definitions") @click.argument("file", type=click.File("r")) @click.option("-l", "--highlight", type=click.File("w")) def validate( diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py index ce63d62..1303f94 100644 --- a/midas/cli/highlighter.py +++ b/midas/cli/highlighter.py @@ -191,6 +191,13 @@ class PythonHighlighter( def visit_pass(self, stmt: p.Pass) -> None: pass + def visit_for_stmt(self, stmt: p.ForStmt) -> None: + self.wrap(stmt, "for") + stmt.iterator.accept(self) + stmt.target.accept(self) + for body_stmt in stmt.body: + body_stmt.accept(self) + def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ... def visit_compare_expr(self, expr: p.CompareExpr) -> None: ... diff --git a/midas/generator/generator.py b/midas/generator/generator.py index ceabdd5..bef6e16 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -161,5 +161,13 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): def visit_pass(self, stmt: p.Pass) -> ast.stmt: return ast.Pass() + def visit_for_stmt(self, stmt: p.ForStmt) -> ast.stmt: + return ast.For( + target=stmt.target.accept(self), + iter=stmt.iterator.accept(self), + body=self._visit_body(stmt.body), + orelse=[], + ) + def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]: return [stmt.accept(self) for stmt in stmts] diff --git a/midas/parser/python.py b/midas/parser/python.py index a0726da..55edd34 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -12,6 +12,7 @@ from midas.ast.python import ( ConstraintType, Expr, ExpressionStmt, + ForStmt, FrameColumn, FrameType, Function, @@ -93,6 +94,9 @@ class PythonParser: case ast.Pass(): return None + case ast.For(orelse=[]): + return self.parse_for(node) + case _: print(f"Unsupported statement: {ast.unparse(node)}") return None @@ -182,6 +186,22 @@ class PythonParser: orelse=orelse, ) + def parse_for(self, node: ast.For) -> ForStmt: + body: list[Stmt] = [] + for stmt in node.body: + stmts = self.parse_stmt(stmt) + if isinstance(stmts, Stmt): + body.append(stmts) + elif stmts is not None: + body.extend(stmts) + + return ForStmt( + location=Location.from_ast(node), + target=self.parse_expr(node.target), + iterator=self.parse_expr(node.iter), + body=body, + ) + def parse_function(self, node: ast.FunctionDef) -> Function: loc: Location = Location.from_ast(node) match node: diff --git a/tests/serializer/python.py b/tests/serializer/python.py index e1f3fd4..56171b8 100644 --- a/tests/serializer/python.py +++ b/tests/serializer/python.py @@ -11,6 +11,7 @@ from midas.ast.python import ( ConstraintType, Expr, ExpressionStmt, + ForStmt, FrameColumn, FrameType, Function, @@ -182,6 +183,14 @@ class PythonAstJsonSerializer( "_type": "Pass", } + def visit_for_stmt(self, stmt: ForStmt) -> dict: + return { + "_type": "ForStmt", + "target": stmt.target.accept(self), + "iterator": stmt.iterator.accept(self), + "body": self._serialize_list(stmt.body), + } + def visit_binary_expr(self, expr: BinaryExpr) -> dict: return { "_type": "BinaryExpr",