diff --git a/midas/ast/python.py b/midas/ast/python.py index 9350fd0..878b8b8 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -3,13 +3,39 @@ from __future__ import annotations from abc import ABC, abstractmethod import ast from dataclasses import dataclass -from typing import Generic, Optional, TypeVar +from typing import Generic, Optional, Protocol, TypeVar T = TypeVar("T") -@dataclass(frozen=True) +class HasLocation(Protocol): + lineno: int + col_offset: int + end_lineno: Optional[int] + end_col_offset: Optional[int] + + +@dataclass(frozen=True, kw_only=True) +class Location: + lineno: int + col_offset: int + end_lineno: Optional[int] + end_col_offset: Optional[int] + + @staticmethod + def from_ast(obj: HasLocation) -> Location: + return Location( + lineno=obj.lineno, + col_offset=obj.col_offset, + end_lineno=obj.end_lineno, + end_col_offset=obj.end_col_offset, + ) + + +@dataclass(frozen=True, kw_only=True) class Expr(ABC): + location: Optional[Location] = None + @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... diff --git a/midas/parser/python.py b/midas/parser/python.py index dc24022..51d68ca 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -8,6 +8,7 @@ from midas.ast.python import ( FrameType, Function, FunctionArgument, + Location, MidasType, ) @@ -50,6 +51,7 @@ class PythonParser(ast.NodeVisitor): self.generic_visit(node) def _parse_function(self, node: ast.FunctionDef) -> Function: + loc: Location = Location.from_ast(node) match node: case ast.FunctionDef( name=name, @@ -65,6 +67,7 @@ class PythonParser(ast.NodeVisitor): return [self._parse_function_argument(arg) for arg in args_list] return Function( + location=loc, name=name, posonlyargs=parse_args(posonlyargs), args=parse_args(args), @@ -73,24 +76,38 @@ class PythonParser(ast.NodeVisitor): ) def _parse_function_argument(self, arg: ast.arg) -> FunctionArgument: + loc: Location = Location.from_ast(arg) name: str = arg.arg type: Optional[MidasType] = None if arg.annotation is not None: type = self._parse_type(arg.annotation) - return FunctionArgument(name=name, type=type) + return FunctionArgument( + location=loc, + name=name, + type=type, + ) def _parse_type( self, type_expr: ast.expr, root: bool = False ) -> Optional[MidasType]: + loc: Location = Location.from_ast(type_expr) match type_expr: case ast.Subscript(value=ast.Name(id="Frame"), slice=schema): return self._parse_frame_type(schema) case ast.Subscript(value=ast.Name(id=name), slice=param): - return BaseType(base=name, param=self._parse_type(param)) + return BaseType( + location=loc, + base=name, + param=self._parse_type(param), + ) case ast.Name(id=name): - return BaseType(base=name, param=None) + return BaseType( + location=loc, + base=name, + param=None, + ) case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr): left = self._parse_type(left_expr) @@ -107,12 +124,17 @@ class PythonParser(ast.NodeVisitor): ) ast.copy_location(constraint, type_expr) return ConstraintType( + location=loc, type=left_type, constraint=constraint, ) case _: - return ConstraintType(type=left, constraint=right_expr) + return ConstraintType( + location=loc, + type=left, + constraint=right_expr, + ) case _: if root: @@ -120,6 +142,7 @@ class PythonParser(ast.NodeVisitor): raise UnsupportedSyntaxError(type_expr) def _parse_frame_type(self, schema: ast.expr) -> FrameType: + loc: Location = Location.from_ast(schema) columns: list[FrameColumn] = [] match schema: @@ -133,12 +156,17 @@ class PythonParser(ast.NodeVisitor): case _: raise UnsupportedSyntaxError(schema) - return FrameType(columns=columns) + return FrameType(location=loc, columns=columns) def _parse_frame_column(self, column: ast.expr) -> FrameColumn: + loc: Location = Location.from_ast(column) match column: case ast.Name(): - return FrameColumn(name=None, type=self._parse_type(column)) + return FrameColumn( + location=loc, + name=None, + type=self._parse_type(column), + ) case ast.Slice(lower=ast.Name(id=name), upper=type_expr): if name == "_": @@ -154,7 +182,7 @@ class PythonParser(ast.NodeVisitor): type = self._parse_type(type_expr) case _: raise UnsupportedSyntaxError(type_expr) - return FrameColumn(name=name, type=type) + return FrameColumn(location=loc, name=name, type=type) case _: raise UnsupportedSyntaxError(column)