fix(parser): add location in all AST nodes

This commit is contained in:
2026-05-28 18:18:16 +02:00
parent 928901ef9c
commit 218b0c5b78
5 changed files with 48 additions and 38 deletions

View File

@@ -11,7 +11,7 @@ SECTION_TEMPLATE = """{banner}
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class {base}(ABC): class {base}(ABC):
location: Optional[Location] = None location: Location
@abstractmethod @abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ... def accept(self, visitor: Visitor[T]) -> T: ...

View File

@@ -21,7 +21,7 @@ T = TypeVar("T")
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Stmt(ABC): class Stmt(ABC):
location: Optional[Location] = None location: Location
@abstractmethod @abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ... def accept(self, visitor: Visitor[T]) -> T: ...
@@ -114,7 +114,7 @@ class PredicateStmt(Stmt):
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Expr(ABC): class Expr(ABC):
location: Optional[Location] = None location: Location
@abstractmethod @abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ... def accept(self, visitor: Visitor[T]) -> T: ...

View File

@@ -21,7 +21,7 @@ T = TypeVar("T")
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class MidasType(ABC): class MidasType(ABC):
location: Optional[Location] = None location: Location
@abstractmethod @abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ... def accept(self, visitor: Visitor[T]) -> T: ...
@@ -82,7 +82,7 @@ class FrameType(MidasType):
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Stmt(ABC): class Stmt(ABC):
location: Optional[Location] = None location: Location
@abstractmethod @abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ... def accept(self, visitor: Visitor[T]) -> T: ...
@@ -157,7 +157,7 @@ class AssignStmt(Stmt):
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Expr(ABC): class Expr(ABC):
location: Optional[Location] = None location: Location
@abstractmethod @abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ... def accept(self, visitor: Visitor[T]) -> T: ...

View File

@@ -205,9 +205,7 @@ class MidasParser(Parser):
while self.match(TokenType.AND): while self.match(TokenType.AND):
operator: Token = self.previous() operator: Token = self.previous()
right: Expr = self.equality() right: Expr = self.equality()
location: Optional[Location] = None location: Location = Location.span(expr.location, right.location)
if expr.location and right.location:
location = Location.span(expr.location, right.location)
expr = LogicalExpr( expr = LogicalExpr(
location=location, left=expr, operator=operator, right=right location=location, left=expr, operator=operator, right=right
) )
@@ -223,9 +221,7 @@ class MidasParser(Parser):
while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL): while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL):
operator: Token = self.previous() operator: Token = self.previous()
right: Expr = self.comparison() right: Expr = self.comparison()
location: Optional[Location] = None location: Location = Location.span(expr.location, right.location)
if expr.location and right.location:
location = Location.span(expr.location, right.location)
expr = BinaryExpr( expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right location=location, left=expr, operator=operator, right=right
) )
@@ -246,9 +242,7 @@ class MidasParser(Parser):
): ):
operator: Token = self.previous() operator: Token = self.previous()
right: Expr = self.unary() right: Expr = self.unary()
location: Optional[Location] = None location: Location = Location.span(expr.location, right.location)
if expr.location and right.location:
location = Location.span(expr.location, right.location)
expr = BinaryExpr( expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right location=location, left=expr, operator=operator, right=right
) )
@@ -263,9 +257,7 @@ class MidasParser(Parser):
if self.match(TokenType.MINUS): if self.match(TokenType.MINUS):
operator: Token = self.previous() operator: Token = self.previous()
right: Expr = self.unary() right: Expr = self.unary()
location: Optional[Location] = None location: Location = Location.span(operator.get_location(), right.location)
if right.location:
location = Location.span(operator.get_location(), right.location)
return UnaryExpr(location=location, operator=operator, right=right) return UnaryExpr(location=location, operator=operator, right=right)
return self.reference() return self.reference()
@@ -280,9 +272,7 @@ class MidasParser(Parser):
name: Token = self.consume( name: Token = self.consume(
TokenType.IDENTIFIER, "Expected property name after '.'" TokenType.IDENTIFIER, "Expected property name after '.'"
) )
location: Optional[Location] = None location: Location = Location.span(expr.location, name.get_location())
if expr.location:
location = Location.span(expr.location, name.get_location())
expr = GetExpr(location=location, expr=expr, name=name) expr = GetExpr(location=location, expr=expr, name=name)
return expr return expr
@@ -370,9 +360,7 @@ class MidasParser(Parser):
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE): while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
operations.append(self.op_declaration()) operations.append(self.op_declaration())
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body") self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
location: Optional[Location] = None location: Location = keyword.location_to(self.previous())
if type.location:
location = keyword.location_to(self.previous())
return ExtendStmt(location=location, type=type, operations=operations) return ExtendStmt(location=location, type=type, operations=operations)
def op_declaration(self) -> OpStmt: def op_declaration(self) -> OpStmt:

View File

@@ -53,6 +53,7 @@ class PythonParser:
return statements return statements
def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]: def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]:
location: Location = Location.from_ast(node)
match node: match node:
case ast.AnnAssign(): case ast.AnnAssign():
return self.parse_annotation_assign(node) return self.parse_annotation_assign(node)
@@ -64,7 +65,10 @@ class PythonParser:
return self.parse_function(node) return self.parse_function(node)
case ast.Expr(value=expr): case ast.Expr(value=expr):
return ExpressionStmt(expr=self.parse_expr(expr)) return ExpressionStmt(
location=location,
expr=self.parse_expr(expr),
)
case _: case _:
print(f"Unsupported statement: {ast.unparse(node)}") print(f"Unsupported statement: {ast.unparse(node)}")
@@ -266,12 +270,14 @@ class PythonParser:
raise UnsupportedSyntaxError(column) raise UnsupportedSyntaxError(column)
def parse_expr(self, node: ast.expr) -> Expr: def parse_expr(self, node: ast.expr) -> Expr:
location: Location = Location.from_ast(node)
match node: match node:
case ast.BoolOp(): case ast.BoolOp():
return self.parse_bool_op(node) return self.parse_bool_op(node)
case ast.BinOp(left=left, op=op, right=right): case ast.BinOp(left=left, op=op, right=right):
return BinaryExpr( return BinaryExpr(
location=location,
left=self.parse_expr(left), left=self.parse_expr(left),
operator=op, operator=op,
right=self.parse_expr(right), right=self.parse_expr(right),
@@ -279,6 +285,7 @@ class PythonParser:
case ast.UnaryOp(op=op, operand=right): case ast.UnaryOp(op=op, operand=right):
return UnaryExpr( return UnaryExpr(
location=location,
operator=op, operator=op,
right=self.parse_expr(right), right=self.parse_expr(right),
) )
@@ -290,58 +297,73 @@ class PythonParser:
return self.parse_call(node) return self.parse_call(node)
case ast.Constant(value=value): case ast.Constant(value=value):
return LiteralExpr(value=value) return LiteralExpr(location=location, value=value)
case ast.Attribute(value=object, attr=name): case ast.Attribute(value=object, attr=name):
return GetExpr( return GetExpr(
location=location,
object=self.parse_expr(object), object=self.parse_expr(object),
name=name, name=name,
) )
case ast.Name(id=name): case ast.Name(id=name):
return VariableExpr(name=name) return VariableExpr(location=location, name=name)
case _: case _:
raise UnsupportedSyntaxError(node) raise UnsupportedSyntaxError(node)
def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr: def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr:
op: ast.boolop = node.op op: ast.boolop = node.op
values: list[ast.expr] = node.values rights: list[Expr] = [self.parse_expr(expr) for expr in node.values]
expr: LogicalExpr = LogicalExpr( expr: LogicalExpr = LogicalExpr(
left=self.parse_expr(values[0]), location=Location.span(
rights[0].location,
rights[1].location,
),
left=rights[0],
operator=op, operator=op,
right=self.parse_expr(values[1]), right=rights[1],
) )
for value in values[2:]: for right in rights[2:]:
expr = LogicalExpr( expr = LogicalExpr(
location=Location.span(expr.location, right.location),
left=expr, left=expr,
operator=op, operator=op,
right=self.parse_expr(value), right=right,
) )
return expr return expr
def parse_compare(self, node: ast.Compare) -> Expr: def parse_compare(self, node: ast.Compare) -> Expr:
ops: list[ast.cmpop] = node.ops ops: list[ast.cmpop] = node.ops
left: Expr = self.parse_expr(node.left)
rights: list[Expr] = [self.parse_expr(expr) for expr in node.comparators] rights: list[Expr] = [self.parse_expr(expr) for expr in node.comparators]
expr: Expr = CompareExpr( expr: Expr = CompareExpr(
left=self.parse_expr(node.left), location=Location.span(
left.location,
rights[0].location,
),
left=left,
operator=ops[0], operator=ops[0],
right=rights[0], right=rights[0],
) )
for i, right in enumerate(rights[1:]): for i, right in enumerate(rights[1:]):
comparison = CompareExpr(
location=Location.span(rights[i].location, right.location),
left=rights[i],
operator=ops[i],
right=right,
)
expr = LogicalExpr( expr = LogicalExpr(
location=Location.span(expr.location, comparison.location),
left=expr, left=expr,
operator=ast.And(), operator=ast.And(),
right=CompareExpr( right=comparison,
left=rights[i],
operator=ops[i],
right=right,
),
) )
return expr return expr
def parse_call(self, node: ast.Call) -> CallExpr: def parse_call(self, node: ast.Call) -> CallExpr:
return CallExpr( return CallExpr(
location=Location.from_ast(node),
callee=self.parse_expr(node.func), callee=self.parse_expr(node.func),
arguments=[self.parse_expr(arg) for arg in node.args], arguments=[self.parse_expr(arg) for arg in node.args],
keywords={ keywords={