342 lines
11 KiB
Python
342 lines
11 KiB
Python
import ast
|
|
import shutil
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import midas.ast.python as p
|
|
from midas.ast.location import Location
|
|
from midas.checker.types import (
|
|
AliasType,
|
|
AppliedType,
|
|
BaseType,
|
|
ComplexType,
|
|
ExtensionType,
|
|
Function,
|
|
GenericType,
|
|
OverloadedFunction,
|
|
TopType,
|
|
Type,
|
|
TypeVar,
|
|
UnitType,
|
|
)
|
|
from midas.utils import TypedAST
|
|
|
|
|
|
@dataclass
|
|
class Scope:
|
|
pre_assertions: list[ast.stmt] = field(default_factory=list)
|
|
aliases: list[str] = field(default_factory=list)
|
|
|
|
|
|
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|
def __init__(self, workdir: Path) -> None:
|
|
self.workdir: Path = workdir.resolve()
|
|
self.build_dir: Path = self.workdir / "build" / "midas"
|
|
if self.build_dir.exists():
|
|
shutil.rmtree(self.build_dir)
|
|
self.build_dir.mkdir(parents=True, exist_ok=True)
|
|
self.rel_src_path: Path = Path()
|
|
|
|
self._typed_ast: TypedAST = TypedAST(
|
|
stmts=[],
|
|
judgements=[],
|
|
)
|
|
self._alias_count: int = 0
|
|
self._scopes: list[Scope] = []
|
|
|
|
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
|
|
self.rel_src_path = src_path.relative_to(self.workdir)
|
|
self._typed_ast = typed_ast
|
|
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
|
module = ast.Module(body=body, type_ignores=[])
|
|
module = ast.fix_missing_locations(module)
|
|
return module
|
|
|
|
def generate(
|
|
self, typed_ast: TypedAST, src_path: Path, out_path: Optional[Path] = None
|
|
) -> Path:
|
|
module: ast.AST = self.generate_ast(typed_ast, src_path)
|
|
compiled: str = ast.unparse(module)
|
|
if out_path is None:
|
|
out_path = (self.build_dir / self.rel_src_path).resolve()
|
|
try:
|
|
_ = out_path.relative_to(self.build_dir)
|
|
except ValueError:
|
|
raise ValueError(
|
|
f"Directory traversal, {self.rel_src_path} points outside of parent directory"
|
|
)
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
out_path.write_text(compiled)
|
|
return out_path
|
|
|
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> ast.expr:
|
|
return ast.BinOp(
|
|
left=expr.left.accept(self),
|
|
op=expr.operator,
|
|
right=expr.right.accept(self),
|
|
)
|
|
|
|
def visit_compare_expr(self, expr: p.CompareExpr) -> ast.expr:
|
|
return ast.Compare(
|
|
left=expr.left.accept(self),
|
|
ops=[expr.operator],
|
|
comparators=[expr.right.accept(self)],
|
|
)
|
|
|
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> ast.expr:
|
|
return ast.UnaryOp(
|
|
op=expr.operator,
|
|
operand=expr.right.accept(self),
|
|
)
|
|
|
|
def visit_call_expr(self, expr: p.CallExpr) -> ast.expr:
|
|
return ast.Call(
|
|
func=expr.callee.accept(self),
|
|
args=[arg.accept(self) for arg in expr.arguments],
|
|
keywords=[
|
|
ast.keyword(arg=name, value=arg.accept(self))
|
|
for name, arg in expr.keywords.items()
|
|
],
|
|
)
|
|
|
|
def visit_get_expr(self, expr: p.GetExpr) -> ast.expr:
|
|
return ast.Attribute(
|
|
value=expr.object.accept(self),
|
|
attr=expr.name,
|
|
)
|
|
|
|
def visit_literal_expr(self, expr: p.LiteralExpr) -> ast.expr:
|
|
return ast.Constant(value=expr.value)
|
|
|
|
def visit_variable_expr(self, expr: p.VariableExpr) -> ast.expr:
|
|
return ast.Name(id=expr.name)
|
|
|
|
def visit_logical_expr(self, expr: p.LogicalExpr) -> ast.expr:
|
|
return ast.BoolOp(
|
|
op=expr.operator,
|
|
values=[expr.left.accept(self), expr.right.accept(self)],
|
|
)
|
|
|
|
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
|
expr2: ast.expr = expr.expr.accept(self)
|
|
alias: ast.expr = self._make_alias(expr2)
|
|
|
|
type: Type = self._get_expr_type(expr)
|
|
self._make_cast_asserts(expr.location, alias, type)
|
|
|
|
return alias
|
|
|
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
|
|
return ast.IfExp(
|
|
test=expr.test.accept(self),
|
|
body=expr.if_true.accept(self),
|
|
orelse=expr.if_false.accept(self),
|
|
)
|
|
|
|
def visit_list_expr(self, expr: p.ListExpr) -> ast.expr:
|
|
return ast.List(
|
|
elts=[item.accept(self) for item in expr.items],
|
|
)
|
|
|
|
def visit_dict_expr(self, expr: p.DictExpr) -> ast.expr:
|
|
return ast.Dict(
|
|
keys=[key.accept(self) if key is not None else None for key in expr.keys],
|
|
values=[value.accept(self) for value in expr.values],
|
|
)
|
|
|
|
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr:
|
|
return ast.Subscript(
|
|
value=expr.object.accept(self),
|
|
slice=expr.index.accept(self),
|
|
)
|
|
|
|
def visit_slice_expr(self, expr: p.SliceExpr) -> ast.expr:
|
|
return ast.Slice(
|
|
lower=expr.lower.accept(self) if expr.lower is not None else None,
|
|
upper=expr.upper.accept(self) if expr.upper is not None else None,
|
|
step=expr.step.accept(self) if expr.step is not None else None,
|
|
)
|
|
|
|
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
|
|
return expr.expr
|
|
|
|
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt:
|
|
return ast.Expr(
|
|
value=stmt.expr.accept(self),
|
|
)
|
|
|
|
def visit_function(self, stmt: p.Function) -> ast.stmt:
|
|
return ast.FunctionDef(
|
|
name=stmt.name,
|
|
args=ast.arguments(
|
|
posonlyargs=[ast.arg(arg=arg.name) for arg in stmt.posonlyargs],
|
|
vararg=None,
|
|
args=[ast.arg(arg=arg.name) for arg in stmt.args],
|
|
kwonlyargs=[ast.arg(arg=arg.name) for arg in stmt.kwonlyargs],
|
|
kwarg=None,
|
|
defaults=[
|
|
arg.default.accept(self)
|
|
for arg in stmt.posonlyargs + stmt.args
|
|
if arg.default is not None
|
|
],
|
|
kw_defaults=[
|
|
arg.default.accept(self) if arg.default is not None else None
|
|
for arg in stmt.kwonlyargs
|
|
],
|
|
),
|
|
body=self._visit_body(stmt.body),
|
|
decorator_list=[],
|
|
)
|
|
|
|
def visit_type_assign(self, stmt: p.TypeAssign) -> ast.stmt:
|
|
# TODO: is that ok?
|
|
return ast.Pass()
|
|
|
|
def visit_assign_stmt(self, stmt: p.AssignStmt) -> ast.stmt:
|
|
return ast.Assign(
|
|
targets=[target.accept(self) for target in stmt.targets],
|
|
value=stmt.value.accept(self),
|
|
)
|
|
|
|
def visit_return_stmt(self, stmt: p.ReturnStmt) -> ast.stmt:
|
|
return ast.Return(
|
|
value=stmt.value.accept(self) if stmt.value is not None else None,
|
|
)
|
|
|
|
def visit_if_stmt(self, stmt: p.IfStmt) -> ast.stmt:
|
|
return ast.If(
|
|
test=stmt.test.accept(self),
|
|
body=self._visit_body(stmt.body),
|
|
orelse=self._visit_body(stmt.orelse),
|
|
)
|
|
|
|
def visit_pass(self, stmt: p.Pass) -> ast.stmt:
|
|
return ast.Pass()
|
|
|
|
def visit_for_stmt(self, stmt: p.ForStmt) -> ast.stmt:
|
|
return ast.For(
|
|
target=stmt.target.accept(self),
|
|
iter=stmt.iterator.accept(self),
|
|
body=self._visit_body(stmt.body),
|
|
orelse=[],
|
|
)
|
|
|
|
def visit_raw_stmt(self, stmt: p.RawStmt) -> ast.stmt:
|
|
return stmt.stmt
|
|
|
|
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
|
|
generated: list[ast.stmt] = []
|
|
for stmt in stmts:
|
|
scope = Scope()
|
|
self._scopes.append(scope)
|
|
|
|
stmt2 = stmt.accept(self)
|
|
generated.extend(scope.pre_assertions)
|
|
generated.append(stmt2)
|
|
if len(scope.aliases) != 0:
|
|
generated.append(
|
|
ast.Delete(targets=[ast.Name(id=alias) for alias in scope.aliases])
|
|
)
|
|
self._scopes.pop()
|
|
|
|
# Remove redundant pass statements
|
|
if len(generated) > 1:
|
|
generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)]
|
|
return generated
|
|
|
|
def _make_alias(self, expr: ast.expr) -> ast.expr:
|
|
name: str = f"__midas_alias_{self._alias_count}__"
|
|
alias = ast.Name(id=name)
|
|
self._alias_count += 1
|
|
self._scopes[-1].aliases.append(name)
|
|
self._scopes[-1].pre_assertions.append(
|
|
ast.Assign(
|
|
targets=[alias],
|
|
value=expr,
|
|
)
|
|
)
|
|
return alias
|
|
|
|
def _add_assert(self, expr: ast.expr, message: str | ast.expr):
|
|
if isinstance(message, str):
|
|
message = ast.Constant(value=message)
|
|
self._scopes[-1].pre_assertions.append(
|
|
ast.Assert(
|
|
test=expr,
|
|
msg=message,
|
|
)
|
|
)
|
|
|
|
def _get_expr_type(self, query: p.Expr) -> Type:
|
|
for expr, type in self._typed_ast.judgements:
|
|
if expr == query:
|
|
return type
|
|
raise RuntimeError(f"Cannot get type judgement for {query}")
|
|
|
|
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
|
|
match type:
|
|
case BaseType(name=name):
|
|
self._add_assert(
|
|
ast.Call(
|
|
func=ast.Name(id="isinstance"),
|
|
args=[expr, ast.Name(id=name)],
|
|
keywords=[],
|
|
),
|
|
self._make_cast_assert_message(src_location, expr, type),
|
|
)
|
|
|
|
case AliasType(type=base):
|
|
self._make_cast_asserts(src_location, expr, base)
|
|
|
|
case UnitType():
|
|
self._add_assert(
|
|
ast.Compare(
|
|
left=expr,
|
|
ops=[ast.Is()],
|
|
comparators=[
|
|
ast.Constant(value=None),
|
|
],
|
|
),
|
|
self._make_cast_assert_message(src_location, expr, type),
|
|
)
|
|
|
|
case AppliedType():
|
|
self._make_cast_asserts(src_location, expr, type.body)
|
|
|
|
case (
|
|
TopType()
|
|
| Function()
|
|
| OverloadedFunction()
|
|
| ComplexType()
|
|
| ExtensionType()
|
|
| GenericType()
|
|
):
|
|
raise NotImplementedError(f"Can't make assertion for type {type}")
|
|
|
|
case TypeVar():
|
|
raise RuntimeError("Unexpected TypeVar")
|
|
|
|
def _make_cast_assert_message(
|
|
self, location: Location, expr: ast.expr, type: Type
|
|
) -> ast.expr:
|
|
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
|
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
|
|
return ast.JoinedStr(
|
|
values=[
|
|
ast.Constant(f"{loc_str}: CastError: Cannot cast "),
|
|
ast.FormattedValue(
|
|
value=ast.Attribute(
|
|
value=ast.Call(
|
|
func=ast.Name(id="type"),
|
|
args=[expr],
|
|
keywords=[],
|
|
),
|
|
attr="__name__",
|
|
),
|
|
conversion=-1,
|
|
),
|
|
ast.Constant(f" to {type}"),
|
|
]
|
|
)
|