feat(parser): parse type constraints in python

This commit is contained in:
2026-05-22 18:46:06 +02:00
parent 832c350b61
commit 8d7c115432
3 changed files with 51 additions and 14 deletions

View File

@@ -355,11 +355,15 @@ class PythonAstPrinter(AstPrinter, p.MidasType.Visitor[None]):
self._write_line("BaseType")
with self._child_level():
self._write_line(f"base: {node.base}")
self._write_optional_child("param", node.param)
constraint_str: str = "None"
if node.constraint is not None:
constraint_str = ast.unparse(node.constraint)
self._write_line(f"constraint: {constraint_str}", last=True)
self._write_optional_child("param", node.param, last=True)
def visit_constraint_type(self, node: p.ConstraintType) -> None:
self._write_line("ConstraintType")
with self._child_level():
self._write_line("type")
with self._child_level(single=True):
node.type.accept(self)
self._write_line(f"constraint: {ast.unparse(node.constraint)}", last=True)
def visit_frame_column(self, node: p.FrameColumn) -> None:
self._write_line("FrameColumn")

View File

@@ -17,6 +17,9 @@ class MidasType(ABC):
@abstractmethod
def visit_base_type(self, node: BaseType) -> T: ...
@abstractmethod
def visit_constraint_type(self, node: ConstraintType) -> T: ...
@abstractmethod
def visit_frame_column(self, node: FrameColumn) -> T: ...
@@ -28,12 +31,20 @@ class MidasType(ABC):
class BaseType(MidasType):
base: str
param: Optional[MidasType]
constraint: Optional[ast.expr] = None
def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_base_type(self)
@dataclass(frozen=True)
class ConstraintType(MidasType):
type: MidasType
constraint: ast.expr
def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_constraint_type(self)
@dataclass(frozen=True)
class FrameColumn(MidasType):
name: Optional[str]

View File

@@ -1,7 +1,7 @@
import ast
from typing import Any, Optional
from midas.ast.python import BaseType, FrameColumn, FrameType, MidasType
from midas.ast.python import BaseType, ConstraintType, FrameColumn, FrameType, MidasType
class InvalidSyntaxError(Exception):
@@ -10,7 +10,9 @@ class InvalidSyntaxError(Exception):
class UnsupportedSyntaxError(Exception):
def __init__(self, expr: ast.expr) -> None:
super().__init__(f"Unsupported syntax: {ast.unparse(expr)}")
super().__init__(
f"Unsupported syntax at L{expr.lineno}:{expr.col_offset}: {ast.unparse(expr)}"
)
class PythonParser(ast.NodeVisitor):
@@ -39,16 +41,32 @@ class PythonParser(ast.NodeVisitor):
return self._parse_frame_type(schema)
case ast.Subscript(value=ast.Name(id=name), slice=param):
return BaseType(
base=name, param=self._parse_type(param), constraint=None
)
return BaseType(base=name, param=self._parse_type(param))
case ast.Name(id=name):
return BaseType(base=name, param=None, constraint=None)
return BaseType(base=name, param=None)
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
print("Constraints not implemented yet")
return None
left = self._parse_type(left_expr)
match left:
case None:
raise InvalidSyntaxError("")
# If chained constraints, separate base type and rebuild constraint
case ConstraintType(type=left_type, constraint=left_constraint):
constraint = ast.BinOp(
left=left_constraint,
op=ast.Add(),
right=right_expr,
)
ast.copy_location(constraint, type_expr)
return ConstraintType(
type=left_type,
constraint=constraint,
)
case _:
return ConstraintType(type=left, constraint=right_expr)
case _:
if root:
@@ -62,8 +80,10 @@ class PythonParser(ast.NodeVisitor):
case ast.Tuple(elts=cols):
for col in cols:
columns.append(self._parse_frame_column(col))
case ast.Slice() | ast.Name():
columns.append(self._parse_frame_column(schema))
case _:
raise UnsupportedSyntaxError(schema)
@@ -73,6 +93,7 @@ class PythonParser(ast.NodeVisitor):
match column:
case ast.Name():
return FrameColumn(name=None, type=self._parse_type(column))
case ast.Slice(lower=ast.Name(id=name), upper=type_expr):
if name == "_":
name = None
@@ -88,5 +109,6 @@ class PythonParser(ast.NodeVisitor):
case _:
raise UnsupportedSyntaxError(type_expr)
return FrameColumn(name=name, type=type)
case _:
raise UnsupportedSyntaxError(column)