feat(gen): add basic cast assertions on base type

This commit is contained in:
2026-06-16 12:49:36 +02:00
parent 0a8e0fb6c2
commit c3229b557c

View File

@@ -1,11 +1,33 @@
import ast import ast
import shutil import shutil
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from midas.ast.location import Location
import midas.ast.python as p import midas.ast.python as p
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
TopType,
Type,
TypeVar,
UnitType,
)
from midas.utils import TypedAST from midas.utils import TypedAST
@dataclass
class Scope:
pre_assertions: list[ast.stmt] = field(default_factory=list)
aliases: list[str] = field(default_factory=list)
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def __init__(self, workdir: Path) -> None: def __init__(self, workdir: Path) -> None:
self.workdir: Path = workdir.resolve() self.workdir: Path = workdir.resolve()
@@ -13,19 +35,28 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
if self.build_dir.exists(): if self.build_dir.exists():
shutil.rmtree(self.build_dir) shutil.rmtree(self.build_dir)
self.build_dir.mkdir(parents=True, exist_ok=True) self.build_dir.mkdir(parents=True, exist_ok=True)
self.rel_src_path: Path = Path()
self._typed_ast: TypedAST = TypedAST(
stmts=[],
judgements=[],
)
self._alias_count: int = 0
self._scopes: list[Scope] = []
def generate(self, typed_ast: TypedAST, src_path: Path) -> Path: def generate(self, typed_ast: TypedAST, src_path: Path) -> Path:
self.rel_src_path = src_path.relative_to(self.workdir)
self._typed_ast = typed_ast
body: list[ast.stmt] = self._visit_body(typed_ast.stmts) body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
module = ast.Module(body=body, type_ignores=[]) module = ast.Module(body=body, type_ignores=[])
module = ast.fix_missing_locations(module) module = ast.fix_missing_locations(module)
compiled: str = ast.unparse(module) compiled: str = ast.unparse(module)
rel_src_path: Path = src_path.relative_to(self.workdir) out_path: Path = (self.build_dir / self.rel_src_path).resolve()
out_path: Path = (self.build_dir / rel_src_path).resolve()
try: try:
_ = out_path.relative_to(self.build_dir) _ = out_path.relative_to(self.build_dir)
except ValueError: except ValueError:
raise ValueError( raise ValueError(
f"Directory traversal, {rel_src_path} points outside of parent directory" f"Directory traversal, {self.rel_src_path} points outside of parent directory"
) )
out_path.parent.mkdir(parents=True, exist_ok=True) out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(compiled) out_path.write_text(compiled)
@@ -80,8 +111,13 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
) )
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr: def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
# TODO: insert assertion expr2: ast.expr = expr.expr.accept(self)
return expr.expr.accept(self) alias: ast.expr = self._make_alias(expr2)
type: Type = self._get_expr_type(expr)
self._make_cast_asserts(expr.location, alias, type)
return alias
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr: def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
return ast.IfExp( return ast.IfExp(
@@ -176,4 +212,116 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
return stmt.stmt return stmt.stmt
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]: def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
return [stmt.accept(self) for stmt in stmts] generated: list[ast.stmt] = []
for stmt in stmts:
scope = Scope()
self._scopes.append(scope)
stmt2 = stmt.accept(self)
generated.extend(scope.pre_assertions)
generated.append(stmt2)
if len(scope.aliases) != 0:
generated.append(
ast.Delete(targets=[ast.Name(id=alias) for alias in scope.aliases])
)
self._scopes.pop()
# Remove redundant pass statements
if len(generated) > 1:
generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)]
return generated
def _make_alias(self, expr: ast.expr) -> ast.expr:
name: str = f"__midas_alias_{self._alias_count}__"
alias = ast.Name(id=name)
self._alias_count += 1
self._scopes[-1].aliases.append(name)
self._scopes[-1].pre_assertions.append(
ast.Assign(
targets=[alias],
value=expr,
)
)
return alias
def _add_assert(self, expr: ast.expr, message: str | ast.expr):
if isinstance(message, str):
message = ast.Constant(value=message)
self._scopes[-1].pre_assertions.append(
ast.Assert(
test=expr,
msg=message,
)
)
def _get_expr_type(self, query: p.Expr) -> Type:
for expr, type in self._typed_ast.judgements:
if expr == query:
return type
raise RuntimeError(f"Cannot get type judgement for {query}")
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
match type:
case BaseType(name=name):
self._add_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id=name)],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
)
case AliasType(type=base):
self._make_cast_asserts(src_location, expr, base)
case UnitType():
self._add_assert(
ast.Compare(
left=expr,
ops=[ast.Is()],
comparators=[
ast.Constant(value=None),
],
),
self._make_cast_assert_message(src_location, expr, type),
)
case AppliedType():
self._make_cast_asserts(src_location, expr, type.body)
case (
TopType()
| Function()
| OverloadedFunction()
| ComplexType()
| ExtensionType()
| GenericType()
):
raise NotImplementedError(f"Can't make assertion for type {type}")
case TypeVar():
raise RuntimeError("Unexpected TypeVar")
def _make_cast_assert_message(
self, location: Location, expr: ast.expr, type: Type
) -> ast.expr:
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
return ast.JoinedStr(
values=[
ast.Constant(f"{loc_str}: CastError: Cannot cast "),
ast.FormattedValue(
value=ast.Attribute(
value=ast.Call(
func=ast.Name(id="type"),
args=[expr],
keywords=[],
),
attr="__name__",
),
conversion=-1,
),
ast.Constant(f" to {type}"),
]
)