diff --git a/midas/checker/builtins.py b/midas/checker/builtins.py index daf0015..78ba18f 100644 --- a/midas/checker/builtins.py +++ b/midas/checker/builtins.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: BUILTIN_SUBTYPES: dict[str, set[str]] = { - "object": {"float", "list", "dict"}, + "object": {"float", "list", "dict", "str"}, "float": {"int"}, "int": {"bool"}, } diff --git a/midas/checker/diagnostic.py b/midas/checker/diagnostic.py index f4b3d12..233dd74 100644 --- a/midas/checker/diagnostic.py +++ b/midas/checker/diagnostic.py @@ -9,6 +9,7 @@ class DiagnosticType(StrEnum): ERROR = "Error" WARNING = "Warning" INFO = "Info" + DEBUG = "Debug" @dataclass(frozen=True) diff --git a/midas/checker/evaluator.py b/midas/checker/evaluator.py new file mode 100644 index 0000000..c21f1c2 --- /dev/null +++ b/midas/checker/evaluator.py @@ -0,0 +1,172 @@ +from dataclasses import dataclass +from typing import Any, Callable, Optional + +import midas.ast.midas as m +from midas.checker.preamble import Preamble +from midas.checker.registry import TypesRegistry +from midas.checker.reporter import FileReporter +from midas.checker.types import Function, Predicate +from midas.lexer.token import TokenType + + +@dataclass(frozen=True, kw_only=True) +class PartialPredicate(Predicate): + scope: dict[str, Any] + + +class Evaluator(m.Expr.Visitor[Any]): + def __init__(self, types: TypesRegistry, reporter: Optional[FileReporter] = None): + self.types: TypesRegistry = types + self.reporter: Optional[FileReporter] = reporter + self.preamble: Preamble = Preamble(self.types) + self.scopes: list[dict[str, Any]] = [{}] + + def evaluate(self, expr: m.Expr) -> Any: + value: Any = expr.accept(self) + if self.reporter is not None: + self.reporter.debug(expr.location, f"Value: {value}") + return value + + def get_value(self, name: str) -> Any: + scope: dict[str, Any] = self.scopes[-1] + return scope[name] + + def set_value(self, name: str, value: Any, force_declare: bool = False): + if not force_declare: + for scope in reversed(self.scopes): + if name in scope: + scope[name] = value + return + self.scopes[-1][name] = value + + def visit_logical_expr(self, expr: m.LogicalExpr) -> Any: + def left(): + return self.evaluate(expr.left) + + def right(): + return self.evaluate(expr.right) + + match expr.operator.type: + case TokenType.AND: + return left() and right() + case _: + raise NotImplementedError + + def visit_binary_expr(self, expr: m.BinaryExpr) -> Any: + left: Any = self.evaluate(expr.left) + right: Any = self.evaluate(expr.right) + match expr.operator.type: + case TokenType.MINUS: + return left - right + case TokenType.STAR: + return left * right + case TokenType.SLASH: + return left / right + case TokenType.GREATER: + return left > right + case TokenType.GREATER_EQUAL: + return left >= right + case TokenType.LESS: + return left < right + case TokenType.LESS_EQUAL: + return left <= right + case TokenType.EQUAL_EQUAL: + return left == right + case TokenType.BANG_EQUAL: + return left != right + case _: + raise NotImplementedError + + def visit_unary_expr(self, expr: m.UnaryExpr) -> Any: + right: Any = self.evaluate(expr.right) + match expr.operator.type: + case TokenType.MINUS: + return -right + case _: + raise NotImplementedError + + def visit_call_expr(self, expr: m.CallExpr) -> Any: + callee: Any = self.evaluate(expr.callee) + args: list[Any] = [self.evaluate(arg) for arg in expr.arguments] + kwargs: dict[str, Any] = { + name: self.evaluate(arg) for name, arg in expr.keywords.items() + } + + match callee: + case Predicate(): + return self._evaluate_predicate(callee, args, kwargs) + case _ if callable(callee): + return callee(*args, **kwargs) + case _: + return NotImplementedError + + def visit_get_expr(self, expr: m.GetExpr) -> Any: + obj: Any = self.evaluate(expr.expr) + return getattr(obj, expr.name.lexeme) + + def visit_variable_expr(self, expr: m.VariableExpr) -> Any: + name: str = expr.name.lexeme + for scope in reversed(self.scopes): + if name in scope: + return scope[name] + + predicate: Optional[Predicate] = self.types.lookup_predicate(name) + if predicate is not None: + if predicate.alias: + return self.evaluate(predicate.body) + return predicate + + glob: Optional[Callable] = self.preamble.get_py_func(name) + if glob is not None: + return glob + raise NameError(f"Unknown variable '{name}'") + + def visit_grouping_expr(self, expr: m.GroupingExpr) -> Any: + return self.evaluate(expr.expr) + + def visit_literal_expr(self, expr: m.LiteralExpr) -> Any: + return expr.value + + def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Any: + return self.get_value("_") + + def _evaluate_predicate( + self, predicate: Predicate, args: list[Any], kwargs: dict[str, Any] + ) -> Any: + res: Any = None + if isinstance(predicate, PartialPredicate): + self.scopes.append(predicate.scope) + else: + self.scopes.append({}) + match predicate.type: + case Function(returns=Function() as inner): + self._map_args(predicate.type, args, kwargs) + res = PartialPredicate( + type=inner, + body=predicate.body, + alias=False, + scope=self.scopes[-1], + ) + + case Function(): + self._map_args(predicate.type, args, kwargs) + res = self.evaluate(predicate.body) + + case _: + raise NotImplementedError + self.scopes.pop() + return res + + def _map_args(self, function: Function, args: list[Any], kwargs: dict[str, Any]): + positional: list[Function.Argument] = function.pos_args + function.args + keywords: dict[str, Function.Argument] = { + arg.name: arg for arg in function.args + function.kw_args + } + + for i, arg in enumerate(args): + param: Function.Argument = positional[i] + self.set_value(param.name, arg) + + for name, arg in kwargs.items(): + param: Function.Argument = keywords[name] + self.set_value(param.name, arg) diff --git a/midas/checker/preamble.py b/midas/checker/preamble.py index 1dcd157..feea8b7 100644 --- a/midas/checker/preamble.py +++ b/midas/checker/preamble.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Callable, Optional from midas.checker.environment import Environment from midas.checker.registry import TypesRegistry @@ -16,16 +17,18 @@ class Preamble(Environment): def __init__(self, types: TypesRegistry) -> None: super().__init__() self._types: TypesRegistry = types + self._python_funcs: dict[str, Callable] = {} - self._def_type_constructor("object") - self._def_type_constructor("float") - self._def_type_constructor("int") - self._def_type_constructor("bool") - self._def_type_constructor("str") + self._def_type_constructor("object", object) + self._def_type_constructor("float", float) + self._def_type_constructor("int", int) + self._def_type_constructor("bool", bool) + self._def_type_constructor("str", str) self._def_function( name="list", pos=[Param("object", TopType())], returns=self._list_of(TopType()), + py_function=list, ) # TODO: use sink @@ -33,6 +36,7 @@ class Preamble(Environment): name="print", pos=[Param("object", TopType())], returns=UnitType(), + py_function=print, ) map_in = TypeVar(name="T", bound=None) @@ -53,6 +57,7 @@ class Preamble(Environment): ], returns=self._list_of(map_out), # TODO: replace with Iterable[U] type_vars=[map_in, map_out], + py_function=map, ) self._def_function( name="input", @@ -63,12 +68,13 @@ class Preamble(Environment): def _list_of(self, item_type: Type) -> Type: return self._types.apply_generic(self._types.get_type("list"), [item_type]) - def _def_type_constructor(self, name: str): + def _def_type_constructor(self, name: str, py_function: Optional[Callable] = None): # TODO: more specific arg types self._def_function( name=name, pos=[Param("object", TopType(), required=False)], returns=self._types.get_type(name), + py_function=py_function, ) def _make_function( @@ -115,6 +121,7 @@ class Preamble(Environment): kw: list[Param] = [], returns: Type = UnitType(), type_vars: list[TypeVar] = [], + py_function: Optional[Callable] = None, ): function: Type = self._make_function( name=name, @@ -125,3 +132,8 @@ class Preamble(Environment): type_vars=type_vars, ) self.define(name, function) + if py_function is not None: + self._python_funcs[name] = py_function + + def get_py_func(self, name: str) -> Optional[Callable]: + return self._python_funcs.get(name) diff --git a/midas/checker/python.py b/midas/checker/python.py index 435f6f1..ffa1600 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -1,11 +1,13 @@ import ast import logging from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional import midas.ast.python as p from midas.ast.location import Location +from midas.ast.printer import MidasPrinter from midas.checker.environment import Environment +from midas.checker.evaluator import Evaluator from midas.checker.operators import ( PY_COMPARATOR_METHODS, PY_OPERATOR_METHODS, @@ -18,6 +20,8 @@ from midas.checker.resolver import Resolver from midas.checker.types import ( AliasType, AppliedType, + BaseType, + ConstraintType, Function, GenericType, OverloadedFunction, @@ -71,6 +75,7 @@ class PythonTyper( self.env: Environment = self.global_env self.locals: dict[p.Expr, int] = {} self.judgements: list[tuple[p.Expr, Type]] = [] + self.evaluated_casts: list[p.CastExpr] = [] def process(self, source: str, path: Optional[str]) -> TypedAST: self.reporter = self.reporter.for_file(path) @@ -84,10 +89,15 @@ class PythonTyper( self.env = self.global_env self.locals = resolver.locals self.judgements = [] + self.evaluated_casts = [] self.check(stmts) - return TypedAST(stmts=stmts, judgements=self.judgements) + return TypedAST( + stmts=stmts, + judgements=self.judgements, + evaluated_casts=self.evaluated_casts, + ) def judge(self, expr: p.Expr, type: Type): """Record a typing judgement @@ -538,8 +548,16 @@ class PythonTyper( return UnknownType() def visit_cast_expr(self, expr: p.CastExpr) -> Type: - _ = self.type_of(expr.expr) - return self.resolve_type_expr(expr.type) + subject_type: Type = self.type_of(expr.expr) + target_type: Type = self.resolve_type_expr(expr.type) + is_lit, lit_value = self._get_literal(expr.expr) + if is_lit: + evaluated: bool = self._evaluate_cast_statically( + expr, subject_type, target_type, lit_value + ) + if evaluated: + self.evaluated_casts.append(expr) + return target_type def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type: test_type: Type = self.type_of(expr.test) @@ -1108,3 +1126,93 @@ class PythonTyper( return p.BaseType(location=location, base=name, param=None) case _: raise NotImplementedError + + def _get_literal(self, expr: p.Expr) -> tuple[bool, Any]: + match expr: + case p.LiteralExpr(value=value): + return True, value + + case p.ListExpr(items=items): + values: list[Any] = [] + for item in items: + is_lit, value = self._get_literal(item) + if not is_lit: + return False, None + values.append(value) + return True, values + + case p.DictExpr(keys=keys, values=values): + pairs: list[tuple[Any, Any]] = [] + for key, value in zip(keys, values): + key_val = None + if key is not None: + is_lit, key_val = self._get_literal(key) + if not is_lit: + return False, None + + is_lit, value_val = self._get_literal(value) + if not is_lit: + return False, None + + if key is None: + # TODO: check that value is always a dict + assert isinstance(value_val, dict) + pairs.extend(value_val.items()) + else: + pairs.append((key_val, value_val)) + return True, dict(pairs) + + case _: + return False, None + + def _evaluate_cast_statically( + self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any + ) -> bool: + match target_type: + case AliasType(type=base): + return self._evaluate_cast_statically( + expr, subject_type, base, lit_value + ) + + case AppliedType(body=body): + return self._evaluate_cast_statically( + expr, subject_type, body, lit_value + ) + + case ConstraintType(type=base, constraint=constraint): + evaluated: bool = True + if not self._evaluate_cast_statically( + expr, subject_type, base, lit_value + ): + evaluated = False + + evaluator = Evaluator(self.types) + evaluator.set_value("_", lit_value) + res = evaluator.evaluate(constraint) + if not res: + printer = MidasPrinter() + constraint_str: str = printer.print(constraint) + self.reporter.error( + expr.location, + f"Value {lit_value!r} does not fit constraint '{constraint_str}'", + ) + evaluated = False + return evaluated + + case BaseType(): + # TODO: do we want to allow cast(float, int)? would require runtime conversion + if not self.types.is_subtype( + subject_type, target_type + ) or not self.types.is_subtype(target_type, subject_type): + self.reporter.error( + expr.location, + f"Value {lit_value!r} of type {subject_type} cannot be cast as {target_type}", + ) + return False + return True + + case _: + self.reporter.info( + expr.location, f"Cannot evaluate cast to {target_type} statically" + ) + return False diff --git a/midas/checker/reporter.py b/midas/checker/reporter.py index b68766a..b61b8f3 100644 --- a/midas/checker/reporter.py +++ b/midas/checker/reporter.py @@ -61,3 +61,10 @@ class FileReporter: location=location, message=message, ) + + def debug(self, location: Location, message: str): + self.report( + type=DiagnosticType.DEBUG, + location=location, + message=message, + ) diff --git a/midas/cli/utils.py b/midas/cli/utils.py index 793d4e1..6cc7d38 100644 --- a/midas/cli/utils.py +++ b/midas/cli/utils.py @@ -59,6 +59,7 @@ class DiagnosticPrinter: DiagnosticType.ERROR: Ansi.RED, DiagnosticType.WARNING: Ansi.YELLOW, DiagnosticType.INFO: Ansi.CYAN, + DiagnosticType.DEBUG: Ansi.MAGENTA, }.get(diagnostic.type, Ansi.WHITE) subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET diff --git a/midas/generator/generator.py b/midas/generator/generator.py index e66f532..0fba91e 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -44,6 +44,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): self._typed_ast: TypedAST = TypedAST( stmts=[], judgements=[], + evaluated_casts=[], ) self._alias_count: int = 0 self._predicate_count: int = 0 @@ -131,6 +132,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr: expr2: ast.expr = expr.expr.accept(self) + + if expr in self._typed_ast.evaluated_casts: + return expr2 + alias: ast.expr = self._make_alias(expr2) type: Type = self._get_expr_type(expr) diff --git a/midas/utils.py b/midas/utils.py index f10bdb3..0ac90cf 100644 --- a/midas/utils.py +++ b/midas/utils.py @@ -62,3 +62,4 @@ class UniversalJSONDumper: class TypedAST: stmts: list[p.Stmt] judgements: list[tuple[p.Expr, Type]] + evaluated_casts: list[p.CastExpr] diff --git a/tests/cases/checker/04_custom_types.py.ref.json b/tests/cases/checker/04_custom_types.py.ref.json index 1802082..d502a97 100644 --- a/tests/cases/checker/04_custom_types.py.ref.json +++ b/tests/cases/checker/04_custom_types.py.ref.json @@ -1,6 +1,19 @@ { "diagnostics": [], "judgments": [ + { + "location": { + "from": "L4:30", + "to": "L4:36" + }, + "expr": { + "_type": "LiteralExpr", + "value": 123.45 + }, + "type": { + "name": "float" + } + }, { "location": { "from": "L4:18", @@ -25,6 +38,19 @@ } } }, + { + "location": { + "from": "L5:28", + "to": "L5:31" + }, + "expr": { + "_type": "LiteralExpr", + "value": 6.7 + }, + "type": { + "name": "float" + } + }, { "location": { "from": "L5:15", diff --git a/tests/cases/generator/01_simple_types.py.ref.txt b/tests/cases/generator/01_simple_types.py.ref.txt index f8da6a6..f434968 100644 --- a/tests/cases/generator/01_simple_types.py.ref.txt +++ b/tests/cases/generator/01_simple_types.py.ref.txt @@ -7,68 +7,14 @@ Module( alias(name='Meter'), alias(name='Second')], level=0), - Assign( - targets=[ - Name(id='__midas_a0__')], - value=Constant(value=123.45)), - Assert( - test=Call( - func=Name(id='isinstance'), - args=[ - Name(id='__midas_a0__'), - Name(id='float')], - keywords=[]), - msg=JoinedStr( - values=[ - Constant(value='01_simple_types.py:L3:19: CastError: Cannot cast '), - FormattedValue( - value=Attribute( - value=Call( - func=Name(id='type'), - args=[ - Name(id='__midas_a0__')], - keywords=[]), - attr='__name__'), - conversion=-1), - Constant(value=' to float')])), Assign( targets=[ Name(id='distance')], - value=Name(id='__midas_a0__')), - Delete( - targets=[ - Name(id='__midas_a0__')]), - Assign( - targets=[ - Name(id='__midas_a1__')], - value=Constant(value=6.7)), - Assert( - test=Call( - func=Name(id='isinstance'), - args=[ - Name(id='__midas_a1__'), - Name(id='float')], - keywords=[]), - msg=JoinedStr( - values=[ - Constant(value='01_simple_types.py:L4:16: CastError: Cannot cast '), - FormattedValue( - value=Attribute( - value=Call( - func=Name(id='type'), - args=[ - Name(id='__midas_a1__')], - keywords=[]), - attr='__name__'), - conversion=-1), - Constant(value=' to float')])), + value=Constant(value=123.45)), Assign( targets=[ Name(id='time')], - value=Name(id='__midas_a1__')), - Delete( - targets=[ - Name(id='__midas_a1__')]), + value=Constant(value=6.7)), Assign( targets=[ Name(id='speed')],