Files
midas/midas/checker/evaluator.py
LordBaryhobal 82666a4918 feat(checker): add evaluator
add an evaluator class to evaluate expressions using literal values
2026-06-24 11:08:15 +02:00

173 lines
5.9 KiB
Python

from dataclasses import dataclass
from typing import Any, Callable, Optional
import midas.ast.midas as m
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import Function, Predicate
from midas.lexer.token import TokenType
@dataclass(frozen=True, kw_only=True)
class PartialPredicate(Predicate):
scope: dict[str, Any]
class Evaluator(m.Expr.Visitor[Any]):
def __init__(self, types: TypesRegistry, reporter: Optional[FileReporter] = None):
self.types: TypesRegistry = types
self.reporter: Optional[FileReporter] = reporter
self.preamble: Preamble = Preamble(self.types)
self.scopes: list[dict[str, Any]] = [{}]
def evaluate(self, expr: m.Expr) -> Any:
value: Any = expr.accept(self)
if self.reporter is not None:
self.reporter.debug(expr.location, f"Value: {value}")
return value
def get_value(self, name: str) -> Any:
scope: dict[str, Any] = self.scopes[-1]
return scope[name]
def set_value(self, name: str, value: Any, force_declare: bool = False):
if not force_declare:
for scope in reversed(self.scopes):
if name in scope:
scope[name] = value
return
self.scopes[-1][name] = value
def visit_logical_expr(self, expr: m.LogicalExpr) -> Any:
def left():
return self.evaluate(expr.left)
def right():
return self.evaluate(expr.right)
match expr.operator.type:
case TokenType.AND:
return left() and right()
case _:
raise NotImplementedError
def visit_binary_expr(self, expr: m.BinaryExpr) -> Any:
left: Any = self.evaluate(expr.left)
right: Any = self.evaluate(expr.right)
match expr.operator.type:
case TokenType.MINUS:
return left - right
case TokenType.STAR:
return left * right
case TokenType.SLASH:
return left / right
case TokenType.GREATER:
return left > right
case TokenType.GREATER_EQUAL:
return left >= right
case TokenType.LESS:
return left < right
case TokenType.LESS_EQUAL:
return left <= right
case TokenType.EQUAL_EQUAL:
return left == right
case TokenType.BANG_EQUAL:
return left != right
case _:
raise NotImplementedError
def visit_unary_expr(self, expr: m.UnaryExpr) -> Any:
right: Any = self.evaluate(expr.right)
match expr.operator.type:
case TokenType.MINUS:
return -right
case _:
raise NotImplementedError
def visit_call_expr(self, expr: m.CallExpr) -> Any:
callee: Any = self.evaluate(expr.callee)
args: list[Any] = [self.evaluate(arg) for arg in expr.arguments]
kwargs: dict[str, Any] = {
name: self.evaluate(arg) for name, arg in expr.keywords.items()
}
match callee:
case Predicate():
return self._evaluate_predicate(callee, args, kwargs)
case _ if callable(callee):
return callee(*args, **kwargs)
case _:
return NotImplementedError
def visit_get_expr(self, expr: m.GetExpr) -> Any:
obj: Any = self.evaluate(expr.expr)
return getattr(obj, expr.name.lexeme)
def visit_variable_expr(self, expr: m.VariableExpr) -> Any:
name: str = expr.name.lexeme
for scope in reversed(self.scopes):
if name in scope:
return scope[name]
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
if predicate is not None:
if predicate.alias:
return self.evaluate(predicate.body)
return predicate
glob: Optional[Callable] = self.preamble.get_py_func(name)
if glob is not None:
return glob
raise NameError(f"Unknown variable '{name}'")
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Any:
return self.evaluate(expr.expr)
def visit_literal_expr(self, expr: m.LiteralExpr) -> Any:
return expr.value
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Any:
return self.get_value("_")
def _evaluate_predicate(
self, predicate: Predicate, args: list[Any], kwargs: dict[str, Any]
) -> Any:
res: Any = None
if isinstance(predicate, PartialPredicate):
self.scopes.append(predicate.scope)
else:
self.scopes.append({})
match predicate.type:
case Function(returns=Function() as inner):
self._map_args(predicate.type, args, kwargs)
res = PartialPredicate(
type=inner,
body=predicate.body,
alias=False,
scope=self.scopes[-1],
)
case Function():
self._map_args(predicate.type, args, kwargs)
res = self.evaluate(predicate.body)
case _:
raise NotImplementedError
self.scopes.pop()
return res
def _map_args(self, function: Function, args: list[Any], kwargs: dict[str, Any]):
positional: list[Function.Argument] = function.pos_args + function.args
keywords: dict[str, Function.Argument] = {
arg.name: arg for arg in function.args + function.kw_args
}
for i, arg in enumerate(args):
param: Function.Argument = positional[i]
self.set_value(param.name, arg)
for name, arg in kwargs.items():
param: Function.Argument = keywords[name]
self.set_value(param.name, arg)