feat(parser): add location to midas AST nodes

This commit is contained in:
2026-05-25 12:14:14 +02:00
parent 9b59058881
commit e94db2181f
6 changed files with 161 additions and 53 deletions

37
midas/ast/location.py Normal file
View File

@@ -0,0 +1,37 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Protocol
class HasLocation(Protocol):
lineno: int
col_offset: int
end_lineno: Optional[int]
end_col_offset: Optional[int]
@dataclass(frozen=True, kw_only=True)
class Location:
lineno: int
col_offset: int
end_lineno: Optional[int]
end_col_offset: Optional[int]
@staticmethod
def from_ast(obj: HasLocation) -> Location:
return Location(
lineno=obj.lineno,
col_offset=obj.col_offset,
end_lineno=obj.end_lineno,
end_col_offset=obj.end_col_offset,
)
@staticmethod
def span(start: Location, end: Location) -> Location:
return Location(
lineno=start.lineno,
col_offset=start.col_offset,
end_lineno=end.lineno,
end_col_offset=end.end_col_offset,
)

View File

@@ -9,6 +9,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar from typing import Any, Generic, Optional, TypeVar
from midas.ast.location import Location
from midas.lexer.token import Token from midas.lexer.token import Token
T = TypeVar("T") T = TypeVar("T")
@@ -18,8 +19,10 @@ T = TypeVar("T")
############## ##############
@dataclass(frozen=True) @dataclass(frozen=True, kw_only=True)
class Stmt(ABC): class Stmt(ABC):
location: Optional[Location] = None
@abstractmethod @abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ... def accept(self, visitor: Visitor[T]) -> T: ...
@@ -109,8 +112,10 @@ class PredicateStmt(Stmt):
############### ###############
@dataclass(frozen=True) @dataclass(frozen=True, kw_only=True)
class Expr(ABC): class Expr(ABC):
location: Optional[Location] = None
@abstractmethod @abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ... def accept(self, visitor: Visitor[T]) -> T: ...

View File

@@ -3,35 +3,13 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import ast import ast
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, Optional, Protocol, TypeVar from typing import Generic, Optional, TypeVar
from midas.ast.location import Location
T = TypeVar("T") T = TypeVar("T")
class HasLocation(Protocol):
lineno: int
col_offset: int
end_lineno: Optional[int]
end_col_offset: Optional[int]
@dataclass(frozen=True, kw_only=True)
class Location:
lineno: int
col_offset: int
end_lineno: Optional[int]
end_col_offset: Optional[int]
@staticmethod
def from_ast(obj: HasLocation) -> Location:
return Location(
lineno=obj.lineno,
col_offset=obj.col_offset,
end_lineno=obj.end_lineno,
end_col_offset=obj.end_col_offset,
)
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Expr(ABC): class Expr(ABC):
location: Optional[Location] = None location: Optional[Location] = None

View File

@@ -1,7 +1,10 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import Any from typing import Any
from midas.ast.location import Location
from midas.lexer.position import Position from midas.lexer.position import Position
@@ -63,3 +66,23 @@ class Token:
lexeme: str lexeme: str
value: Any value: Any
position: Position position: Position
def get_location(self) -> Location:
lineno: int = self.position.line
col_offset: int = self.position.column - 1
end_lineno = lineno
end_col_offset = col_offset
for c in self.lexeme:
end_col_offset += 1
if c == "\n":
end_lineno += 1
end_col_offset = 0
return Location(
lineno=lineno,
col_offset=col_offset,
end_lineno=end_lineno,
end_col_offset=end_col_offset,
)
def location_to(self, to: Token) -> Location:
return Location.span(self.get_location(), to.get_location())

View File

@@ -1,5 +1,6 @@
from typing import Optional from typing import Optional
from midas.ast.location import Location
from midas.ast.midas import ( from midas.ast.midas import (
BinaryExpr, BinaryExpr,
ComplexTypeStmt, ComplexTypeStmt,
@@ -104,6 +105,7 @@ class MidasParser(Parser):
Returns: Returns:
TypeStmt: the parsed type declaration statement TypeStmt: the parsed type declaration statement
""" """
keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
template: Optional[TemplateExpr] = None template: Optional[TemplateExpr] = None
if self.check(TokenType.LEFT_BRACKET): if self.check(TokenType.LEFT_BRACKET):
@@ -116,11 +118,20 @@ class MidasParser(Parser):
if self.match(TokenType.WHERE): if self.match(TokenType.WHERE):
constraint = self.constraint() constraint = self.constraint()
return SimpleTypeStmt( return SimpleTypeStmt(
name=name, template=template, base=base, constraint=constraint location=keyword.location_to(self.previous()),
name=name,
template=template,
base=base,
constraint=constraint,
) )
else: else:
properties: list[PropertyStmt] = self.type_properties() properties: list[PropertyStmt] = self.type_properties()
return ComplexTypeStmt(name=name, template=template, properties=properties) return ComplexTypeStmt(
location=keyword.location_to(self.previous()),
name=name,
template=template,
properties=properties,
)
def template_expr(self) -> TemplateExpr: def template_expr(self) -> TemplateExpr:
"""Parse a generic template expression """Parse a generic template expression
@@ -130,10 +141,14 @@ class MidasParser(Parser):
Returns: Returns:
TemplateExpr: the parsed template expression TemplateExpr: the parsed template expression
""" """
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression") left: Token = self.consume(
TokenType.LEFT_BRACKET, "Missing '[' before template expression"
)
type: TypeExpr = self.type_expr() type: TypeExpr = self.type_expr()
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression") right: Token = self.consume(
return TemplateExpr(type=type) TokenType.RIGHT_BRACKET, "Missing ']' after template expression"
)
return TemplateExpr(location=left.location_to(right), type=type)
def type_expr(self) -> TypeExpr: def type_expr(self) -> TypeExpr:
"""Parse a type expression """Parse a type expression
@@ -149,7 +164,12 @@ class MidasParser(Parser):
if self.check(TokenType.LEFT_BRACKET): if self.check(TokenType.LEFT_BRACKET):
template = self.template_expr() template = self.template_expr()
optional: bool = self.match(TokenType.QMARK) optional: bool = self.match(TokenType.QMARK)
return TypeExpr(name=name, template=template, optional=optional) return TypeExpr(
location=name.location_to(self.previous()),
name=name,
template=template,
optional=optional,
)
def simple_type_expr(self) -> SimpleTypeExpr: def simple_type_expr(self) -> SimpleTypeExpr:
"""Parse a simple type expression """Parse a simple type expression
@@ -161,7 +181,9 @@ class MidasParser(Parser):
""" """
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
optional: bool = self.match(TokenType.QMARK) optional: bool = self.match(TokenType.QMARK)
return SimpleTypeExpr(name=name, optional=optional) return SimpleTypeExpr(
location=name.location_to(self.previous()), name=name, optional=optional
)
def constraint(self) -> Expr: def constraint(self) -> Expr:
"""Parse a constraint """Parse a constraint
@@ -183,7 +205,12 @@ class MidasParser(Parser):
while self.match(TokenType.AND): while self.match(TokenType.AND):
operator: Token = self.previous() operator: Token = self.previous()
right: Expr = self.equality() right: Expr = self.equality()
expr = LogicalExpr(left=expr, operator=operator, right=right) location: Optional[Location] = None
if expr.location and right.location:
location = Location.span(expr.location, right.location)
expr = LogicalExpr(
location=location, left=expr, operator=operator, right=right
)
return expr return expr
def equality(self) -> Expr: def equality(self) -> Expr:
@@ -196,7 +223,12 @@ class MidasParser(Parser):
while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL): while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL):
operator: Token = self.previous() operator: Token = self.previous()
right: Expr = self.comparison() right: Expr = self.comparison()
expr = BinaryExpr(left=expr, operator=operator, right=right) location: Optional[Location] = None
if expr.location and right.location:
location = Location.span(expr.location, right.location)
expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right
)
return expr return expr
def comparison(self) -> Expr: def comparison(self) -> Expr:
@@ -214,7 +246,12 @@ class MidasParser(Parser):
): ):
operator: Token = self.previous() operator: Token = self.previous()
right: Expr = self.unary() right: Expr = self.unary()
expr = BinaryExpr(left=expr, operator=operator, right=right) location: Optional[Location] = None
if expr.location and right.location:
location = Location.span(expr.location, right.location)
expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right
)
return expr return expr
def unary(self) -> Expr: def unary(self) -> Expr:
@@ -226,7 +263,10 @@ class MidasParser(Parser):
if self.match(TokenType.MINUS): if self.match(TokenType.MINUS):
operator: Token = self.previous() operator: Token = self.previous()
right: Expr = self.unary() right: Expr = self.unary()
return UnaryExpr(operator=operator, right=right) location: Optional[Location] = None
if right.location:
location = Location.span(operator.get_location(), right.location)
return UnaryExpr(location=location, operator=operator, right=right)
return self.reference() return self.reference()
def reference(self) -> Expr: def reference(self) -> Expr:
@@ -240,7 +280,10 @@ class MidasParser(Parser):
name: Token = self.consume( name: Token = self.consume(
TokenType.IDENTIFIER, "Expected property name after '.'" TokenType.IDENTIFIER, "Expected property name after '.'"
) )
expr = GetExpr(expr=expr, name=name) location: Optional[Location] = None
if expr.location:
location = Location.span(expr.location, name.get_location())
expr = GetExpr(location=location, expr=expr, name=name)
return expr return expr
def primary(self) -> Expr: def primary(self) -> Expr:
@@ -251,26 +294,27 @@ class MidasParser(Parser):
Returns: Returns:
Expr: the parsed expression Expr: the parsed expression
""" """
token: Token = self.peek()
if self.match(TokenType.FALSE): if self.match(TokenType.FALSE):
return LiteralExpr(False) return LiteralExpr(location=token.get_location(), value=False)
if self.match(TokenType.TRUE): if self.match(TokenType.TRUE):
return LiteralExpr(True) return LiteralExpr(location=token.get_location(), value=True)
if self.match(TokenType.NONE): if self.match(TokenType.NONE):
return LiteralExpr(None) return LiteralExpr(location=token.get_location(), value=None)
if self.match(TokenType.NUMBER): if self.match(TokenType.NUMBER):
return LiteralExpr(self.previous().value) return LiteralExpr(location=token.get_location(), value=token.value)
if self.match(TokenType.IDENTIFIER): if self.match(TokenType.IDENTIFIER):
return VariableExpr(self.previous()) return VariableExpr(location=token.get_location(), name=token)
if self.match(TokenType.UNDERSCORE): if self.match(TokenType.UNDERSCORE):
return WildcardExpr(self.previous()) return WildcardExpr(location=token.get_location(), token=token)
if self.match(TokenType.LEFT_PAREN): if self.match(TokenType.LEFT_PAREN):
expr: Expr = self.constraint() expr: Expr = self.constraint()
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis") right: Token = self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
return GroupingExpr(expr) return GroupingExpr(location=token.location_to(right), expr=expr)
raise self.error(self.peek(), "Expected expression") raise self.error(self.peek(), "Expected expression")
@@ -304,7 +348,12 @@ class MidasParser(Parser):
constraint: Optional[Expr] = None constraint: Optional[Expr] = None
if self.match(TokenType.WHERE): if self.match(TokenType.WHERE):
constraint = self.constraint() constraint = self.constraint()
return PropertyStmt(name=name, type=type, constraint=constraint) return PropertyStmt(
location=name.location_to(self.previous()),
name=name,
type=type,
constraint=constraint,
)
def extend_declaration(self) -> ExtendStmt: def extend_declaration(self) -> ExtendStmt:
"""Parse an extension definition """Parse an extension definition
@@ -314,13 +363,17 @@ class MidasParser(Parser):
Returns: Returns:
ExtendStmt: the parsed extension statement ExtendStmt: the parsed extension statement
""" """
keyword: Token = self.previous()
type: TypeExpr = self.type_expr() type: TypeExpr = self.type_expr()
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body") self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
operations: list[OpStmt] = [] operations: list[OpStmt] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE): while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
operations.append(self.op_declaration()) operations.append(self.op_declaration())
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body") self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
return ExtendStmt(type=type, operations=operations) location: Optional[Location] = None
if type.location:
location = keyword.location_to(self.previous())
return ExtendStmt(location=location, type=type, operations=operations)
def op_declaration(self) -> OpStmt: def op_declaration(self) -> OpStmt:
"""Parse an operation definition """Parse an operation definition
@@ -330,7 +383,7 @@ class MidasParser(Parser):
Returns: Returns:
OpStmt: the parsed operation statement OpStmt: the parsed operation statement
""" """
self.consume(TokenType.OP, "Expected 'op' keyword") keyword: Token = self.consume(TokenType.OP, "Expected 'op' keyword")
name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type") self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type")
@@ -340,7 +393,12 @@ class MidasParser(Parser):
self.consume(TokenType.ARROW, "Expected '->' before result type") self.consume(TokenType.ARROW, "Expected '->' before result type")
result: TypeExpr = self.type_expr() result: TypeExpr = self.type_expr()
return OpStmt(name=name, operand=operand, result=result) return OpStmt(
location=keyword.location_to(self.previous()),
name=name,
operand=operand,
result=result,
)
def predicate_declaration(self) -> PredicateStmt: def predicate_declaration(self) -> PredicateStmt:
"""Parse a predicate declaration """Parse a predicate declaration
@@ -350,6 +408,7 @@ class MidasParser(Parser):
Returns: Returns:
PredicateStmt: the parsed predicate declaration statement PredicateStmt: the parsed predicate declaration statement
""" """
keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject") self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name") subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
@@ -358,4 +417,10 @@ class MidasParser(Parser):
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject") self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject") self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
condition: Expr = self.constraint() condition: Expr = self.constraint()
return PredicateStmt(name=name, subject=subject, type=type, condition=condition) return PredicateStmt(
location=keyword.location_to(self.previous()),
name=name,
subject=subject,
type=type,
condition=condition,
)

View File

@@ -1,6 +1,7 @@
import ast import ast
from typing import Any, Optional from typing import Any, Optional
from midas.ast.location import Location
from midas.ast.python import ( from midas.ast.python import (
BaseType, BaseType,
ConstraintType, ConstraintType,
@@ -8,7 +9,6 @@ from midas.ast.python import (
FrameType, FrameType,
Function, Function,
FunctionArgument, FunctionArgument,
Location,
MidasType, MidasType,
) )