diff --git a/midas/checker/builtins.py b/midas/checker/builtins.py index daf0015..78ba18f 100644 --- a/midas/checker/builtins.py +++ b/midas/checker/builtins.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: BUILTIN_SUBTYPES: dict[str, set[str]] = { - "object": {"float", "list", "dict"}, + "object": {"float", "list", "dict", "str"}, "float": {"int"}, "int": {"bool"}, } diff --git a/midas/checker/python.py b/midas/checker/python.py index 435f6f1..0ff9eed 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -1,11 +1,13 @@ import ast import logging from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional import midas.ast.python as p from midas.ast.location import Location +from midas.ast.printer import MidasPrinter from midas.checker.environment import Environment +from midas.checker.evaluator import Evaluator from midas.checker.operators import ( PY_COMPARATOR_METHODS, PY_OPERATOR_METHODS, @@ -18,6 +20,8 @@ from midas.checker.resolver import Resolver from midas.checker.types import ( AliasType, AppliedType, + BaseType, + ConstraintType, Function, GenericType, OverloadedFunction, @@ -538,8 +542,12 @@ class PythonTyper( return UnknownType() def visit_cast_expr(self, expr: p.CastExpr) -> Type: - _ = self.type_of(expr.expr) - return self.resolve_type_expr(expr.type) + subject_type: Type = self.type_of(expr.expr) + target_type: Type = self.resolve_type_expr(expr.type) + is_lit, lit_value = self._get_literal(expr.expr) + if is_lit: + self._evaluate_cast_statically(expr, subject_type, target_type, lit_value) + return target_type def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type: test_type: Type = self.type_of(expr.test) @@ -1108,3 +1116,83 @@ class PythonTyper( return p.BaseType(location=location, base=name, param=None) case _: raise NotImplementedError + + def _get_literal(self, expr: p.Expr) -> tuple[bool, Any]: + match expr: + case p.LiteralExpr(value=value): + return True, value + + case p.ListExpr(items=items): + values: list[Any] = [] + for item in items: + is_lit, value = self._get_literal(item) + if not is_lit: + return False, None + values.append(value) + return True, values + + case p.DictExpr(keys=keys, values=values): + pairs: list[tuple[Any, Any]] = [] + for key, value in zip(keys, values): + key_val = None + if key is not None: + is_lit, key_val = self._get_literal(key) + if not is_lit: + return False, None + + is_lit, value_val = self._get_literal(value) + if not is_lit: + return False, None + + if key is None: + # TODO: check that value is always a dict + assert isinstance(value_val, dict) + pairs.extend(value_val.items()) + else: + pairs.append((key_val, value_val)) + return True, dict(pairs) + + case _: + return False, None + + def _evaluate_cast_statically( + self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any + ): + match target_type: + case AliasType(type=base): + self._evaluate_cast_statically(expr, subject_type, base, lit_value) + + case AppliedType(name=name, args=args, body=body): + generic: Type = self.types.get_type(name) + assert isinstance(generic, GenericType) + for param, arg in zip(generic.params, args): + pass + self._evaluate_cast_statically(expr, subject_type, body, lit_value) + + case ConstraintType(type=base, constraint=constraint): + self._evaluate_cast_statically(expr, subject_type, base, lit_value) + evaluator = Evaluator(self.types) + evaluator.set_value("_", lit_value) + res = evaluator.evaluate(constraint) + if not res: + printer = MidasPrinter() + constraint_str: str = printer.print(constraint) + self.reporter.error( + expr.location, + f"Value {lit_value!r} does not fit constraint '{constraint_str}'", + ) + + case BaseType(): + # TODO: do we want to allow cast(float, int)? would require runtime conversion + if not self.types.is_subtype( + subject_type, target_type + ) or not self.types.is_subtype(target_type, subject_type): + self.reporter.error( + expr.location, + f"Value {lit_value!r} of type {subject_type} cannot be cast as {target_type}", + ) + + case _: + self.reporter.info( + expr.location, f"Cannot evaluate cast to {target_type} statically" + )