feat(parser): store locations in parsed nodes

This commit is contained in:
2026-05-22 22:11:44 +02:00
parent 5aedddfabb
commit d0c54db33a
2 changed files with 63 additions and 9 deletions

View File

@@ -3,13 +3,39 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import ast import ast
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, Optional, TypeVar from typing import Generic, Optional, Protocol, TypeVar
T = TypeVar("T") T = TypeVar("T")
@dataclass(frozen=True) class HasLocation(Protocol):
lineno: int
col_offset: int
end_lineno: Optional[int]
end_col_offset: Optional[int]
@dataclass(frozen=True, kw_only=True)
class Location:
lineno: int
col_offset: int
end_lineno: Optional[int]
end_col_offset: Optional[int]
@staticmethod
def from_ast(obj: HasLocation) -> Location:
return Location(
lineno=obj.lineno,
col_offset=obj.col_offset,
end_lineno=obj.end_lineno,
end_col_offset=obj.end_col_offset,
)
@dataclass(frozen=True, kw_only=True)
class Expr(ABC): class Expr(ABC):
location: Optional[Location] = None
@abstractmethod @abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ... def accept(self, visitor: Visitor[T]) -> T: ...

View File

@@ -8,6 +8,7 @@ from midas.ast.python import (
FrameType, FrameType,
Function, Function,
FunctionArgument, FunctionArgument,
Location,
MidasType, MidasType,
) )
@@ -50,6 +51,7 @@ class PythonParser(ast.NodeVisitor):
self.generic_visit(node) self.generic_visit(node)
def _parse_function(self, node: ast.FunctionDef) -> Function: def _parse_function(self, node: ast.FunctionDef) -> Function:
loc: Location = Location.from_ast(node)
match node: match node:
case ast.FunctionDef( case ast.FunctionDef(
name=name, name=name,
@@ -65,6 +67,7 @@ class PythonParser(ast.NodeVisitor):
return [self._parse_function_argument(arg) for arg in args_list] return [self._parse_function_argument(arg) for arg in args_list]
return Function( return Function(
location=loc,
name=name, name=name,
posonlyargs=parse_args(posonlyargs), posonlyargs=parse_args(posonlyargs),
args=parse_args(args), args=parse_args(args),
@@ -73,24 +76,38 @@ class PythonParser(ast.NodeVisitor):
) )
def _parse_function_argument(self, arg: ast.arg) -> FunctionArgument: def _parse_function_argument(self, arg: ast.arg) -> FunctionArgument:
loc: Location = Location.from_ast(arg)
name: str = arg.arg name: str = arg.arg
type: Optional[MidasType] = None type: Optional[MidasType] = None
if arg.annotation is not None: if arg.annotation is not None:
type = self._parse_type(arg.annotation) type = self._parse_type(arg.annotation)
return FunctionArgument(name=name, type=type) return FunctionArgument(
location=loc,
name=name,
type=type,
)
def _parse_type( def _parse_type(
self, type_expr: ast.expr, root: bool = False self, type_expr: ast.expr, root: bool = False
) -> Optional[MidasType]: ) -> Optional[MidasType]:
loc: Location = Location.from_ast(type_expr)
match type_expr: match type_expr:
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema): case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
return self._parse_frame_type(schema) return self._parse_frame_type(schema)
case ast.Subscript(value=ast.Name(id=name), slice=param): case ast.Subscript(value=ast.Name(id=name), slice=param):
return BaseType(base=name, param=self._parse_type(param)) return BaseType(
location=loc,
base=name,
param=self._parse_type(param),
)
case ast.Name(id=name): case ast.Name(id=name):
return BaseType(base=name, param=None) return BaseType(
location=loc,
base=name,
param=None,
)
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr): case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
left = self._parse_type(left_expr) left = self._parse_type(left_expr)
@@ -107,12 +124,17 @@ class PythonParser(ast.NodeVisitor):
) )
ast.copy_location(constraint, type_expr) ast.copy_location(constraint, type_expr)
return ConstraintType( return ConstraintType(
location=loc,
type=left_type, type=left_type,
constraint=constraint, constraint=constraint,
) )
case _: case _:
return ConstraintType(type=left, constraint=right_expr) return ConstraintType(
location=loc,
type=left,
constraint=right_expr,
)
case _: case _:
if root: if root:
@@ -120,6 +142,7 @@ class PythonParser(ast.NodeVisitor):
raise UnsupportedSyntaxError(type_expr) raise UnsupportedSyntaxError(type_expr)
def _parse_frame_type(self, schema: ast.expr) -> FrameType: def _parse_frame_type(self, schema: ast.expr) -> FrameType:
loc: Location = Location.from_ast(schema)
columns: list[FrameColumn] = [] columns: list[FrameColumn] = []
match schema: match schema:
@@ -133,12 +156,17 @@ class PythonParser(ast.NodeVisitor):
case _: case _:
raise UnsupportedSyntaxError(schema) raise UnsupportedSyntaxError(schema)
return FrameType(columns=columns) return FrameType(location=loc, columns=columns)
def _parse_frame_column(self, column: ast.expr) -> FrameColumn: def _parse_frame_column(self, column: ast.expr) -> FrameColumn:
loc: Location = Location.from_ast(column)
match column: match column:
case ast.Name(): case ast.Name():
return FrameColumn(name=None, type=self._parse_type(column)) return FrameColumn(
location=loc,
name=None,
type=self._parse_type(column),
)
case ast.Slice(lower=ast.Name(id=name), upper=type_expr): case ast.Slice(lower=ast.Name(id=name), upper=type_expr):
if name == "_": if name == "_":
@@ -154,7 +182,7 @@ class PythonParser(ast.NodeVisitor):
type = self._parse_type(type_expr) type = self._parse_type(type_expr)
case _: case _:
raise UnsupportedSyntaxError(type_expr) raise UnsupportedSyntaxError(type_expr)
return FrameColumn(name=name, type=type) return FrameColumn(location=loc, name=name, type=type)
case _: case _:
raise UnsupportedSyntaxError(column) raise UnsupportedSyntaxError(column)