feat(parser): parse type constraints in python
This commit is contained in:
@@ -355,11 +355,15 @@ class PythonAstPrinter(AstPrinter, p.MidasType.Visitor[None]):
|
||||
self._write_line("BaseType")
|
||||
with self._child_level():
|
||||
self._write_line(f"base: {node.base}")
|
||||
self._write_optional_child("param", node.param)
|
||||
constraint_str: str = "None"
|
||||
if node.constraint is not None:
|
||||
constraint_str = ast.unparse(node.constraint)
|
||||
self._write_line(f"constraint: {constraint_str}", last=True)
|
||||
self._write_optional_child("param", node.param, last=True)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||
self._write_line("ConstraintType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
node.type.accept(self)
|
||||
self._write_line(f"constraint: {ast.unparse(node.constraint)}", last=True)
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> None:
|
||||
self._write_line("FrameColumn")
|
||||
|
||||
@@ -17,6 +17,9 @@ class MidasType(ABC):
|
||||
@abstractmethod
|
||||
def visit_base_type(self, node: BaseType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint_type(self, node: ConstraintType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_frame_column(self, node: FrameColumn) -> T: ...
|
||||
|
||||
@@ -28,12 +31,20 @@ class MidasType(ABC):
|
||||
class BaseType(MidasType):
|
||||
base: str
|
||||
param: Optional[MidasType]
|
||||
constraint: Optional[ast.expr] = None
|
||||
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_base_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstraintType(MidasType):
|
||||
type: MidasType
|
||||
constraint: ast.expr
|
||||
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FrameColumn(MidasType):
|
||||
name: Optional[str]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ast
|
||||
from typing import Any, Optional
|
||||
|
||||
from midas.ast.python import BaseType, FrameColumn, FrameType, MidasType
|
||||
from midas.ast.python import BaseType, ConstraintType, FrameColumn, FrameType, MidasType
|
||||
|
||||
|
||||
class InvalidSyntaxError(Exception):
|
||||
@@ -10,7 +10,9 @@ class InvalidSyntaxError(Exception):
|
||||
|
||||
class UnsupportedSyntaxError(Exception):
|
||||
def __init__(self, expr: ast.expr) -> None:
|
||||
super().__init__(f"Unsupported syntax: {ast.unparse(expr)}")
|
||||
super().__init__(
|
||||
f"Unsupported syntax at L{expr.lineno}:{expr.col_offset}: {ast.unparse(expr)}"
|
||||
)
|
||||
|
||||
|
||||
class PythonParser(ast.NodeVisitor):
|
||||
@@ -39,16 +41,32 @@ class PythonParser(ast.NodeVisitor):
|
||||
return self._parse_frame_type(schema)
|
||||
|
||||
case ast.Subscript(value=ast.Name(id=name), slice=param):
|
||||
return BaseType(
|
||||
base=name, param=self._parse_type(param), constraint=None
|
||||
)
|
||||
return BaseType(base=name, param=self._parse_type(param))
|
||||
|
||||
case ast.Name(id=name):
|
||||
return BaseType(base=name, param=None, constraint=None)
|
||||
return BaseType(base=name, param=None)
|
||||
|
||||
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
|
||||
print("Constraints not implemented yet")
|
||||
return None
|
||||
left = self._parse_type(left_expr)
|
||||
match left:
|
||||
case None:
|
||||
raise InvalidSyntaxError("")
|
||||
|
||||
# If chained constraints, separate base type and rebuild constraint
|
||||
case ConstraintType(type=left_type, constraint=left_constraint):
|
||||
constraint = ast.BinOp(
|
||||
left=left_constraint,
|
||||
op=ast.Add(),
|
||||
right=right_expr,
|
||||
)
|
||||
ast.copy_location(constraint, type_expr)
|
||||
return ConstraintType(
|
||||
type=left_type,
|
||||
constraint=constraint,
|
||||
)
|
||||
|
||||
case _:
|
||||
return ConstraintType(type=left, constraint=right_expr)
|
||||
|
||||
case _:
|
||||
if root:
|
||||
@@ -62,8 +80,10 @@ class PythonParser(ast.NodeVisitor):
|
||||
case ast.Tuple(elts=cols):
|
||||
for col in cols:
|
||||
columns.append(self._parse_frame_column(col))
|
||||
|
||||
case ast.Slice() | ast.Name():
|
||||
columns.append(self._parse_frame_column(schema))
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(schema)
|
||||
|
||||
@@ -73,6 +93,7 @@ class PythonParser(ast.NodeVisitor):
|
||||
match column:
|
||||
case ast.Name():
|
||||
return FrameColumn(name=None, type=self._parse_type(column))
|
||||
|
||||
case ast.Slice(lower=ast.Name(id=name), upper=type_expr):
|
||||
if name == "_":
|
||||
name = None
|
||||
@@ -88,5 +109,6 @@ class PythonParser(ast.NodeVisitor):
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(type_expr)
|
||||
return FrameColumn(name=name, type=type)
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(column)
|
||||
|
||||
Reference in New Issue
Block a user