feat(checker): evaluate constraints statically on literals

This commit is contained in:
2026-06-24 11:09:03 +02:00
parent 82666a4918
commit e1d5eac8b8
2 changed files with 92 additions and 4 deletions

View File

@@ -15,7 +15,7 @@ if TYPE_CHECKING:
BUILTIN_SUBTYPES: dict[str, set[str]] = {
"object": {"float", "list", "dict"},
"object": {"float", "list", "dict", "str"},
"float": {"int"},
"int": {"bool"},
}

View File

@@ -1,11 +1,13 @@
import ast
import logging
from dataclasses import dataclass
from typing import Optional
from typing import Any, Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasPrinter
from midas.checker.environment import Environment
from midas.checker.evaluator import Evaluator
from midas.checker.operators import (
PY_COMPARATOR_METHODS,
PY_OPERATOR_METHODS,
@@ -18,6 +20,8 @@ from midas.checker.resolver import Resolver
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ConstraintType,
Function,
GenericType,
OverloadedFunction,
@@ -538,8 +542,12 @@ class PythonTyper(
return UnknownType()
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
_ = self.type_of(expr.expr)
return self.resolve_type_expr(expr.type)
subject_type: Type = self.type_of(expr.expr)
target_type: Type = self.resolve_type_expr(expr.type)
is_lit, lit_value = self._get_literal(expr.expr)
if is_lit:
self._evaluate_cast_statically(expr, subject_type, target_type, lit_value)
return target_type
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
test_type: Type = self.type_of(expr.test)
@@ -1108,3 +1116,83 @@ class PythonTyper(
return p.BaseType(location=location, base=name, param=None)
case _:
raise NotImplementedError
def _get_literal(self, expr: p.Expr) -> tuple[bool, Any]:
match expr:
case p.LiteralExpr(value=value):
return True, value
case p.ListExpr(items=items):
values: list[Any] = []
for item in items:
is_lit, value = self._get_literal(item)
if not is_lit:
return False, None
values.append(value)
return True, values
case p.DictExpr(keys=keys, values=values):
pairs: list[tuple[Any, Any]] = []
for key, value in zip(keys, values):
key_val = None
if key is not None:
is_lit, key_val = self._get_literal(key)
if not is_lit:
return False, None
is_lit, value_val = self._get_literal(value)
if not is_lit:
return False, None
if key is None:
# TODO: check that value is always a dict
assert isinstance(value_val, dict)
pairs.extend(value_val.items())
else:
pairs.append((key_val, value_val))
return True, dict(pairs)
case _:
return False, None
def _evaluate_cast_statically(
self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any
):
match target_type:
case AliasType(type=base):
self._evaluate_cast_statically(expr, subject_type, base, lit_value)
case AppliedType(name=name, args=args, body=body):
generic: Type = self.types.get_type(name)
assert isinstance(generic, GenericType)
for param, arg in zip(generic.params, args):
pass
self._evaluate_cast_statically(expr, subject_type, body, lit_value)
case ConstraintType(type=base, constraint=constraint):
self._evaluate_cast_statically(expr, subject_type, base, lit_value)
evaluator = Evaluator(self.types)
evaluator.set_value("_", lit_value)
res = evaluator.evaluate(constraint)
if not res:
printer = MidasPrinter()
constraint_str: str = printer.print(constraint)
self.reporter.error(
expr.location,
f"Value {lit_value!r} does not fit constraint '{constraint_str}'",
)
case BaseType():
# TODO: do we want to allow cast(float, int)? would require runtime conversion
if not self.types.is_subtype(
subject_type, target_type
) or not self.types.is_subtype(target_type, subject_type):
self.reporter.error(
expr.location,
f"Value {lit_value!r} of type {subject_type} cannot be cast as {target_type}",
)
case _:
self.reporter.info(
expr.location, f"Cannot evaluate cast to {target_type} statically"
)