diff --git a/gen/gen.py b/gen/gen.py index 75e6100..4f15521 100644 --- a/gen/gen.py +++ b/gen/gen.py @@ -11,7 +11,7 @@ SECTION_TEMPLATE = """{banner} @dataclass(frozen=True, kw_only=True) class {base}(ABC): - location: Optional[Location] = None + location: Location @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... diff --git a/midas/ast/midas.py b/midas/ast/midas.py index 9cea8c2..c307e85 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -21,7 +21,7 @@ T = TypeVar("T") @dataclass(frozen=True, kw_only=True) class Stmt(ABC): - location: Optional[Location] = None + location: Location @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... @@ -114,7 +114,7 @@ class PredicateStmt(Stmt): @dataclass(frozen=True, kw_only=True) class Expr(ABC): - location: Optional[Location] = None + location: Location @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... diff --git a/midas/ast/python.py b/midas/ast/python.py index 96c218e..48f4643 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -21,7 +21,7 @@ T = TypeVar("T") @dataclass(frozen=True, kw_only=True) class MidasType(ABC): - location: Optional[Location] = None + location: Location @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... @@ -82,7 +82,7 @@ class FrameType(MidasType): @dataclass(frozen=True, kw_only=True) class Stmt(ABC): - location: Optional[Location] = None + location: Location @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... @@ -157,7 +157,7 @@ class AssignStmt(Stmt): @dataclass(frozen=True, kw_only=True) class Expr(ABC): - location: Optional[Location] = None + location: Location @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... diff --git a/midas/parser/midas.py b/midas/parser/midas.py index 4998c51..db0efdf 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -205,9 +205,7 @@ class MidasParser(Parser): while self.match(TokenType.AND): operator: Token = self.previous() right: Expr = self.equality() - location: Optional[Location] = None - if expr.location and right.location: - location = Location.span(expr.location, right.location) + location: Location = Location.span(expr.location, right.location) expr = LogicalExpr( location=location, left=expr, operator=operator, right=right ) @@ -223,9 +221,7 @@ class MidasParser(Parser): while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL): operator: Token = self.previous() right: Expr = self.comparison() - location: Optional[Location] = None - if expr.location and right.location: - location = Location.span(expr.location, right.location) + location: Location = Location.span(expr.location, right.location) expr = BinaryExpr( location=location, left=expr, operator=operator, right=right ) @@ -246,9 +242,7 @@ class MidasParser(Parser): ): operator: Token = self.previous() right: Expr = self.unary() - location: Optional[Location] = None - if expr.location and right.location: - location = Location.span(expr.location, right.location) + location: Location = Location.span(expr.location, right.location) expr = BinaryExpr( location=location, left=expr, operator=operator, right=right ) @@ -263,9 +257,7 @@ class MidasParser(Parser): if self.match(TokenType.MINUS): operator: Token = self.previous() right: Expr = self.unary() - location: Optional[Location] = None - if right.location: - location = Location.span(operator.get_location(), right.location) + location: Location = Location.span(operator.get_location(), right.location) return UnaryExpr(location=location, operator=operator, right=right) return self.reference() @@ -280,9 +272,7 @@ class MidasParser(Parser): name: Token = self.consume( TokenType.IDENTIFIER, "Expected property name after '.'" ) - location: Optional[Location] = None - if expr.location: - location = Location.span(expr.location, name.get_location()) + location: Location = Location.span(expr.location, name.get_location()) expr = GetExpr(location=location, expr=expr, name=name) return expr @@ -370,9 +360,7 @@ class MidasParser(Parser): while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE): operations.append(self.op_declaration()) self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body") - location: Optional[Location] = None - if type.location: - location = keyword.location_to(self.previous()) + location: Location = keyword.location_to(self.previous()) return ExtendStmt(location=location, type=type, operations=operations) def op_declaration(self) -> OpStmt: diff --git a/midas/parser/python.py b/midas/parser/python.py index 470f666..892725a 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -53,6 +53,7 @@ class PythonParser: return statements def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]: + location: Location = Location.from_ast(node) match node: case ast.AnnAssign(): return self.parse_annotation_assign(node) @@ -64,7 +65,10 @@ class PythonParser: return self.parse_function(node) case ast.Expr(value=expr): - return ExpressionStmt(expr=self.parse_expr(expr)) + return ExpressionStmt( + location=location, + expr=self.parse_expr(expr), + ) case _: print(f"Unsupported statement: {ast.unparse(node)}") @@ -266,12 +270,14 @@ class PythonParser: raise UnsupportedSyntaxError(column) def parse_expr(self, node: ast.expr) -> Expr: + location: Location = Location.from_ast(node) match node: case ast.BoolOp(): return self.parse_bool_op(node) case ast.BinOp(left=left, op=op, right=right): return BinaryExpr( + location=location, left=self.parse_expr(left), operator=op, right=self.parse_expr(right), @@ -279,6 +285,7 @@ class PythonParser: case ast.UnaryOp(op=op, operand=right): return UnaryExpr( + location=location, operator=op, right=self.parse_expr(right), ) @@ -290,58 +297,73 @@ class PythonParser: return self.parse_call(node) case ast.Constant(value=value): - return LiteralExpr(value=value) + return LiteralExpr(location=location, value=value) case ast.Attribute(value=object, attr=name): return GetExpr( + location=location, object=self.parse_expr(object), name=name, ) case ast.Name(id=name): - return VariableExpr(name=name) + return VariableExpr(location=location, name=name) case _: raise UnsupportedSyntaxError(node) def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr: op: ast.boolop = node.op - values: list[ast.expr] = node.values + rights: list[Expr] = [self.parse_expr(expr) for expr in node.values] expr: LogicalExpr = LogicalExpr( - left=self.parse_expr(values[0]), + location=Location.span( + rights[0].location, + rights[1].location, + ), + left=rights[0], operator=op, - right=self.parse_expr(values[1]), + right=rights[1], ) - for value in values[2:]: + for right in rights[2:]: expr = LogicalExpr( + location=Location.span(expr.location, right.location), left=expr, operator=op, - right=self.parse_expr(value), + right=right, ) return expr def parse_compare(self, node: ast.Compare) -> Expr: ops: list[ast.cmpop] = node.ops + left: Expr = self.parse_expr(node.left) rights: list[Expr] = [self.parse_expr(expr) for expr in node.comparators] expr: Expr = CompareExpr( - left=self.parse_expr(node.left), + location=Location.span( + left.location, + rights[0].location, + ), + left=left, operator=ops[0], right=rights[0], ) for i, right in enumerate(rights[1:]): + comparison = CompareExpr( + location=Location.span(rights[i].location, right.location), + left=rights[i], + operator=ops[i], + right=right, + ) expr = LogicalExpr( + location=Location.span(expr.location, comparison.location), left=expr, operator=ast.And(), - right=CompareExpr( - left=rights[i], - operator=ops[i], - right=right, - ), + right=comparison, ) return expr def parse_call(self, node: ast.Call) -> CallExpr: return CallExpr( + location=Location.from_ast(node), callee=self.parse_expr(node.func), arguments=[self.parse_expr(arg) for arg in node.args], keywords={