353 lines
11 KiB
Python
353 lines
11 KiB
Python
import ast
|
|
from typing import Optional
|
|
|
|
from midas.ast.location import Location
|
|
|
|
from midas.ast.python import (
|
|
AssignStmt,
|
|
BaseType,
|
|
BinaryExpr,
|
|
CallExpr,
|
|
CompareExpr,
|
|
ConstraintType,
|
|
Expr,
|
|
ExpressionStmt,
|
|
FrameColumn,
|
|
FrameType,
|
|
Function,
|
|
GetExpr,
|
|
LiteralExpr,
|
|
LogicalExpr,
|
|
MidasType,
|
|
Stmt,
|
|
TypeAssign,
|
|
UnaryExpr,
|
|
VariableExpr,
|
|
)
|
|
|
|
|
|
class InvalidSyntaxError(Exception):
|
|
pass
|
|
|
|
|
|
class UnsupportedSyntaxError(Exception):
|
|
def __init__(self, expr: ast.expr) -> None:
|
|
super().__init__(
|
|
f"Unsupported syntax at L{expr.lineno}:{expr.col_offset}: {ast.unparse(expr)}"
|
|
)
|
|
|
|
|
|
class PythonParser:
|
|
def parse_module(self, node: ast.Module) -> list[Stmt]:
|
|
statements: list[Stmt] = []
|
|
for stmt in node.body:
|
|
try:
|
|
parsed: None | Stmt | list[Stmt] = self.parse_stmt(stmt)
|
|
if isinstance(parsed, Stmt):
|
|
statements.append(parsed)
|
|
elif parsed is not None:
|
|
statements.extend(parsed)
|
|
except UnsupportedSyntaxError as e:
|
|
print(f"{e}, skipping")
|
|
continue
|
|
return statements
|
|
|
|
def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]:
|
|
match node:
|
|
case ast.AnnAssign():
|
|
return self.parse_annotation_assign(node)
|
|
|
|
case ast.Assign():
|
|
return self.parse_assign(node)
|
|
|
|
case ast.FunctionDef():
|
|
return self.parse_function(node)
|
|
|
|
case ast.Expr(value=expr):
|
|
return ExpressionStmt(expr=self.parse_expr(expr))
|
|
|
|
case _:
|
|
print(f"Unsupported statement: {ast.unparse(node)}")
|
|
return None
|
|
|
|
def parse_annotation_assign(self, node: ast.AnnAssign) -> list[Stmt]:
|
|
statements: list[Stmt] = []
|
|
loc: Location = Location.from_ast(node)
|
|
match node:
|
|
case ast.AnnAssign(
|
|
target=ast.Name(id=target),
|
|
annotation=annotation,
|
|
value=value,
|
|
simple=1,
|
|
):
|
|
type = self._parse_type(annotation, root=True)
|
|
if type is not None:
|
|
statements.append(
|
|
TypeAssign(
|
|
location=loc,
|
|
name=target,
|
|
type=type,
|
|
)
|
|
)
|
|
|
|
if value is not None:
|
|
statements.append(
|
|
AssignStmt(
|
|
location=loc,
|
|
targets=[
|
|
VariableExpr(
|
|
location=Location.from_ast(node.target), name=target
|
|
),
|
|
],
|
|
value=self.parse_expr(value),
|
|
),
|
|
)
|
|
case _:
|
|
print(f"Unsupported annotation: {ast.unparse(node)}")
|
|
return statements
|
|
|
|
def parse_assign(self, node: ast.Assign) -> AssignStmt:
|
|
targets: list[Expr] = []
|
|
for target in node.targets:
|
|
targets.append(self.parse_expr(target))
|
|
value: Expr = self.parse_expr(node.value)
|
|
return AssignStmt(
|
|
location=Location.from_ast(node),
|
|
targets=targets,
|
|
value=value,
|
|
)
|
|
|
|
def parse_function(self, node: ast.FunctionDef) -> Function:
|
|
loc: Location = Location.from_ast(node)
|
|
match node:
|
|
case ast.FunctionDef(
|
|
name=name,
|
|
args=ast.arguments(
|
|
posonlyargs=posonlyargs,
|
|
args=args,
|
|
kwonlyargs=kwonlyargs,
|
|
),
|
|
returns=returns,
|
|
body=raw_body,
|
|
):
|
|
|
|
def parse_args(args_list: list[ast.arg]) -> list[Function.Argument]:
|
|
return [self._parse_function_argument(arg) for arg in args_list]
|
|
|
|
body: list[Stmt] = []
|
|
for stmt in raw_body:
|
|
stmts = self.parse_stmt(stmt)
|
|
if isinstance(stmts, Stmt):
|
|
body.append(stmts)
|
|
elif stmts is not None:
|
|
body.extend(stmts)
|
|
return Function(
|
|
location=loc,
|
|
name=name,
|
|
posonlyargs=parse_args(posonlyargs),
|
|
args=parse_args(args),
|
|
kwonlyargs=parse_args(kwonlyargs),
|
|
returns=self._parse_type(returns) if returns is not None else None,
|
|
body=body,
|
|
)
|
|
case _:
|
|
print(f"Unsupported function definition: {ast.unparse(node)}")
|
|
|
|
def _parse_function_argument(self, arg: ast.arg) -> Function.Argument:
|
|
loc: Location = Location.from_ast(arg)
|
|
name: str = arg.arg
|
|
type: Optional[MidasType] = None
|
|
if arg.annotation is not None:
|
|
type = self._parse_type(arg.annotation)
|
|
return Function.Argument(
|
|
location=loc,
|
|
name=name,
|
|
type=type,
|
|
)
|
|
|
|
def _parse_type(
|
|
self, type_expr: ast.expr, root: bool = False
|
|
) -> Optional[MidasType]:
|
|
loc: Location = Location.from_ast(type_expr)
|
|
match type_expr:
|
|
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
|
|
return self._parse_frame_type(schema)
|
|
|
|
case ast.Subscript(value=ast.Name(id=name), slice=param):
|
|
return BaseType(
|
|
location=loc,
|
|
base=name,
|
|
param=self._parse_type(param),
|
|
)
|
|
|
|
case ast.Name(id=name):
|
|
return BaseType(
|
|
location=loc,
|
|
base=name,
|
|
param=None,
|
|
)
|
|
|
|
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
|
|
left = self._parse_type(left_expr)
|
|
match left:
|
|
case None:
|
|
raise InvalidSyntaxError()
|
|
|
|
# If chained constraints, separate base type and rebuild constraint
|
|
case ConstraintType(type=left_type, constraint=left_constraint):
|
|
constraint = ast.BinOp(
|
|
left=left_constraint,
|
|
op=ast.Add(),
|
|
right=right_expr,
|
|
)
|
|
ast.copy_location(constraint, type_expr)
|
|
return ConstraintType(
|
|
location=loc,
|
|
type=left_type,
|
|
constraint=constraint,
|
|
)
|
|
|
|
case _:
|
|
return ConstraintType(
|
|
location=loc,
|
|
type=left,
|
|
constraint=right_expr,
|
|
)
|
|
|
|
case _:
|
|
if root:
|
|
return None
|
|
raise UnsupportedSyntaxError(type_expr)
|
|
|
|
def _parse_frame_type(self, schema: ast.expr) -> FrameType:
|
|
loc: Location = Location.from_ast(schema)
|
|
columns: list[FrameColumn] = []
|
|
|
|
match schema:
|
|
case ast.Tuple(elts=cols):
|
|
for col in cols:
|
|
columns.append(self._parse_frame_column(col))
|
|
|
|
case ast.Slice() | ast.Name():
|
|
columns.append(self._parse_frame_column(schema))
|
|
|
|
case _:
|
|
raise UnsupportedSyntaxError(schema)
|
|
|
|
return FrameType(location=loc, columns=columns)
|
|
|
|
def _parse_frame_column(self, column: ast.expr) -> FrameColumn:
|
|
loc: Location = Location.from_ast(column)
|
|
match column:
|
|
case ast.Name():
|
|
return FrameColumn(
|
|
location=loc,
|
|
name=None,
|
|
type=self._parse_type(column),
|
|
)
|
|
|
|
case ast.Slice(lower=ast.Name(id=name), upper=type_expr):
|
|
if name == "_":
|
|
name = None
|
|
|
|
type: Optional[MidasType] = None
|
|
match type_expr:
|
|
case None:
|
|
raise InvalidSyntaxError("Missing column type")
|
|
case ast.Name(id="_"):
|
|
type = None
|
|
case ast.expr():
|
|
type = self._parse_type(type_expr)
|
|
case _:
|
|
raise UnsupportedSyntaxError(type_expr)
|
|
return FrameColumn(location=loc, name=name, type=type)
|
|
|
|
case _:
|
|
raise UnsupportedSyntaxError(column)
|
|
|
|
def parse_expr(self, node: ast.expr) -> Expr:
|
|
match node:
|
|
case ast.BoolOp():
|
|
return self.parse_bool_op(node)
|
|
|
|
case ast.BinOp(left=left, op=op, right=right):
|
|
return BinaryExpr(
|
|
left=self.parse_expr(left),
|
|
operator=op,
|
|
right=self.parse_expr(right),
|
|
)
|
|
|
|
case ast.UnaryOp(op=op, operand=right):
|
|
return UnaryExpr(
|
|
operator=op,
|
|
right=self.parse_expr(right),
|
|
)
|
|
|
|
case ast.Compare():
|
|
return self.parse_compare(node)
|
|
|
|
case ast.Call():
|
|
return self.parse_call(node)
|
|
|
|
case ast.Constant(value=value):
|
|
return LiteralExpr(value=value)
|
|
|
|
case ast.Attribute(value=object, attr=name):
|
|
return GetExpr(
|
|
object=self.parse_expr(object),
|
|
name=name,
|
|
)
|
|
|
|
case ast.Name(id=name):
|
|
return VariableExpr(name=name)
|
|
|
|
case _:
|
|
raise UnsupportedSyntaxError(node)
|
|
|
|
def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr:
|
|
op: ast.boolop = node.op
|
|
values: list[ast.expr] = node.values
|
|
expr: LogicalExpr = LogicalExpr(
|
|
left=self.parse_expr(values[0]),
|
|
operator=op,
|
|
right=self.parse_expr(values[1]),
|
|
)
|
|
for value in values[2:]:
|
|
expr = LogicalExpr(
|
|
left=expr,
|
|
operator=op,
|
|
right=self.parse_expr(value),
|
|
)
|
|
return expr
|
|
|
|
def parse_compare(self, node: ast.Compare) -> Expr:
|
|
ops: list[ast.cmpop] = node.ops
|
|
rights: list[Expr] = [self.parse_expr(expr) for expr in node.comparators]
|
|
expr: Expr = CompareExpr(
|
|
left=self.parse_expr(node.left),
|
|
operator=ops[0],
|
|
right=rights[0],
|
|
)
|
|
for i, right in enumerate(rights[1:]):
|
|
expr = LogicalExpr(
|
|
left=expr,
|
|
operator=ast.And(),
|
|
right=CompareExpr(
|
|
left=rights[i],
|
|
operator=ops[i],
|
|
right=right,
|
|
),
|
|
)
|
|
return expr
|
|
|
|
def parse_call(self, node: ast.Call) -> CallExpr:
|
|
return CallExpr(
|
|
callee=self.parse_expr(node.func),
|
|
arguments=[self.parse_expr(arg) for arg in node.args],
|
|
keywords={
|
|
arg.arg: self.parse_expr(arg.value)
|
|
for arg in node.keywords
|
|
if arg.arg is not None # Should always be True, type checker happy
|
|
},
|
|
)
|