Merge pull request 'Static evalution of casts on literals' (#20) from feat/literal-static-constraints into main

Reviewed-on: #20
This commit was merged in pull request #20.
This commit is contained in:
2026-06-24 09:32:54 +00:00
11 changed files with 346 additions and 67 deletions

View File

@@ -15,7 +15,7 @@ if TYPE_CHECKING:
BUILTIN_SUBTYPES: dict[str, set[str]] = { BUILTIN_SUBTYPES: dict[str, set[str]] = {
"object": {"float", "list", "dict"}, "object": {"float", "list", "dict", "str"},
"float": {"int"}, "float": {"int"},
"int": {"bool"}, "int": {"bool"},
} }

View File

@@ -9,6 +9,7 @@ class DiagnosticType(StrEnum):
ERROR = "Error" ERROR = "Error"
WARNING = "Warning" WARNING = "Warning"
INFO = "Info" INFO = "Info"
DEBUG = "Debug"
@dataclass(frozen=True) @dataclass(frozen=True)

172
midas/checker/evaluator.py Normal file
View File

@@ -0,0 +1,172 @@
from dataclasses import dataclass
from typing import Any, Callable, Optional
import midas.ast.midas as m
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import Function, Predicate
from midas.lexer.token import TokenType
@dataclass(frozen=True, kw_only=True)
class PartialPredicate(Predicate):
scope: dict[str, Any]
class Evaluator(m.Expr.Visitor[Any]):
def __init__(self, types: TypesRegistry, reporter: Optional[FileReporter] = None):
self.types: TypesRegistry = types
self.reporter: Optional[FileReporter] = reporter
self.preamble: Preamble = Preamble(self.types)
self.scopes: list[dict[str, Any]] = [{}]
def evaluate(self, expr: m.Expr) -> Any:
value: Any = expr.accept(self)
if self.reporter is not None:
self.reporter.debug(expr.location, f"Value: {value}")
return value
def get_value(self, name: str) -> Any:
scope: dict[str, Any] = self.scopes[-1]
return scope[name]
def set_value(self, name: str, value: Any, force_declare: bool = False):
if not force_declare:
for scope in reversed(self.scopes):
if name in scope:
scope[name] = value
return
self.scopes[-1][name] = value
def visit_logical_expr(self, expr: m.LogicalExpr) -> Any:
def left():
return self.evaluate(expr.left)
def right():
return self.evaluate(expr.right)
match expr.operator.type:
case TokenType.AND:
return left() and right()
case _:
raise NotImplementedError
def visit_binary_expr(self, expr: m.BinaryExpr) -> Any:
left: Any = self.evaluate(expr.left)
right: Any = self.evaluate(expr.right)
match expr.operator.type:
case TokenType.MINUS:
return left - right
case TokenType.STAR:
return left * right
case TokenType.SLASH:
return left / right
case TokenType.GREATER:
return left > right
case TokenType.GREATER_EQUAL:
return left >= right
case TokenType.LESS:
return left < right
case TokenType.LESS_EQUAL:
return left <= right
case TokenType.EQUAL_EQUAL:
return left == right
case TokenType.BANG_EQUAL:
return left != right
case _:
raise NotImplementedError
def visit_unary_expr(self, expr: m.UnaryExpr) -> Any:
right: Any = self.evaluate(expr.right)
match expr.operator.type:
case TokenType.MINUS:
return -right
case _:
raise NotImplementedError
def visit_call_expr(self, expr: m.CallExpr) -> Any:
callee: Any = self.evaluate(expr.callee)
args: list[Any] = [self.evaluate(arg) for arg in expr.arguments]
kwargs: dict[str, Any] = {
name: self.evaluate(arg) for name, arg in expr.keywords.items()
}
match callee:
case Predicate():
return self._evaluate_predicate(callee, args, kwargs)
case _ if callable(callee):
return callee(*args, **kwargs)
case _:
return NotImplementedError
def visit_get_expr(self, expr: m.GetExpr) -> Any:
obj: Any = self.evaluate(expr.expr)
return getattr(obj, expr.name.lexeme)
def visit_variable_expr(self, expr: m.VariableExpr) -> Any:
name: str = expr.name.lexeme
for scope in reversed(self.scopes):
if name in scope:
return scope[name]
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
if predicate is not None:
if predicate.alias:
return self.evaluate(predicate.body)
return predicate
glob: Optional[Callable] = self.preamble.get_py_func(name)
if glob is not None:
return glob
raise NameError(f"Unknown variable '{name}'")
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Any:
return self.evaluate(expr.expr)
def visit_literal_expr(self, expr: m.LiteralExpr) -> Any:
return expr.value
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Any:
return self.get_value("_")
def _evaluate_predicate(
self, predicate: Predicate, args: list[Any], kwargs: dict[str, Any]
) -> Any:
res: Any = None
if isinstance(predicate, PartialPredicate):
self.scopes.append(predicate.scope)
else:
self.scopes.append({})
match predicate.type:
case Function(returns=Function() as inner):
self._map_args(predicate.type, args, kwargs)
res = PartialPredicate(
type=inner,
body=predicate.body,
alias=False,
scope=self.scopes[-1],
)
case Function():
self._map_args(predicate.type, args, kwargs)
res = self.evaluate(predicate.body)
case _:
raise NotImplementedError
self.scopes.pop()
return res
def _map_args(self, function: Function, args: list[Any], kwargs: dict[str, Any]):
positional: list[Function.Argument] = function.pos_args + function.args
keywords: dict[str, Function.Argument] = {
arg.name: arg for arg in function.args + function.kw_args
}
for i, arg in enumerate(args):
param: Function.Argument = positional[i]
self.set_value(param.name, arg)
for name, arg in kwargs.items():
param: Function.Argument = keywords[name]
self.set_value(param.name, arg)

View File

@@ -1,4 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Optional
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.registry import TypesRegistry from midas.checker.registry import TypesRegistry
@@ -16,16 +17,18 @@ class Preamble(Environment):
def __init__(self, types: TypesRegistry) -> None: def __init__(self, types: TypesRegistry) -> None:
super().__init__() super().__init__()
self._types: TypesRegistry = types self._types: TypesRegistry = types
self._python_funcs: dict[str, Callable] = {}
self._def_type_constructor("object") self._def_type_constructor("object", object)
self._def_type_constructor("float") self._def_type_constructor("float", float)
self._def_type_constructor("int") self._def_type_constructor("int", int)
self._def_type_constructor("bool") self._def_type_constructor("bool", bool)
self._def_type_constructor("str") self._def_type_constructor("str", str)
self._def_function( self._def_function(
name="list", name="list",
pos=[Param("object", TopType())], pos=[Param("object", TopType())],
returns=self._list_of(TopType()), returns=self._list_of(TopType()),
py_function=list,
) )
# TODO: use sink # TODO: use sink
@@ -33,6 +36,7 @@ class Preamble(Environment):
name="print", name="print",
pos=[Param("object", TopType())], pos=[Param("object", TopType())],
returns=UnitType(), returns=UnitType(),
py_function=print,
) )
map_in = TypeVar(name="T", bound=None) map_in = TypeVar(name="T", bound=None)
@@ -53,6 +57,7 @@ class Preamble(Environment):
], ],
returns=self._list_of(map_out), # TODO: replace with Iterable[U] returns=self._list_of(map_out), # TODO: replace with Iterable[U]
type_vars=[map_in, map_out], type_vars=[map_in, map_out],
py_function=map,
) )
self._def_function( self._def_function(
name="input", name="input",
@@ -63,12 +68,13 @@ class Preamble(Environment):
def _list_of(self, item_type: Type) -> Type: def _list_of(self, item_type: Type) -> Type:
return self._types.apply_generic(self._types.get_type("list"), [item_type]) return self._types.apply_generic(self._types.get_type("list"), [item_type])
def _def_type_constructor(self, name: str): def _def_type_constructor(self, name: str, py_function: Optional[Callable] = None):
# TODO: more specific arg types # TODO: more specific arg types
self._def_function( self._def_function(
name=name, name=name,
pos=[Param("object", TopType(), required=False)], pos=[Param("object", TopType(), required=False)],
returns=self._types.get_type(name), returns=self._types.get_type(name),
py_function=py_function,
) )
def _make_function( def _make_function(
@@ -115,6 +121,7 @@ class Preamble(Environment):
kw: list[Param] = [], kw: list[Param] = [],
returns: Type = UnitType(), returns: Type = UnitType(),
type_vars: list[TypeVar] = [], type_vars: list[TypeVar] = [],
py_function: Optional[Callable] = None,
): ):
function: Type = self._make_function( function: Type = self._make_function(
name=name, name=name,
@@ -125,3 +132,8 @@ class Preamble(Environment):
type_vars=type_vars, type_vars=type_vars,
) )
self.define(name, function) self.define(name, function)
if py_function is not None:
self._python_funcs[name] = py_function
def get_py_func(self, name: str) -> Optional[Callable]:
return self._python_funcs.get(name)

View File

@@ -1,11 +1,13 @@
import ast import ast
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Any, Optional
import midas.ast.python as p import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
from midas.ast.printer import MidasPrinter
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.evaluator import Evaluator
from midas.checker.operators import ( from midas.checker.operators import (
PY_COMPARATOR_METHODS, PY_COMPARATOR_METHODS,
PY_OPERATOR_METHODS, PY_OPERATOR_METHODS,
@@ -18,6 +20,8 @@ from midas.checker.resolver import Resolver
from midas.checker.types import ( from midas.checker.types import (
AliasType, AliasType,
AppliedType, AppliedType,
BaseType,
ConstraintType,
Function, Function,
GenericType, GenericType,
OverloadedFunction, OverloadedFunction,
@@ -71,6 +75,7 @@ class PythonTyper(
self.env: Environment = self.global_env self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {} self.locals: dict[p.Expr, int] = {}
self.judgements: list[tuple[p.Expr, Type]] = [] self.judgements: list[tuple[p.Expr, Type]] = []
self.evaluated_casts: list[p.CastExpr] = []
def process(self, source: str, path: Optional[str]) -> TypedAST: def process(self, source: str, path: Optional[str]) -> TypedAST:
self.reporter = self.reporter.for_file(path) self.reporter = self.reporter.for_file(path)
@@ -84,10 +89,15 @@ class PythonTyper(
self.env = self.global_env self.env = self.global_env
self.locals = resolver.locals self.locals = resolver.locals
self.judgements = [] self.judgements = []
self.evaluated_casts = []
self.check(stmts) self.check(stmts)
return TypedAST(stmts=stmts, judgements=self.judgements) return TypedAST(
stmts=stmts,
judgements=self.judgements,
evaluated_casts=self.evaluated_casts,
)
def judge(self, expr: p.Expr, type: Type): def judge(self, expr: p.Expr, type: Type):
"""Record a typing judgement """Record a typing judgement
@@ -538,8 +548,16 @@ class PythonTyper(
return UnknownType() return UnknownType()
def visit_cast_expr(self, expr: p.CastExpr) -> Type: def visit_cast_expr(self, expr: p.CastExpr) -> Type:
_ = self.type_of(expr.expr) subject_type: Type = self.type_of(expr.expr)
return self.resolve_type_expr(expr.type) target_type: Type = self.resolve_type_expr(expr.type)
is_lit, lit_value = self._get_literal(expr.expr)
if is_lit:
evaluated: bool = self._evaluate_cast_statically(
expr, subject_type, target_type, lit_value
)
if evaluated:
self.evaluated_casts.append(expr)
return target_type
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type: def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
test_type: Type = self.type_of(expr.test) test_type: Type = self.type_of(expr.test)
@@ -1108,3 +1126,93 @@ class PythonTyper(
return p.BaseType(location=location, base=name, param=None) return p.BaseType(location=location, base=name, param=None)
case _: case _:
raise NotImplementedError raise NotImplementedError
def _get_literal(self, expr: p.Expr) -> tuple[bool, Any]:
match expr:
case p.LiteralExpr(value=value):
return True, value
case p.ListExpr(items=items):
values: list[Any] = []
for item in items:
is_lit, value = self._get_literal(item)
if not is_lit:
return False, None
values.append(value)
return True, values
case p.DictExpr(keys=keys, values=values):
pairs: list[tuple[Any, Any]] = []
for key, value in zip(keys, values):
key_val = None
if key is not None:
is_lit, key_val = self._get_literal(key)
if not is_lit:
return False, None
is_lit, value_val = self._get_literal(value)
if not is_lit:
return False, None
if key is None:
# TODO: check that value is always a dict
assert isinstance(value_val, dict)
pairs.extend(value_val.items())
else:
pairs.append((key_val, value_val))
return True, dict(pairs)
case _:
return False, None
def _evaluate_cast_statically(
self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any
) -> bool:
match target_type:
case AliasType(type=base):
return self._evaluate_cast_statically(
expr, subject_type, base, lit_value
)
case AppliedType(body=body):
return self._evaluate_cast_statically(
expr, subject_type, body, lit_value
)
case ConstraintType(type=base, constraint=constraint):
evaluated: bool = True
if not self._evaluate_cast_statically(
expr, subject_type, base, lit_value
):
evaluated = False
evaluator = Evaluator(self.types)
evaluator.set_value("_", lit_value)
res = evaluator.evaluate(constraint)
if not res:
printer = MidasPrinter()
constraint_str: str = printer.print(constraint)
self.reporter.error(
expr.location,
f"Value {lit_value!r} does not fit constraint '{constraint_str}'",
)
evaluated = False
return evaluated
case BaseType():
# TODO: do we want to allow cast(float, int)? would require runtime conversion
if not self.types.is_subtype(
subject_type, target_type
) or not self.types.is_subtype(target_type, subject_type):
self.reporter.error(
expr.location,
f"Value {lit_value!r} of type {subject_type} cannot be cast as {target_type}",
)
return False
return True
case _:
self.reporter.info(
expr.location, f"Cannot evaluate cast to {target_type} statically"
)
return False

View File

@@ -61,3 +61,10 @@ class FileReporter:
location=location, location=location,
message=message, message=message,
) )
def debug(self, location: Location, message: str):
self.report(
type=DiagnosticType.DEBUG,
location=location,
message=message,
)

View File

@@ -59,6 +59,7 @@ class DiagnosticPrinter:
DiagnosticType.ERROR: Ansi.RED, DiagnosticType.ERROR: Ansi.RED,
DiagnosticType.WARNING: Ansi.YELLOW, DiagnosticType.WARNING: Ansi.YELLOW,
DiagnosticType.INFO: Ansi.CYAN, DiagnosticType.INFO: Ansi.CYAN,
DiagnosticType.DEBUG: Ansi.MAGENTA,
}.get(diagnostic.type, Ansi.WHITE) }.get(diagnostic.type, Ansi.WHITE)
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET

View File

@@ -44,6 +44,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._typed_ast: TypedAST = TypedAST( self._typed_ast: TypedAST = TypedAST(
stmts=[], stmts=[],
judgements=[], judgements=[],
evaluated_casts=[],
) )
self._alias_count: int = 0 self._alias_count: int = 0
self._predicate_count: int = 0 self._predicate_count: int = 0
@@ -131,6 +132,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr: def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
expr2: ast.expr = expr.expr.accept(self) expr2: ast.expr = expr.expr.accept(self)
if expr in self._typed_ast.evaluated_casts:
return expr2
alias: ast.expr = self._make_alias(expr2) alias: ast.expr = self._make_alias(expr2)
type: Type = self._get_expr_type(expr) type: Type = self._get_expr_type(expr)

View File

@@ -62,3 +62,4 @@ class UniversalJSONDumper:
class TypedAST: class TypedAST:
stmts: list[p.Stmt] stmts: list[p.Stmt]
judgements: list[tuple[p.Expr, Type]] judgements: list[tuple[p.Expr, Type]]
evaluated_casts: list[p.CastExpr]

View File

@@ -1,6 +1,19 @@
{ {
"diagnostics": [], "diagnostics": [],
"judgments": [ "judgments": [
{
"location": {
"from": "L4:30",
"to": "L4:36"
},
"expr": {
"_type": "LiteralExpr",
"value": 123.45
},
"type": {
"name": "float"
}
},
{ {
"location": { "location": {
"from": "L4:18", "from": "L4:18",
@@ -25,6 +38,19 @@
} }
} }
}, },
{
"location": {
"from": "L5:28",
"to": "L5:31"
},
"expr": {
"_type": "LiteralExpr",
"value": 6.7
},
"type": {
"name": "float"
}
},
{ {
"location": { "location": {
"from": "L5:15", "from": "L5:15",

View File

@@ -7,68 +7,14 @@ Module(
alias(name='Meter'), alias(name='Meter'),
alias(name='Second')], alias(name='Second')],
level=0), level=0),
Assign(
targets=[
Name(id='__midas_a0__')],
value=Constant(value=123.45)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_a0__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='01_simple_types.py:L3:19: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_a0__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assign( Assign(
targets=[ targets=[
Name(id='distance')], Name(id='distance')],
value=Name(id='__midas_a0__')), value=Constant(value=123.45)),
Delete(
targets=[
Name(id='__midas_a0__')]),
Assign(
targets=[
Name(id='__midas_a1__')],
value=Constant(value=6.7)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_a1__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='01_simple_types.py:L4:16: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_a1__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assign( Assign(
targets=[ targets=[
Name(id='time')], Name(id='time')],
value=Name(id='__midas_a1__')), value=Constant(value=6.7)),
Delete(
targets=[
Name(id='__midas_a1__')]),
Assign( Assign(
targets=[ targets=[
Name(id='speed')], Name(id='speed')],