feat(checker): evaluate constraints statically on literals
This commit is contained in:
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||||
"object": {"float", "list", "dict"},
|
"object": {"float", "list", "dict", "str"},
|
||||||
"float": {"int"},
|
"float": {"int"},
|
||||||
"int": {"bool"},
|
"int": {"bool"},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
import ast
|
import ast
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
|
from midas.ast.printer import MidasPrinter
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
|
from midas.checker.evaluator import Evaluator
|
||||||
from midas.checker.operators import (
|
from midas.checker.operators import (
|
||||||
PY_COMPARATOR_METHODS,
|
PY_COMPARATOR_METHODS,
|
||||||
PY_OPERATOR_METHODS,
|
PY_OPERATOR_METHODS,
|
||||||
@@ -18,6 +20,8 @@ from midas.checker.resolver import Resolver
|
|||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
AliasType,
|
||||||
AppliedType,
|
AppliedType,
|
||||||
|
BaseType,
|
||||||
|
ConstraintType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
@@ -538,8 +542,12 @@ class PythonTyper(
|
|||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
||||||
_ = self.type_of(expr.expr)
|
subject_type: Type = self.type_of(expr.expr)
|
||||||
return self.resolve_type_expr(expr.type)
|
target_type: Type = self.resolve_type_expr(expr.type)
|
||||||
|
is_lit, lit_value = self._get_literal(expr.expr)
|
||||||
|
if is_lit:
|
||||||
|
self._evaluate_cast_statically(expr, subject_type, target_type, lit_value)
|
||||||
|
return target_type
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
||||||
test_type: Type = self.type_of(expr.test)
|
test_type: Type = self.type_of(expr.test)
|
||||||
@@ -1108,3 +1116,83 @@ class PythonTyper(
|
|||||||
return p.BaseType(location=location, base=name, param=None)
|
return p.BaseType(location=location, base=name, param=None)
|
||||||
case _:
|
case _:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _get_literal(self, expr: p.Expr) -> tuple[bool, Any]:
|
||||||
|
match expr:
|
||||||
|
case p.LiteralExpr(value=value):
|
||||||
|
return True, value
|
||||||
|
|
||||||
|
case p.ListExpr(items=items):
|
||||||
|
values: list[Any] = []
|
||||||
|
for item in items:
|
||||||
|
is_lit, value = self._get_literal(item)
|
||||||
|
if not is_lit:
|
||||||
|
return False, None
|
||||||
|
values.append(value)
|
||||||
|
return True, values
|
||||||
|
|
||||||
|
case p.DictExpr(keys=keys, values=values):
|
||||||
|
pairs: list[tuple[Any, Any]] = []
|
||||||
|
for key, value in zip(keys, values):
|
||||||
|
key_val = None
|
||||||
|
if key is not None:
|
||||||
|
is_lit, key_val = self._get_literal(key)
|
||||||
|
if not is_lit:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
is_lit, value_val = self._get_literal(value)
|
||||||
|
if not is_lit:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
if key is None:
|
||||||
|
# TODO: check that value is always a dict
|
||||||
|
assert isinstance(value_val, dict)
|
||||||
|
pairs.extend(value_val.items())
|
||||||
|
else:
|
||||||
|
pairs.append((key_val, value_val))
|
||||||
|
return True, dict(pairs)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def _evaluate_cast_statically(
|
||||||
|
self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any
|
||||||
|
):
|
||||||
|
match target_type:
|
||||||
|
case AliasType(type=base):
|
||||||
|
self._evaluate_cast_statically(expr, subject_type, base, lit_value)
|
||||||
|
|
||||||
|
case AppliedType(name=name, args=args, body=body):
|
||||||
|
generic: Type = self.types.get_type(name)
|
||||||
|
assert isinstance(generic, GenericType)
|
||||||
|
for param, arg in zip(generic.params, args):
|
||||||
|
pass
|
||||||
|
self._evaluate_cast_statically(expr, subject_type, body, lit_value)
|
||||||
|
|
||||||
|
case ConstraintType(type=base, constraint=constraint):
|
||||||
|
self._evaluate_cast_statically(expr, subject_type, base, lit_value)
|
||||||
|
evaluator = Evaluator(self.types)
|
||||||
|
evaluator.set_value("_", lit_value)
|
||||||
|
res = evaluator.evaluate(constraint)
|
||||||
|
if not res:
|
||||||
|
printer = MidasPrinter()
|
||||||
|
constraint_str: str = printer.print(constraint)
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Value {lit_value!r} does not fit constraint '{constraint_str}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
case BaseType():
|
||||||
|
# TODO: do we want to allow cast(float, int)? would require runtime conversion
|
||||||
|
if not self.types.is_subtype(
|
||||||
|
subject_type, target_type
|
||||||
|
) or not self.types.is_subtype(target_type, subject_type):
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Value {lit_value!r} of type {subject_type} cannot be cast as {target_type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
self.reporter.info(
|
||||||
|
expr.location, f"Cannot evaluate cast to {target_type} statically"
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user