diff --git a/midas/checker/types.py b/midas/checker/types.py index 5cb1beb..c16db79 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, assert_never @dataclass(frozen=True, kw_only=True) @@ -203,9 +203,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: case UnknownType() | UnitType(): return type - case _: + case TopType() | GenericType(): raise NotImplementedError(f"Unsupported type {type}") + # Ensure exhaustiveness + case _: + assert_never(type) + def unfold_type(type: Type) -> Type: match type: diff --git a/midas/generator/generator.py b/midas/generator/generator.py index 7575ca5..c7b1323 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -2,7 +2,7 @@ import ast import shutil from dataclasses import dataclass, field from pathlib import Path -from typing import Optional +from typing import Optional, assert_never import midas.ast.python as p from midas.ast.location import Location @@ -19,6 +19,7 @@ from midas.checker.types import ( Type, TypeVar, UnitType, + UnknownType, ) from midas.utils import TypedAST @@ -276,6 +277,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type): match type: + case UnknownType(): + pass + case BaseType(name=name): self._add_assert( ast.Call( @@ -301,8 +305,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): self._make_cast_assert_message(src_location, expr, type), ) - case AppliedType(): - self._make_cast_asserts(src_location, expr, type.body) + case AppliedType(body=body): + self._make_cast_asserts(src_location, expr, body) + + case TypeVar(): + raise RuntimeError("Unexpected TypeVar") case ( TopType() @@ -314,8 +321,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): ): raise NotImplementedError(f"Can't make assertion for type {type}") - case TypeVar(): - raise RuntimeError("Unexpected TypeVar") + # Ensure exhaustiveness + case _: + assert_never(type) def _make_cast_assert_message( self, location: Location, expr: ast.expr, type: Type