refactor: ensure exhaustiveness in some match/case

This commit is contained in:
2026-06-18 12:51:28 +02:00
parent 04853eac70
commit 48c1ecc1c8
2 changed files with 19 additions and 7 deletions

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, assert_never
@dataclass(frozen=True, kw_only=True)
@@ -203,9 +203,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
case UnknownType() | UnitType():
return type
case _:
case TopType() | GenericType():
raise NotImplementedError(f"Unsupported type {type}")
# Ensure exhaustiveness
case _:
assert_never(type)
def unfold_type(type: Type) -> Type:
match type:

View File

@@ -2,7 +2,7 @@ import ast
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from typing import Optional, assert_never
import midas.ast.python as p
from midas.ast.location import Location
@@ -19,6 +19,7 @@ from midas.checker.types import (
Type,
TypeVar,
UnitType,
UnknownType,
)
from midas.utils import TypedAST
@@ -276,6 +277,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
match type:
case UnknownType():
pass
case BaseType(name=name):
self._add_assert(
ast.Call(
@@ -301,8 +305,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._make_cast_assert_message(src_location, expr, type),
)
case AppliedType():
self._make_cast_asserts(src_location, expr, type.body)
case AppliedType(body=body):
self._make_cast_asserts(src_location, expr, body)
case TypeVar():
raise RuntimeError("Unexpected TypeVar")
case (
TopType()
@@ -314,8 +321,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
):
raise NotImplementedError(f"Can't make assertion for type {type}")
case TypeVar():
raise RuntimeError("Unexpected TypeVar")
# Ensure exhaustiveness
case _:
assert_never(type)
def _make_cast_assert_message(
self, location: Location, expr: ast.expr, type: Type