refactor: ensure exhaustiveness in some match/case

This commit is contained in:
2026-06-18 12:51:28 +02:00
parent 04853eac70
commit 48c1ecc1c8
2 changed files with 19 additions and 7 deletions

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional, assert_never
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
@@ -203,9 +203,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
case UnknownType() | UnitType(): case UnknownType() | UnitType():
return type return type
case _: case TopType() | GenericType():
raise NotImplementedError(f"Unsupported type {type}") raise NotImplementedError(f"Unsupported type {type}")
# Ensure exhaustiveness
case _:
assert_never(type)
def unfold_type(type: Type) -> Type: def unfold_type(type: Type) -> Type:
match type: match type:

View File

@@ -2,7 +2,7 @@ import ast
import shutil import shutil
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, assert_never
import midas.ast.python as p import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
@@ -19,6 +19,7 @@ from midas.checker.types import (
Type, Type,
TypeVar, TypeVar,
UnitType, UnitType,
UnknownType,
) )
from midas.utils import TypedAST from midas.utils import TypedAST
@@ -276,6 +277,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type): def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
match type: match type:
case UnknownType():
pass
case BaseType(name=name): case BaseType(name=name):
self._add_assert( self._add_assert(
ast.Call( ast.Call(
@@ -301,8 +305,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._make_cast_assert_message(src_location, expr, type), self._make_cast_assert_message(src_location, expr, type),
) )
case AppliedType(): case AppliedType(body=body):
self._make_cast_asserts(src_location, expr, type.body) self._make_cast_asserts(src_location, expr, body)
case TypeVar():
raise RuntimeError("Unexpected TypeVar")
case ( case (
TopType() TopType()
@@ -314,8 +321,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
): ):
raise NotImplementedError(f"Can't make assertion for type {type}") raise NotImplementedError(f"Can't make assertion for type {type}")
case TypeVar(): # Ensure exhaustiveness
raise RuntimeError("Unexpected TypeVar") case _:
assert_never(type)
def _make_cast_assert_message( def _make_cast_assert_message(
self, location: Location, expr: ast.expr, type: Type self, location: Location, expr: ast.expr, type: Type