refactor: ensure exhaustiveness in some match/case
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional, assert_never
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
@@ -203,9 +203,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
case UnknownType() | UnitType():
|
case UnknownType() | UnitType():
|
||||||
return type
|
return type
|
||||||
|
|
||||||
case _:
|
case TopType() | GenericType():
|
||||||
raise NotImplementedError(f"Unsupported type {type}")
|
raise NotImplementedError(f"Unsupported type {type}")
|
||||||
|
|
||||||
|
# Ensure exhaustiveness
|
||||||
|
case _:
|
||||||
|
assert_never(type)
|
||||||
|
|
||||||
|
|
||||||
def unfold_type(type: Type) -> Type:
|
def unfold_type(type: Type) -> Type:
|
||||||
match type:
|
match type:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import ast
|
|||||||
import shutil
|
import shutil
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, assert_never
|
||||||
|
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
@@ -19,6 +19,7 @@ from midas.checker.types import (
|
|||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnitType,
|
UnitType,
|
||||||
|
UnknownType,
|
||||||
)
|
)
|
||||||
from midas.utils import TypedAST
|
from midas.utils import TypedAST
|
||||||
|
|
||||||
@@ -276,6 +277,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
|
|
||||||
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
|
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
|
||||||
match type:
|
match type:
|
||||||
|
case UnknownType():
|
||||||
|
pass
|
||||||
|
|
||||||
case BaseType(name=name):
|
case BaseType(name=name):
|
||||||
self._add_assert(
|
self._add_assert(
|
||||||
ast.Call(
|
ast.Call(
|
||||||
@@ -301,8 +305,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
self._make_cast_assert_message(src_location, expr, type),
|
self._make_cast_assert_message(src_location, expr, type),
|
||||||
)
|
)
|
||||||
|
|
||||||
case AppliedType():
|
case AppliedType(body=body):
|
||||||
self._make_cast_asserts(src_location, expr, type.body)
|
self._make_cast_asserts(src_location, expr, body)
|
||||||
|
|
||||||
|
case TypeVar():
|
||||||
|
raise RuntimeError("Unexpected TypeVar")
|
||||||
|
|
||||||
case (
|
case (
|
||||||
TopType()
|
TopType()
|
||||||
@@ -314,8 +321,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
):
|
):
|
||||||
raise NotImplementedError(f"Can't make assertion for type {type}")
|
raise NotImplementedError(f"Can't make assertion for type {type}")
|
||||||
|
|
||||||
case TypeVar():
|
# Ensure exhaustiveness
|
||||||
raise RuntimeError("Unexpected TypeVar")
|
case _:
|
||||||
|
assert_never(type)
|
||||||
|
|
||||||
def _make_cast_assert_message(
|
def _make_cast_assert_message(
|
||||||
self, location: Location, expr: ast.expr, type: Type
|
self, location: Location, expr: ast.expr, type: Type
|
||||||
|
|||||||
Reference in New Issue
Block a user