feat(parser): parse type constraints in python

This commit is contained in:
2026-05-22 18:46:06 +02:00
parent 832c350b61
commit 8d7c115432
3 changed files with 51 additions and 14 deletions

View File

@@ -355,11 +355,15 @@ class PythonAstPrinter(AstPrinter, p.MidasType.Visitor[None]):
self._write_line("BaseType") self._write_line("BaseType")
with self._child_level(): with self._child_level():
self._write_line(f"base: {node.base}") self._write_line(f"base: {node.base}")
self._write_optional_child("param", node.param) self._write_optional_child("param", node.param, last=True)
constraint_str: str = "None"
if node.constraint is not None: def visit_constraint_type(self, node: p.ConstraintType) -> None:
constraint_str = ast.unparse(node.constraint) self._write_line("ConstraintType")
self._write_line(f"constraint: {constraint_str}", last=True) with self._child_level():
self._write_line("type")
with self._child_level(single=True):
node.type.accept(self)
self._write_line(f"constraint: {ast.unparse(node.constraint)}", last=True)
def visit_frame_column(self, node: p.FrameColumn) -> None: def visit_frame_column(self, node: p.FrameColumn) -> None:
self._write_line("FrameColumn") self._write_line("FrameColumn")

View File

@@ -17,6 +17,9 @@ class MidasType(ABC):
@abstractmethod @abstractmethod
def visit_base_type(self, node: BaseType) -> T: ... def visit_base_type(self, node: BaseType) -> T: ...
@abstractmethod
def visit_constraint_type(self, node: ConstraintType) -> T: ...
@abstractmethod @abstractmethod
def visit_frame_column(self, node: FrameColumn) -> T: ... def visit_frame_column(self, node: FrameColumn) -> T: ...
@@ -28,12 +31,20 @@ class MidasType(ABC):
class BaseType(MidasType): class BaseType(MidasType):
base: str base: str
param: Optional[MidasType] param: Optional[MidasType]
constraint: Optional[ast.expr] = None
def accept(self, visitor: MidasType.Visitor[T]) -> T: def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_base_type(self) return visitor.visit_base_type(self)
@dataclass(frozen=True)
class ConstraintType(MidasType):
type: MidasType
constraint: ast.expr
def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_constraint_type(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class FrameColumn(MidasType): class FrameColumn(MidasType):
name: Optional[str] name: Optional[str]

View File

@@ -1,7 +1,7 @@
import ast import ast
from typing import Any, Optional from typing import Any, Optional
from midas.ast.python import BaseType, FrameColumn, FrameType, MidasType from midas.ast.python import BaseType, ConstraintType, FrameColumn, FrameType, MidasType
class InvalidSyntaxError(Exception): class InvalidSyntaxError(Exception):
@@ -10,7 +10,9 @@ class InvalidSyntaxError(Exception):
class UnsupportedSyntaxError(Exception): class UnsupportedSyntaxError(Exception):
def __init__(self, expr: ast.expr) -> None: def __init__(self, expr: ast.expr) -> None:
super().__init__(f"Unsupported syntax: {ast.unparse(expr)}") super().__init__(
f"Unsupported syntax at L{expr.lineno}:{expr.col_offset}: {ast.unparse(expr)}"
)
class PythonParser(ast.NodeVisitor): class PythonParser(ast.NodeVisitor):
@@ -39,16 +41,32 @@ class PythonParser(ast.NodeVisitor):
return self._parse_frame_type(schema) return self._parse_frame_type(schema)
case ast.Subscript(value=ast.Name(id=name), slice=param): case ast.Subscript(value=ast.Name(id=name), slice=param):
return BaseType( return BaseType(base=name, param=self._parse_type(param))
base=name, param=self._parse_type(param), constraint=None
)
case ast.Name(id=name): case ast.Name(id=name):
return BaseType(base=name, param=None, constraint=None) return BaseType(base=name, param=None)
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr): case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
print("Constraints not implemented yet") left = self._parse_type(left_expr)
return None match left:
case None:
raise InvalidSyntaxError("")
# If chained constraints, separate base type and rebuild constraint
case ConstraintType(type=left_type, constraint=left_constraint):
constraint = ast.BinOp(
left=left_constraint,
op=ast.Add(),
right=right_expr,
)
ast.copy_location(constraint, type_expr)
return ConstraintType(
type=left_type,
constraint=constraint,
)
case _:
return ConstraintType(type=left, constraint=right_expr)
case _: case _:
if root: if root:
@@ -62,8 +80,10 @@ class PythonParser(ast.NodeVisitor):
case ast.Tuple(elts=cols): case ast.Tuple(elts=cols):
for col in cols: for col in cols:
columns.append(self._parse_frame_column(col)) columns.append(self._parse_frame_column(col))
case ast.Slice() | ast.Name(): case ast.Slice() | ast.Name():
columns.append(self._parse_frame_column(schema)) columns.append(self._parse_frame_column(schema))
case _: case _:
raise UnsupportedSyntaxError(schema) raise UnsupportedSyntaxError(schema)
@@ -73,6 +93,7 @@ class PythonParser(ast.NodeVisitor):
match column: match column:
case ast.Name(): case ast.Name():
return FrameColumn(name=None, type=self._parse_type(column)) return FrameColumn(name=None, type=self._parse_type(column))
case ast.Slice(lower=ast.Name(id=name), upper=type_expr): case ast.Slice(lower=ast.Name(id=name), upper=type_expr):
if name == "_": if name == "_":
name = None name = None
@@ -88,5 +109,6 @@ class PythonParser(ast.NodeVisitor):
case _: case _:
raise UnsupportedSyntaxError(type_expr) raise UnsupportedSyntaxError(type_expr)
return FrameColumn(name=name, type=type) return FrameColumn(name=name, type=type)
case _: case _:
raise UnsupportedSyntaxError(column) raise UnsupportedSyntaxError(column)