From 6e717a3f9e627b8a3d02c40835616f49af7aff3e Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Wed, 1 Jul 2026 11:24:09 +0200 Subject: [PATCH] refactor: use CallDispatcher in Midas typer --- midas/checker/midas.py | 392 +++-------------------------------------- 1 file changed, 24 insertions(+), 368 deletions(-) diff --git a/midas/checker/midas.py b/midas/checker/midas.py index 01aa09f..b6048cb 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -6,13 +6,13 @@ from typing import Optional import midas.ast.midas as m from midas.ast.location import Location from midas.checker.builtins import define_builtins +from midas.checker.dispatcher import CallDispatcher, CallResult from midas.checker.environment import Environment from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS from midas.checker.preamble import Preamble from midas.checker.registry import TypesRegistry from midas.checker.reporter import FileReporter, Reporter from midas.checker.types import ( - AppliedType, ColumnType, ComplexType, ConstraintType, @@ -21,12 +21,10 @@ from midas.checker.types import ( ExtensionType, Function, GenericType, - OverloadedFunction, Predicate, Type, TypeVar, UnknownType, - unfold_type, ) from midas.checker.variance import VarianceInferrer from midas.lexer.midas import MidasLexer @@ -41,9 +39,6 @@ class TypedParamSpec: kw: list[Function.Argument] -TypedExpr = tuple[m.Expr, Type] - - class ReturnException(Exception): pass @@ -259,13 +254,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type ) return UnknownType() - result: Optional[Type] = self._get_call_result( - location, - operation, - [(right_expr, right)], - {}, + dispatcher = CallDispatcher(self.types, self.reporter) + result: CallResult = dispatcher.get_result( + location=location, + callee=operation, + positional=[(right_expr, right)], + keywords={}, ) - return result or UnknownType() + return result.result def visit_unary_expr(self, expr: m.UnaryExpr) -> Type: method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type) @@ -285,31 +281,31 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type ) return UnknownType() - result: Optional[Type] = self._get_call_result( - expr.location, - operation, - [], - {}, + dispatcher = CallDispatcher(self.types, self.reporter) + result: CallResult = dispatcher.get_result( + location=expr.location, + callee=operation, + positional=[], + keywords={}, ) - return result or UnknownType() + return result.result def visit_call_expr(self, expr: m.CallExpr) -> Type: callee: Type = expr.callee.accept(self) - positional: list[TypedExpr] = [ + positional: list[tuple[m.Expr, Type]] = [ (arg, self.type_of(arg)) for arg in expr.arguments ] - keywords: dict[str, TypedExpr] = { + keywords: dict[str, tuple[m.Expr, Type]] = { name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items() } - return ( - self._get_call_result( - expr.location, - callee, - positional, - keywords, - ) - or UnknownType() + dispatcher = CallDispatcher(self.types, self.reporter) + result: CallResult = dispatcher.get_result( + location=expr.location, + callee=callee, + positional=positional, + keywords=keywords, ) + return result.result def visit_get_expr(self, expr: m.GetExpr) -> Type: object: Type = expr.expr.accept(self) @@ -433,343 +429,3 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type self._local_variables[name] = var vars.append(var) return vars - - def _get_call_result( - self, - location: Location, - callee: Type, - positional: list[TypedExpr], - keywords: dict[str, TypedExpr], - report_errors: bool = True, - ) -> Optional[Type]: - """Get the result type of a function call - - If the function has overloads, the function will try to resolve the - appropriate signature. - Argument types are matched to the defined parameters. - The function doesn't take the raw expression as a parameter to accommodate - for desugared calls such as for operators. - - Args: - location (Location): the call location - callee (Type): the called function - positional (list[TypedExpr]): the list positional arguments - keywords (dict[str, TypedExpr]): the map of keyword arguments - report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. - - Returns: - Type: the return type of the call, or `None` if either - the call is invalid or no overload matched the arguments uniquely - """ - match callee: - case Function() as function: - valid: bool - mapped: list[MappedArgument] - valid, mapped = self.map_call_arguments( - function, location, positional, keywords - ) - valid = valid and self._are_arguments_valid(mapped, report_errors) - if not valid: - return None - return function.returns - - case OverloadedFunction(overloads=overloads): - function = self._match_overload( - overloads, location, positional, keywords, report_errors - ) - if function is None: - return None - return function.returns - - case AppliedType(body=body): - return self._get_call_result( - location, body, positional, keywords, report_errors - ) - - case UnknownType(): - return UnknownType() - - case _: - if report_errors: - self.reporter.error(location, f"{callee} is not callable") - return None - - def _are_arguments_valid( - self, - arguments: list[MappedArgument], - report_errors: bool = True, - ) -> bool: - """Check whether the passed argument types correspond to their matched parameter definitions - - Args: - arguments (list[MappedArgument]): the list of argument/parameter pairs - report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. - - Returns: - bool: True if all arguments fit the matching parameter definitions, False otherwise - """ - valid: bool = True - for arg in arguments: - if not self.types.is_subtype(arg.type, arg.argument.type): - if report_errors: - self.reporter.error( - arg.expr.location, - f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", - ) - valid = False - return valid - - def _match_overload( - self, - overloads: list[Type], - location: Location, - positional: list[TypedExpr], - keywords: dict[str, TypedExpr], - report_errors: bool = True, - ) -> Optional[Function]: - """Try and resolve the appropriate overload for the given arguments - - Args: - overloads (list[Type]): the list of possible overloads - location (Location): the call location - positional (list[TypedExpr]): the list of positional arguments - keywords (dict[str, TypedExpr]): the map of keywords arguments - report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. - - Returns: - Optional[Function]: the resolved function signature if it can be - determined unambiguously, or `None`. - """ - candidates: list[OverloadCandidate] = [] - for overload in overloads: - function: Type = unfold_type(overload) - if not isinstance(function, Function): - if report_errors: - self.logger.error( - f"Overload is not a function: {overload} is {function}" - ) - continue - valid, mapped = self.map_call_arguments( - function=function, - location=location, - positional=positional, - keywords=keywords, - report_errors=False, - ) - if valid and self._are_arguments_valid(mapped, report_errors=False): - candidates.append( - OverloadCandidate( - function=function, - mapped=mapped, - ) - ) - - pos_types: str = ", ".join(str(type) for _, type in positional) - kw_types: str = ", ".join( - f"{name}: {type}" for name, (_, type) in keywords.items() - ) - for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}" - - n_candidates: int = len(candidates) - - # Exactly 1 match -> return it - if n_candidates == 1: - return candidates[0].function - - # No match -> invalid call - if n_candidates == 0: - overloads_str: str = ", ".join(map(str, overloads)) - if report_errors: - self.reporter.error( - location, - f"No matching overload in [{overloads_str}] {for_args}", - ) - return None - - # Multiple matches -> see if one <: all others (more specific) - for i1, c1 in enumerate(candidates): - mapped1: list[MappedArgument] = c1.mapped - best_match: bool = True - for i2, c2 in enumerate(candidates): - if i1 == i2: - continue - mapped2: list[MappedArgument] = c2.mapped - if not self._are_mapped_subtypes(mapped1, mapped2): - best_match = False - break - self.logger.debug(f"{c1.function} is a full overload of {c2.function}") - if best_match: - return c1.function - - candidates_str: str = ", ".join( - str(candidate.function) for candidate in candidates - ) - if report_errors: - self.reporter.error( - location, - f"Multiple matching overloads {for_args}: {candidates_str}", - ) - return None - - def map_call_arguments( - self, - function: Function, - location: Location, - positional: list[TypedExpr], - keywords: dict[str, TypedExpr], - report_errors: bool = True, - ) -> tuple[bool, list[MappedArgument]]: - """Map call arguments to a function's parameters as defined in its signature - - This method maps positional-only, keyword-only and mixed parameter definitions - with the arguments passed at the call site - - Any mismatched, missing or unexpected argument is reported as a diagnostic, - unless `report_errors` is set to `False` - - Args: - function (Function): the function definition - location (Location): the call location - positional (list[TypedExpr]): the list of positional arguments - keywords (dict[str, TypedExpr]): the map of keyword arguments - report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. - - Returns: - tuple[bool, list[MappedArgument]]: a boolean reporting whether - the call is valid and the list of mapped arguments - """ - set_args: set[str] = set() - - required_positional: list[str] = [ - arg.name for arg in function.pos_args + function.args if arg.required - ] - required_keyword: list[str] = [ - arg.name for arg in function.kw_args if arg.required - ] - - mapped: list[MappedArgument] = [] - - pos_params: list[Function.Argument] = list(function.pos_args) - mixed_params: list[Function.Argument] = list(function.args) - kw_params: dict[str, Function.Argument] = { - arg.name: arg for arg in function.kw_args - } - - valid_call: bool = True - - # TODO: handle *args and **kwargs sinks - for arg in positional: - param: Function.Argument - if len(pos_params) != 0: - param = pos_params.pop(0) - elif len(mixed_params) != 0: - param = mixed_params.pop(0) - else: - if report_errors: - self.reporter.error( - arg[0].location, "Too many positional arguments" - ) - valid_call = False - break - name: str = param.name - if name in required_positional: - required_positional.remove(name) - if name in required_keyword: - required_keyword.remove(name) - set_args.add(name) - mapped.append( - MappedArgument( - expr=arg[0], - type=arg[1], - argument=param, - ) - ) - - kw_params.update({arg.name: arg for arg in mixed_params}) - for name, arg in keywords.items(): - param: Function.Argument - if name not in kw_params: - if report_errors: - if name in set_args: - self.reporter.error( - arg[0].location, f"Multiple values for argument '{name}'" - ) - else: - self.reporter.error( - arg[0].location, f"Unknown keyword argument '{name}'" - ) - valid_call = False - continue - param = kw_params.pop(name) - if name in required_positional: - required_positional.remove(name) - if name in required_keyword: - required_keyword.remove(name) - set_args.add(name) - mapped.append( - MappedArgument( - expr=arg[0], - type=arg[1], - argument=param, - ) - ) - - def join_args(args: list[str]) -> str: - args = list(map(lambda a: f"'{a}'", args)) - if len(args) == 0: - return "" - if len(args) == 1: - return args[0] - return ", ".join(args[:-1]) + " and " + args[-1] - - if len(required_positional) != 0: - plural: str = "" if len(required_positional) == 1 else "s" - args: str = join_args(required_positional) - if report_errors: - self.reporter.error( - location, - f"Missing required positional argument{plural}: {args}", - ) - valid_call = False - - if len(required_keyword) != 0: - plural: str = "" if len(required_keyword) == 1 else "s" - args: str = join_args(required_keyword) - if report_errors: - self.reporter.error( - location, - f"Missing required keyword argument{plural}: {args}", - ) - valid_call = False - - return valid_call, mapped - - def _are_mapped_subtypes( - self, mapped1: list[MappedArgument], mapped2: list[MappedArgument] - ) -> bool: - """Check whether the given argument mappings are subtype/supertype of one another - - This function checks whether the argument mappings `mapped1` are subtypes - of `mapped2`. If any of the parameter type in `mapped1` is not a subtype - of the corresponding parameter in `mapped2`, `False` is returned. - - This is used to check whether a given overload is - a more specific function/ a subtype of another. - - Args: - mapped1 (list[MappedArgument]): the first argument mappings (subtype) - mapped2 (list[MappedArgument]): the second argument mappings (supertype) - - Returns: - bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise - """ - by_expr: dict[m.Expr, Type] = {} - for arg in mapped1: - by_expr[arg.expr] = arg.argument.type - - for arg in mapped2: - type2: Type = arg.argument.type - type1: Type = by_expr[arg.expr] - if not self.types.is_subtype(type1, type2): - return False - return True