fix(checker): handle all operations and calls in predicates

This commit is contained in:
2026-06-19 21:10:33 +02:00
parent 43d2118db7
commit 8461d05fa6
3 changed files with 480 additions and 39 deletions

View File

@@ -4,22 +4,27 @@ from pathlib import Path
from typing import Optional
import midas.ast.midas as m
from midas.ast.location import Location
from midas.checker.builtins import define_builtins
from midas.checker.environment import Environment
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import (
AliasType,
AppliedType,
ComplexType,
ConstraintType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
Predicate,
Type,
TypeVar,
UnknownType,
unfold_type,
)
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
@@ -33,6 +38,26 @@ class TypedParamSpec:
kw: list[Function.Argument]
TypedExpr = tuple[m.Expr, Type]
class ReturnException(Exception):
pass
@dataclass(frozen=True, kw_only=True)
class MappedArgument:
expr: m.Expr
type: Type
argument: Function.Argument
@dataclass(frozen=True, kw_only=True)
class OverloadCandidate:
function: Function
mapped: list[MappedArgument]
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry"""
@@ -197,42 +222,81 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
return self._bool
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
# TODO
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
return UnknownType()
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
# TODO
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
return UnknownType()
def visit_call_expr(self, expr: m.CallExpr) -> Type:
callee: Type = expr.callee.accept(self)
if not isinstance(callee, Function):
self.reporter.error(expr.location, f"Cannot call {callee}")
return UnknownType()
args: list[Type] = [arg.accept(self) for arg in expr.arguments]
n_args: int = len(args)
n_params: int = len(callee.args)
if n_args != n_params:
self.reporter.error(
expr.location,
f"Wrong number of argument, expected {n_params}, got {n_args}",
method: Optional[str] = MIDAS_BINARY_METHODS.get(expr.operator.type)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.reporter.warning(
expr.location, f"Unsupported operator {expr.operator}"
)
return UnknownType()
valid: bool = True
for arg, param in zip(args, callee.args):
if not self.types.is_subtype(arg, param.type):
self.reporter.error(
expr.location,
f"Invalid argument type at pos {param.pos}, expected {param.type}, got {arg}",
)
valid = False
if not valid:
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
def _visit_binary_expr(
self, location: Location, left_expr: m.Expr, right_expr: m.Expr, method: str
) -> Type:
left: Type = self.type_of(left_expr)
right: Type = self.type_of(right_expr)
operation: Optional[Type] = self.types.lookup_member(left, method)
if operation is None:
self.reporter.error(
location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
return callee.returns
result: Optional[Type] = self._get_call_result(
location,
operation,
[(right_expr, right)],
{},
)
return result or UnknownType()
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.reporter.warning(
expr.location, f"Unsupported operator {expr.operator}"
)
return UnknownType()
operand: Type = self.type_of(expr.right)
operation: Optional[Type] = self.types.lookup_member(operand, method)
if operation is None:
self.reporter.error(
expr.location,
f"Undefined operation {method} for {operand}",
)
return UnknownType()
result: Optional[Type] = self._get_call_result(
expr.location,
operation,
[],
{},
)
return result or UnknownType()
def visit_call_expr(self, expr: m.CallExpr) -> Type:
callee: Type = expr.callee.accept(self)
positional: list[TypedExpr] = [
(arg, self.type_of(arg)) for arg in expr.arguments
]
keywords: dict[str, TypedExpr] = {
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
}
return (
self._get_call_result(
expr.location,
callee,
positional,
keywords,
)
or UnknownType()
)
def visit_get_expr(self, expr: m.GetExpr) -> Type:
object: Type = expr.expr.accept(self)
@@ -344,3 +408,343 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
self._local_variables[name] = var
vars.append(var)
return vars
def _get_call_result(
self,
location: Location,
callee: Type,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> Optional[Type]:
"""Get the result type of a function call
If the function has overloads, the function will try to resolve the
appropriate signature.
Argument types are matched to the defined parameters.
The function doesn't take the raw expression as a parameter to accommodate
for desugared calls such as for operators.
Args:
location (Location): the call location
callee (Type): the called function
positional (list[TypedExpr]): the list positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Type: the return type of the call, or `None` if either
the call is invalid or no overload matched the arguments uniquely
"""
match callee:
case Function() as function:
valid: bool
mapped: list[MappedArgument]
valid, mapped = self.map_call_arguments(
function, location, positional, keywords
)
valid = valid and self._are_arguments_valid(mapped, report_errors)
if not valid:
return None
return function.returns
case OverloadedFunction(overloads=overloads):
function = self._match_overload(
overloads, location, positional, keywords, report_errors
)
if function is None:
return None
return function.returns
case AppliedType(body=body):
return self._get_call_result(
location, body, positional, keywords, report_errors
)
case UnknownType():
return UnknownType()
case _:
if report_errors:
self.reporter.error(location, f"{callee} is not callable")
return None
def _are_arguments_valid(
self,
arguments: list[MappedArgument],
report_errors: bool = True,
) -> bool:
"""Check whether the passed argument types correspond to their matched parameter definitions
Args:
arguments (list[MappedArgument]): the list of argument/parameter pairs
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
bool: True if all arguments fit the matching parameter definitions, False otherwise
"""
valid: bool = True
for arg in arguments:
if not self.types.is_subtype(arg.type, arg.argument.type):
if report_errors:
self.reporter.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
)
valid = False
return valid
def _match_overload(
self,
overloads: list[Type],
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> Optional[Function]:
"""Try and resolve the appropriate overload for the given arguments
Args:
overloads (list[Type]): the list of possible overloads
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keywords arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Optional[Function]: the resolved function signature if it can be
determined unambiguously, or `None`.
"""
candidates: list[OverloadCandidate] = []
for overload in overloads:
function: Type = unfold_type(overload)
if not isinstance(function, Function):
if report_errors:
self.logger.error(
f"Overload is not a function: {overload} is {function}"
)
continue
valid, mapped = self.map_call_arguments(
function=function,
location=location,
positional=positional,
keywords=keywords,
report_errors=False,
)
if valid and self._are_arguments_valid(mapped, report_errors=False):
candidates.append(
OverloadCandidate(
function=function,
mapped=mapped,
)
)
pos_types: str = ", ".join(str(type) for _, type in positional)
kw_types: str = ", ".join(
f"{name}: {type}" for name, (_, type) in keywords.items()
)
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
n_candidates: int = len(candidates)
# Exactly 1 match -> return it
if n_candidates == 1:
return candidates[0].function
# No match -> invalid call
if n_candidates == 0:
overloads_str: str = ", ".join(map(str, overloads))
if report_errors:
self.reporter.error(
location,
f"No matching overload in [{overloads_str}] {for_args}",
)
return None
# Multiple matches -> see if one <: all others (more specific)
for i1, c1 in enumerate(candidates):
mapped1: list[MappedArgument] = c1.mapped
best_match: bool = True
for i2, c2 in enumerate(candidates):
if i1 == i2:
continue
mapped2: list[MappedArgument] = c2.mapped
if not self._are_mapped_subtypes(mapped1, mapped2):
best_match = False
break
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
if best_match:
return c1.function
candidates_str: str = ", ".join(
str(candidate.function) for candidate in candidates
)
if report_errors:
self.reporter.error(
location,
f"Multiple matching overloads {for_args}: {candidates_str}",
)
return None
def map_call_arguments(
self,
function: Function,
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> tuple[bool, list[MappedArgument]]:
"""Map call arguments to a function's parameters as defined in its signature
This method maps positional-only, keyword-only and mixed parameter definitions
with the arguments passed at the call site
Any mismatched, missing or unexpected argument is reported as a diagnostic,
unless `report_errors` is set to `False`
Args:
function (Function): the function definition
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
tuple[bool, list[MappedArgument]]: a boolean reporting whether
the call is valid and the list of mapped arguments
"""
set_args: set[str] = set()
required_positional: list[str] = [
arg.name for arg in function.pos_args + function.args if arg.required
]
required_keyword: list[str] = [
arg.name for arg in function.kw_args if arg.required
]
mapped: list[MappedArgument] = []
pos_params: list[Function.Argument] = list(function.pos_args)
mixed_params: list[Function.Argument] = list(function.args)
kw_params: dict[str, Function.Argument] = {
arg.name: arg for arg in function.kw_args
}
valid_call: bool = True
# TODO: handle *args and **kwargs sinks
for arg in positional:
param: Function.Argument
if len(pos_params) != 0:
param = pos_params.pop(0)
elif len(mixed_params) != 0:
param = mixed_params.pop(0)
else:
if report_errors:
self.reporter.error(
arg[0].location, "Too many positional arguments"
)
valid_call = False
break
name: str = param.name
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
kw_params.update({arg.name: arg for arg in mixed_params})
for name, arg in keywords.items():
param: Function.Argument
if name not in kw_params:
if report_errors:
if name in set_args:
self.reporter.error(
arg[0].location, f"Multiple values for argument '{name}'"
)
else:
self.reporter.error(
arg[0].location, f"Unknown keyword argument '{name}'"
)
valid_call = False
continue
param = kw_params.pop(name)
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
def join_args(args: list[str]) -> str:
args = list(map(lambda a: f"'{a}'", args))
if len(args) == 0:
return ""
if len(args) == 1:
return args[0]
return ", ".join(args[:-1]) + " and " + args[-1]
if len(required_positional) != 0:
plural: str = "" if len(required_positional) == 1 else "s"
args: str = join_args(required_positional)
if report_errors:
self.reporter.error(
location,
f"Missing required positional argument{plural}: {args}",
)
valid_call = False
if len(required_keyword) != 0:
plural: str = "" if len(required_keyword) == 1 else "s"
args: str = join_args(required_keyword)
if report_errors:
self.reporter.error(
location,
f"Missing required keyword argument{plural}: {args}",
)
valid_call = False
return valid_call, mapped
def _are_mapped_subtypes(
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
) -> bool:
"""Check whether the given argument mappings are subtype/supertype of one another
This function checks whether the argument mappings `mapped1` are subtypes
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
of the corresponding parameter in `mapped2`, `False` is returned.
This is used to check whether a given overload is
a more specific function/ a subtype of another.
Args:
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
Returns:
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
"""
by_expr: dict[m.Expr, Type] = {}
for arg in mapped1:
by_expr[arg.expr] = arg.argument.type
for arg in mapped2:
type2: Type = arg.argument.type
type1: Type = by_expr[arg.expr]
if not self.types.is_subtype(type1, type2):
return False
return True

View File

@@ -1,7 +1,9 @@
import ast
from typing import Type
OPERATOR_METHODS: dict[Type[ast.operator], str] = {
from midas.lexer.token import TokenType
PY_OPERATOR_METHODS: dict[Type[ast.operator], str] = {
ast.Add: "__add__",
ast.Sub: "__sub__",
ast.Mult: "__mul__",
@@ -17,7 +19,7 @@ OPERATOR_METHODS: dict[Type[ast.operator], str] = {
ast.FloorDiv: "__floordiv__",
}
COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
PY_COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
ast.Eq: "__eq__",
# ast.NotEq: "__noteq__",
ast.Lt: "__lt__",
@@ -30,9 +32,40 @@ COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
# ast.NotIn: "__notin__",
}
UNARY_METHODS: dict[Type[ast.unaryop], str] = {
PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = {
ast.Invert: "__invert__",
# ast.Not: "",
ast.UAdd: "__pos__",
ast.USub: "__neg__",
}
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
# TokenType.PLUS: "__add__",
TokenType.MINUS: "__sub__",
TokenType.STAR: "__mul__",
TokenType.SLASH: "__truediv__",
# TokenType.MODULO: "__mod__",
# TokenType.POW: "__pow__",
# ast.BitOr: "__or__",
# ast.BitXor: "__xor__",
# ast.BitAnd: "__and__",
# ast.FloorDiv: "__floordiv__",
TokenType.EQUAL_EQUAL: "__eq__",
# ast.NotEq: "__noteq__",
TokenType.LESS: "__lt__",
TokenType.LESS_EQUAL: "__le__",
TokenType.GREATER: "__gt__",
TokenType.GREATER_EQUAL: "__ge__",
# ast.Is: "__is__",
# ast.IsNot: "__isnot__",
# ast.In: "__in__",
# ast.NotIn: "__notin__",
}
MIDAS_UNARY_METHODS: dict[TokenType, str] = {
# ast.Invert: "__invert__",
# ast.Not: "",
# TokenType.PLUS: "__pos__",
TokenType.MINUS: "__neg__",
}

View File

@@ -6,7 +6,11 @@ from typing import Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.environment import Environment
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS
from midas.checker.operators import (
PY_COMPARATOR_METHODS,
PY_OPERATOR_METHODS,
PY_UNARY_METHODS,
)
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
@@ -376,7 +380,7 @@ class PythonTyper(
pass
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
method: Optional[str] = PY_OPERATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.reporter.warning(
@@ -387,7 +391,7 @@ class PythonTyper(
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
method: Optional[str] = PY_COMPARATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.reporter.warning(
@@ -420,7 +424,7 @@ class PythonTyper(
return result or UnknownType()
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
method: Optional[str] = PY_UNARY_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.reporter.warning(