feat(gen): add basic cast assertions on base type
This commit is contained in:
@@ -1,11 +1,33 @@
|
||||
import ast
|
||||
import shutil
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from midas.ast.location import Location
|
||||
import midas.ast.python as p
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ComplexType,
|
||||
ExtensionType,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
)
|
||||
from midas.utils import TypedAST
|
||||
|
||||
|
||||
@dataclass
|
||||
class Scope:
|
||||
pre_assertions: list[ast.stmt] = field(default_factory=list)
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
def __init__(self, workdir: Path) -> None:
|
||||
self.workdir: Path = workdir.resolve()
|
||||
@@ -13,19 +35,28 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
if self.build_dir.exists():
|
||||
shutil.rmtree(self.build_dir)
|
||||
self.build_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.rel_src_path: Path = Path()
|
||||
|
||||
self._typed_ast: TypedAST = TypedAST(
|
||||
stmts=[],
|
||||
judgements=[],
|
||||
)
|
||||
self._alias_count: int = 0
|
||||
self._scopes: list[Scope] = []
|
||||
|
||||
def generate(self, typed_ast: TypedAST, src_path: Path) -> Path:
|
||||
self.rel_src_path = src_path.relative_to(self.workdir)
|
||||
self._typed_ast = typed_ast
|
||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
||||
module = ast.Module(body=body, type_ignores=[])
|
||||
module = ast.fix_missing_locations(module)
|
||||
compiled: str = ast.unparse(module)
|
||||
rel_src_path: Path = src_path.relative_to(self.workdir)
|
||||
out_path: Path = (self.build_dir / rel_src_path).resolve()
|
||||
out_path: Path = (self.build_dir / self.rel_src_path).resolve()
|
||||
try:
|
||||
_ = out_path.relative_to(self.build_dir)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Directory traversal, {rel_src_path} points outside of parent directory"
|
||||
f"Directory traversal, {self.rel_src_path} points outside of parent directory"
|
||||
)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
out_path.write_text(compiled)
|
||||
@@ -80,8 +111,13 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
||||
# TODO: insert assertion
|
||||
return expr.expr.accept(self)
|
||||
expr2: ast.expr = expr.expr.accept(self)
|
||||
alias: ast.expr = self._make_alias(expr2)
|
||||
|
||||
type: Type = self._get_expr_type(expr)
|
||||
self._make_cast_asserts(expr.location, alias, type)
|
||||
|
||||
return alias
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
|
||||
return ast.IfExp(
|
||||
@@ -176,4 +212,116 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
return stmt.stmt
|
||||
|
||||
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
|
||||
return [stmt.accept(self) for stmt in stmts]
|
||||
generated: list[ast.stmt] = []
|
||||
for stmt in stmts:
|
||||
scope = Scope()
|
||||
self._scopes.append(scope)
|
||||
|
||||
stmt2 = stmt.accept(self)
|
||||
generated.extend(scope.pre_assertions)
|
||||
generated.append(stmt2)
|
||||
if len(scope.aliases) != 0:
|
||||
generated.append(
|
||||
ast.Delete(targets=[ast.Name(id=alias) for alias in scope.aliases])
|
||||
)
|
||||
self._scopes.pop()
|
||||
|
||||
# Remove redundant pass statements
|
||||
if len(generated) > 1:
|
||||
generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)]
|
||||
return generated
|
||||
|
||||
def _make_alias(self, expr: ast.expr) -> ast.expr:
|
||||
name: str = f"__midas_alias_{self._alias_count}__"
|
||||
alias = ast.Name(id=name)
|
||||
self._alias_count += 1
|
||||
self._scopes[-1].aliases.append(name)
|
||||
self._scopes[-1].pre_assertions.append(
|
||||
ast.Assign(
|
||||
targets=[alias],
|
||||
value=expr,
|
||||
)
|
||||
)
|
||||
return alias
|
||||
|
||||
def _add_assert(self, expr: ast.expr, message: str | ast.expr):
|
||||
if isinstance(message, str):
|
||||
message = ast.Constant(value=message)
|
||||
self._scopes[-1].pre_assertions.append(
|
||||
ast.Assert(
|
||||
test=expr,
|
||||
msg=message,
|
||||
)
|
||||
)
|
||||
|
||||
def _get_expr_type(self, query: p.Expr) -> Type:
|
||||
for expr, type in self._typed_ast.judgements:
|
||||
if expr == query:
|
||||
return type
|
||||
raise RuntimeError(f"Cannot get type judgement for {query}")
|
||||
|
||||
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
|
||||
match type:
|
||||
case BaseType(name=name):
|
||||
self._add_assert(
|
||||
ast.Call(
|
||||
func=ast.Name(id="isinstance"),
|
||||
args=[expr, ast.Name(id=name)],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_cast_assert_message(src_location, expr, type),
|
||||
)
|
||||
|
||||
case AliasType(type=base):
|
||||
self._make_cast_asserts(src_location, expr, base)
|
||||
|
||||
case UnitType():
|
||||
self._add_assert(
|
||||
ast.Compare(
|
||||
left=expr,
|
||||
ops=[ast.Is()],
|
||||
comparators=[
|
||||
ast.Constant(value=None),
|
||||
],
|
||||
),
|
||||
self._make_cast_assert_message(src_location, expr, type),
|
||||
)
|
||||
|
||||
case AppliedType():
|
||||
self._make_cast_asserts(src_location, expr, type.body)
|
||||
|
||||
case (
|
||||
TopType()
|
||||
| Function()
|
||||
| OverloadedFunction()
|
||||
| ComplexType()
|
||||
| ExtensionType()
|
||||
| GenericType()
|
||||
):
|
||||
raise NotImplementedError(f"Can't make assertion for type {type}")
|
||||
|
||||
case TypeVar():
|
||||
raise RuntimeError("Unexpected TypeVar")
|
||||
|
||||
def _make_cast_assert_message(
|
||||
self, location: Location, expr: ast.expr, type: Type
|
||||
) -> ast.expr:
|
||||
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
|
||||
return ast.JoinedStr(
|
||||
values=[
|
||||
ast.Constant(f"{loc_str}: CastError: Cannot cast "),
|
||||
ast.FormattedValue(
|
||||
value=ast.Attribute(
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="type"),
|
||||
args=[expr],
|
||||
keywords=[],
|
||||
),
|
||||
attr="__name__",
|
||||
),
|
||||
conversion=-1,
|
||||
),
|
||||
ast.Constant(f" to {type}"),
|
||||
]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user