diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 513b823..e9033e2 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -355,11 +355,15 @@ class PythonAstPrinter(AstPrinter, p.MidasType.Visitor[None]): self._write_line("BaseType") with self._child_level(): self._write_line(f"base: {node.base}") - self._write_optional_child("param", node.param) - constraint_str: str = "None" - if node.constraint is not None: - constraint_str = ast.unparse(node.constraint) - self._write_line(f"constraint: {constraint_str}", last=True) + self._write_optional_child("param", node.param, last=True) + + def visit_constraint_type(self, node: p.ConstraintType) -> None: + self._write_line("ConstraintType") + with self._child_level(): + self._write_line("type") + with self._child_level(single=True): + node.type.accept(self) + self._write_line(f"constraint: {ast.unparse(node.constraint)}", last=True) def visit_frame_column(self, node: p.FrameColumn) -> None: self._write_line("FrameColumn") diff --git a/midas/ast/python.py b/midas/ast/python.py index 63307c4..8b7f03e 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -17,6 +17,9 @@ class MidasType(ABC): @abstractmethod def visit_base_type(self, node: BaseType) -> T: ... + @abstractmethod + def visit_constraint_type(self, node: ConstraintType) -> T: ... + @abstractmethod def visit_frame_column(self, node: FrameColumn) -> T: ... @@ -28,12 +31,20 @@ class MidasType(ABC): class BaseType(MidasType): base: str param: Optional[MidasType] - constraint: Optional[ast.expr] = None def accept(self, visitor: MidasType.Visitor[T]) -> T: return visitor.visit_base_type(self) +@dataclass(frozen=True) +class ConstraintType(MidasType): + type: MidasType + constraint: ast.expr + + def accept(self, visitor: MidasType.Visitor[T]) -> T: + return visitor.visit_constraint_type(self) + + @dataclass(frozen=True) class FrameColumn(MidasType): name: Optional[str] diff --git a/midas/parser/python.py b/midas/parser/python.py index e55d21e..6139801 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -1,7 +1,7 @@ import ast from typing import Any, Optional -from midas.ast.python import BaseType, FrameColumn, FrameType, MidasType +from midas.ast.python import BaseType, ConstraintType, FrameColumn, FrameType, MidasType class InvalidSyntaxError(Exception): @@ -10,7 +10,9 @@ class InvalidSyntaxError(Exception): class UnsupportedSyntaxError(Exception): def __init__(self, expr: ast.expr) -> None: - super().__init__(f"Unsupported syntax: {ast.unparse(expr)}") + super().__init__( + f"Unsupported syntax at L{expr.lineno}:{expr.col_offset}: {ast.unparse(expr)}" + ) class PythonParser(ast.NodeVisitor): @@ -39,16 +41,32 @@ class PythonParser(ast.NodeVisitor): return self._parse_frame_type(schema) case ast.Subscript(value=ast.Name(id=name), slice=param): - return BaseType( - base=name, param=self._parse_type(param), constraint=None - ) + return BaseType(base=name, param=self._parse_type(param)) case ast.Name(id=name): - return BaseType(base=name, param=None, constraint=None) + return BaseType(base=name, param=None) case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr): - print("Constraints not implemented yet") - return None + left = self._parse_type(left_expr) + match left: + case None: + raise InvalidSyntaxError("") + + # If chained constraints, separate base type and rebuild constraint + case ConstraintType(type=left_type, constraint=left_constraint): + constraint = ast.BinOp( + left=left_constraint, + op=ast.Add(), + right=right_expr, + ) + ast.copy_location(constraint, type_expr) + return ConstraintType( + type=left_type, + constraint=constraint, + ) + + case _: + return ConstraintType(type=left, constraint=right_expr) case _: if root: @@ -62,8 +80,10 @@ class PythonParser(ast.NodeVisitor): case ast.Tuple(elts=cols): for col in cols: columns.append(self._parse_frame_column(col)) + case ast.Slice() | ast.Name(): columns.append(self._parse_frame_column(schema)) + case _: raise UnsupportedSyntaxError(schema) @@ -73,6 +93,7 @@ class PythonParser(ast.NodeVisitor): match column: case ast.Name(): return FrameColumn(name=None, type=self._parse_type(column)) + case ast.Slice(lower=ast.Name(id=name), upper=type_expr): if name == "_": name = None @@ -88,5 +109,6 @@ class PythonParser(ast.NodeVisitor): case _: raise UnsupportedSyntaxError(type_expr) return FrameColumn(name=name, type=type) + case _: raise UnsupportedSyntaxError(column)