diff --git a/midas/generator/generator.py b/midas/generator/generator.py index ddab606..0c086d8 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -1,11 +1,33 @@ import ast import shutil +from dataclasses import dataclass, field from pathlib import Path +from midas.ast.location import Location import midas.ast.python as p +from midas.checker.types import ( + AliasType, + AppliedType, + BaseType, + ComplexType, + ExtensionType, + Function, + GenericType, + OverloadedFunction, + TopType, + Type, + TypeVar, + UnitType, +) from midas.utils import TypedAST +@dataclass +class Scope: + pre_assertions: list[ast.stmt] = field(default_factory=list) + aliases: list[str] = field(default_factory=list) + + class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): def __init__(self, workdir: Path) -> None: self.workdir: Path = workdir.resolve() @@ -13,19 +35,28 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): if self.build_dir.exists(): shutil.rmtree(self.build_dir) self.build_dir.mkdir(parents=True, exist_ok=True) + self.rel_src_path: Path = Path() + + self._typed_ast: TypedAST = TypedAST( + stmts=[], + judgements=[], + ) + self._alias_count: int = 0 + self._scopes: list[Scope] = [] def generate(self, typed_ast: TypedAST, src_path: Path) -> Path: + self.rel_src_path = src_path.relative_to(self.workdir) + self._typed_ast = typed_ast body: list[ast.stmt] = self._visit_body(typed_ast.stmts) module = ast.Module(body=body, type_ignores=[]) module = ast.fix_missing_locations(module) compiled: str = ast.unparse(module) - rel_src_path: Path = src_path.relative_to(self.workdir) - out_path: Path = (self.build_dir / rel_src_path).resolve() + out_path: Path = (self.build_dir / self.rel_src_path).resolve() try: _ = out_path.relative_to(self.build_dir) except ValueError: raise ValueError( - f"Directory traversal, {rel_src_path} points outside of parent directory" + f"Directory traversal, {self.rel_src_path} points outside of parent directory" ) out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(compiled) @@ -80,8 +111,13 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): ) def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr: - # TODO: insert assertion - return expr.expr.accept(self) + expr2: ast.expr = expr.expr.accept(self) + alias: ast.expr = self._make_alias(expr2) + + type: Type = self._get_expr_type(expr) + self._make_cast_asserts(expr.location, alias, type) + + return alias def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr: return ast.IfExp( @@ -176,4 +212,116 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): return stmt.stmt def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]: - return [stmt.accept(self) for stmt in stmts] + generated: list[ast.stmt] = [] + for stmt in stmts: + scope = Scope() + self._scopes.append(scope) + + stmt2 = stmt.accept(self) + generated.extend(scope.pre_assertions) + generated.append(stmt2) + if len(scope.aliases) != 0: + generated.append( + ast.Delete(targets=[ast.Name(id=alias) for alias in scope.aliases]) + ) + self._scopes.pop() + + # Remove redundant pass statements + if len(generated) > 1: + generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)] + return generated + + def _make_alias(self, expr: ast.expr) -> ast.expr: + name: str = f"__midas_alias_{self._alias_count}__" + alias = ast.Name(id=name) + self._alias_count += 1 + self._scopes[-1].aliases.append(name) + self._scopes[-1].pre_assertions.append( + ast.Assign( + targets=[alias], + value=expr, + ) + ) + return alias + + def _add_assert(self, expr: ast.expr, message: str | ast.expr): + if isinstance(message, str): + message = ast.Constant(value=message) + self._scopes[-1].pre_assertions.append( + ast.Assert( + test=expr, + msg=message, + ) + ) + + def _get_expr_type(self, query: p.Expr) -> Type: + for expr, type in self._typed_ast.judgements: + if expr == query: + return type + raise RuntimeError(f"Cannot get type judgement for {query}") + + def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type): + match type: + case BaseType(name=name): + self._add_assert( + ast.Call( + func=ast.Name(id="isinstance"), + args=[expr, ast.Name(id=name)], + keywords=[], + ), + self._make_cast_assert_message(src_location, expr, type), + ) + + case AliasType(type=base): + self._make_cast_asserts(src_location, expr, base) + + case UnitType(): + self._add_assert( + ast.Compare( + left=expr, + ops=[ast.Is()], + comparators=[ + ast.Constant(value=None), + ], + ), + self._make_cast_assert_message(src_location, expr, type), + ) + + case AppliedType(): + self._make_cast_asserts(src_location, expr, type.body) + + case ( + TopType() + | Function() + | OverloadedFunction() + | ComplexType() + | ExtensionType() + | GenericType() + ): + raise NotImplementedError(f"Can't make assertion for type {type}") + + case TypeVar(): + raise RuntimeError("Unexpected TypeVar") + + def _make_cast_assert_message( + self, location: Location, expr: ast.expr, type: Type + ) -> ast.expr: + loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}" + # f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type" + return ast.JoinedStr( + values=[ + ast.Constant(f"{loc_str}: CastError: Cannot cast "), + ast.FormattedValue( + value=ast.Attribute( + value=ast.Call( + func=ast.Name(id="type"), + args=[expr], + keywords=[], + ), + attr="__name__", + ), + conversion=-1, + ), + ast.Constant(f" to {type}"), + ] + )