From e94db2181fc26e772bf45799ef270200884c7355 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Mon, 25 May 2026 12:14:14 +0200 Subject: [PATCH] feat(parser): add location to midas AST nodes --- midas/ast/location.py | 37 +++++++++++++ midas/ast/midas.py | 9 +++- midas/ast/python.py | 28 ++-------- midas/lexer/token.py | 23 +++++++++ midas/parser/midas.py | 115 ++++++++++++++++++++++++++++++++--------- midas/parser/python.py | 2 +- 6 files changed, 161 insertions(+), 53 deletions(-) create mode 100644 midas/ast/location.py diff --git a/midas/ast/location.py b/midas/ast/location.py new file mode 100644 index 0000000..47fe360 --- /dev/null +++ b/midas/ast/location.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Protocol + + +class HasLocation(Protocol): + lineno: int + col_offset: int + end_lineno: Optional[int] + end_col_offset: Optional[int] + + +@dataclass(frozen=True, kw_only=True) +class Location: + lineno: int + col_offset: int + end_lineno: Optional[int] + end_col_offset: Optional[int] + + @staticmethod + def from_ast(obj: HasLocation) -> Location: + return Location( + lineno=obj.lineno, + col_offset=obj.col_offset, + end_lineno=obj.end_lineno, + end_col_offset=obj.end_col_offset, + ) + + @staticmethod + def span(start: Location, end: Location) -> Location: + return Location( + lineno=start.lineno, + col_offset=start.col_offset, + end_lineno=end.lineno, + end_col_offset=end.end_col_offset, + ) diff --git a/midas/ast/midas.py b/midas/ast/midas.py index 28a7819..1ff503d 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -9,6 +9,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Generic, Optional, TypeVar +from midas.ast.location import Location from midas.lexer.token import Token T = TypeVar("T") @@ -18,8 +19,10 @@ T = TypeVar("T") ############## -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class Stmt(ABC): + location: Optional[Location] = None + @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... @@ -109,8 +112,10 @@ class PredicateStmt(Stmt): ############### -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class Expr(ABC): + location: Optional[Location] = None + @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... diff --git a/midas/ast/python.py b/midas/ast/python.py index 878b8b8..c25b438 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -3,35 +3,13 @@ from __future__ import annotations from abc import ABC, abstractmethod import ast from dataclasses import dataclass -from typing import Generic, Optional, Protocol, TypeVar +from typing import Generic, Optional, TypeVar + +from midas.ast.location import Location T = TypeVar("T") -class HasLocation(Protocol): - lineno: int - col_offset: int - end_lineno: Optional[int] - end_col_offset: Optional[int] - - -@dataclass(frozen=True, kw_only=True) -class Location: - lineno: int - col_offset: int - end_lineno: Optional[int] - end_col_offset: Optional[int] - - @staticmethod - def from_ast(obj: HasLocation) -> Location: - return Location( - lineno=obj.lineno, - col_offset=obj.col_offset, - end_lineno=obj.end_lineno, - end_col_offset=obj.end_col_offset, - ) - - @dataclass(frozen=True, kw_only=True) class Expr(ABC): location: Optional[Location] = None diff --git a/midas/lexer/token.py b/midas/lexer/token.py index 76a0fb1..052d8a6 100644 --- a/midas/lexer/token.py +++ b/midas/lexer/token.py @@ -1,7 +1,10 @@ +from __future__ import annotations + from dataclasses import dataclass from enum import Enum, auto from typing import Any +from midas.ast.location import Location from midas.lexer.position import Position @@ -63,3 +66,23 @@ class Token: lexeme: str value: Any position: Position + + def get_location(self) -> Location: + lineno: int = self.position.line + col_offset: int = self.position.column - 1 + end_lineno = lineno + end_col_offset = col_offset + for c in self.lexeme: + end_col_offset += 1 + if c == "\n": + end_lineno += 1 + end_col_offset = 0 + return Location( + lineno=lineno, + col_offset=col_offset, + end_lineno=end_lineno, + end_col_offset=end_col_offset, + ) + + def location_to(self, to: Token) -> Location: + return Location.span(self.get_location(), to.get_location()) diff --git a/midas/parser/midas.py b/midas/parser/midas.py index a919994..4998c51 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -1,5 +1,6 @@ from typing import Optional +from midas.ast.location import Location from midas.ast.midas import ( BinaryExpr, ComplexTypeStmt, @@ -104,6 +105,7 @@ class MidasParser(Parser): Returns: TypeStmt: the parsed type declaration statement """ + keyword: Token = self.previous() name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") template: Optional[TemplateExpr] = None if self.check(TokenType.LEFT_BRACKET): @@ -116,11 +118,20 @@ class MidasParser(Parser): if self.match(TokenType.WHERE): constraint = self.constraint() return SimpleTypeStmt( - name=name, template=template, base=base, constraint=constraint + location=keyword.location_to(self.previous()), + name=name, + template=template, + base=base, + constraint=constraint, ) else: properties: list[PropertyStmt] = self.type_properties() - return ComplexTypeStmt(name=name, template=template, properties=properties) + return ComplexTypeStmt( + location=keyword.location_to(self.previous()), + name=name, + template=template, + properties=properties, + ) def template_expr(self) -> TemplateExpr: """Parse a generic template expression @@ -130,10 +141,14 @@ class MidasParser(Parser): Returns: TemplateExpr: the parsed template expression """ - self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression") + left: Token = self.consume( + TokenType.LEFT_BRACKET, "Missing '[' before template expression" + ) type: TypeExpr = self.type_expr() - self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression") - return TemplateExpr(type=type) + right: Token = self.consume( + TokenType.RIGHT_BRACKET, "Missing ']' after template expression" + ) + return TemplateExpr(location=left.location_to(right), type=type) def type_expr(self) -> TypeExpr: """Parse a type expression @@ -149,7 +164,12 @@ class MidasParser(Parser): if self.check(TokenType.LEFT_BRACKET): template = self.template_expr() optional: bool = self.match(TokenType.QMARK) - return TypeExpr(name=name, template=template, optional=optional) + return TypeExpr( + location=name.location_to(self.previous()), + name=name, + template=template, + optional=optional, + ) def simple_type_expr(self) -> SimpleTypeExpr: """Parse a simple type expression @@ -161,7 +181,9 @@ class MidasParser(Parser): """ name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") optional: bool = self.match(TokenType.QMARK) - return SimpleTypeExpr(name=name, optional=optional) + return SimpleTypeExpr( + location=name.location_to(self.previous()), name=name, optional=optional + ) def constraint(self) -> Expr: """Parse a constraint @@ -183,7 +205,12 @@ class MidasParser(Parser): while self.match(TokenType.AND): operator: Token = self.previous() right: Expr = self.equality() - expr = LogicalExpr(left=expr, operator=operator, right=right) + location: Optional[Location] = None + if expr.location and right.location: + location = Location.span(expr.location, right.location) + expr = LogicalExpr( + location=location, left=expr, operator=operator, right=right + ) return expr def equality(self) -> Expr: @@ -196,7 +223,12 @@ class MidasParser(Parser): while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL): operator: Token = self.previous() right: Expr = self.comparison() - expr = BinaryExpr(left=expr, operator=operator, right=right) + location: Optional[Location] = None + if expr.location and right.location: + location = Location.span(expr.location, right.location) + expr = BinaryExpr( + location=location, left=expr, operator=operator, right=right + ) return expr def comparison(self) -> Expr: @@ -214,7 +246,12 @@ class MidasParser(Parser): ): operator: Token = self.previous() right: Expr = self.unary() - expr = BinaryExpr(left=expr, operator=operator, right=right) + location: Optional[Location] = None + if expr.location and right.location: + location = Location.span(expr.location, right.location) + expr = BinaryExpr( + location=location, left=expr, operator=operator, right=right + ) return expr def unary(self) -> Expr: @@ -226,7 +263,10 @@ class MidasParser(Parser): if self.match(TokenType.MINUS): operator: Token = self.previous() right: Expr = self.unary() - return UnaryExpr(operator=operator, right=right) + location: Optional[Location] = None + if right.location: + location = Location.span(operator.get_location(), right.location) + return UnaryExpr(location=location, operator=operator, right=right) return self.reference() def reference(self) -> Expr: @@ -240,7 +280,10 @@ class MidasParser(Parser): name: Token = self.consume( TokenType.IDENTIFIER, "Expected property name after '.'" ) - expr = GetExpr(expr=expr, name=name) + location: Optional[Location] = None + if expr.location: + location = Location.span(expr.location, name.get_location()) + expr = GetExpr(location=location, expr=expr, name=name) return expr def primary(self) -> Expr: @@ -251,26 +294,27 @@ class MidasParser(Parser): Returns: Expr: the parsed expression """ + token: Token = self.peek() if self.match(TokenType.FALSE): - return LiteralExpr(False) + return LiteralExpr(location=token.get_location(), value=False) if self.match(TokenType.TRUE): - return LiteralExpr(True) + return LiteralExpr(location=token.get_location(), value=True) if self.match(TokenType.NONE): - return LiteralExpr(None) + return LiteralExpr(location=token.get_location(), value=None) if self.match(TokenType.NUMBER): - return LiteralExpr(self.previous().value) + return LiteralExpr(location=token.get_location(), value=token.value) if self.match(TokenType.IDENTIFIER): - return VariableExpr(self.previous()) + return VariableExpr(location=token.get_location(), name=token) if self.match(TokenType.UNDERSCORE): - return WildcardExpr(self.previous()) + return WildcardExpr(location=token.get_location(), token=token) if self.match(TokenType.LEFT_PAREN): expr: Expr = self.constraint() - self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis") - return GroupingExpr(expr) + right: Token = self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis") + return GroupingExpr(location=token.location_to(right), expr=expr) raise self.error(self.peek(), "Expected expression") @@ -304,7 +348,12 @@ class MidasParser(Parser): constraint: Optional[Expr] = None if self.match(TokenType.WHERE): constraint = self.constraint() - return PropertyStmt(name=name, type=type, constraint=constraint) + return PropertyStmt( + location=name.location_to(self.previous()), + name=name, + type=type, + constraint=constraint, + ) def extend_declaration(self) -> ExtendStmt: """Parse an extension definition @@ -314,13 +363,17 @@ class MidasParser(Parser): Returns: ExtendStmt: the parsed extension statement """ + keyword: Token = self.previous() type: TypeExpr = self.type_expr() self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body") operations: list[OpStmt] = [] while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE): operations.append(self.op_declaration()) self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body") - return ExtendStmt(type=type, operations=operations) + location: Optional[Location] = None + if type.location: + location = keyword.location_to(self.previous()) + return ExtendStmt(location=location, type=type, operations=operations) def op_declaration(self) -> OpStmt: """Parse an operation definition @@ -330,7 +383,7 @@ class MidasParser(Parser): Returns: OpStmt: the parsed operation statement """ - self.consume(TokenType.OP, "Expected 'op' keyword") + keyword: Token = self.consume(TokenType.OP, "Expected 'op' keyword") name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name") self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type") @@ -340,7 +393,12 @@ class MidasParser(Parser): self.consume(TokenType.ARROW, "Expected '->' before result type") result: TypeExpr = self.type_expr() - return OpStmt(name=name, operand=operand, result=result) + return OpStmt( + location=keyword.location_to(self.previous()), + name=name, + operand=operand, + result=result, + ) def predicate_declaration(self) -> PredicateStmt: """Parse a predicate declaration @@ -350,6 +408,7 @@ class MidasParser(Parser): Returns: PredicateStmt: the parsed predicate declaration statement """ + keyword: Token = self.previous() name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name") self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject") subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name") @@ -358,4 +417,10 @@ class MidasParser(Parser): self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject") self.consume(TokenType.EQUAL, "Expected '=' after predicate subject") condition: Expr = self.constraint() - return PredicateStmt(name=name, subject=subject, type=type, condition=condition) + return PredicateStmt( + location=keyword.location_to(self.previous()), + name=name, + subject=subject, + type=type, + condition=condition, + ) diff --git a/midas/parser/python.py b/midas/parser/python.py index 51d68ca..6e0ffe1 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -1,6 +1,7 @@ import ast from typing import Any, Optional +from midas.ast.location import Location from midas.ast.python import ( BaseType, ConstraintType, @@ -8,7 +9,6 @@ from midas.ast.python import ( FrameType, Function, FunctionArgument, - Location, MidasType, )