diff --git a/midas/checker/midas.py b/midas/checker/midas.py index f989152..1fe1490 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -4,22 +4,27 @@ from pathlib import Path from typing import Optional import midas.ast.midas as m +from midas.ast.location import Location from midas.checker.builtins import define_builtins from midas.checker.environment import Environment +from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS from midas.checker.preamble import Preamble from midas.checker.registry import TypesRegistry from midas.checker.reporter import FileReporter, Reporter from midas.checker.types import ( AliasType, + AppliedType, ComplexType, ConstraintType, ExtensionType, Function, GenericType, + OverloadedFunction, Predicate, Type, TypeVar, UnknownType, + unfold_type, ) from midas.lexer.midas import MidasLexer from midas.lexer.token import Token @@ -33,6 +38,26 @@ class TypedParamSpec: kw: list[Function.Argument] +TypedExpr = tuple[m.Expr, Type] + + +class ReturnException(Exception): + pass + + +@dataclass(frozen=True, kw_only=True) +class MappedArgument: + expr: m.Expr + type: Type + argument: Function.Argument + + +@dataclass(frozen=True, kw_only=True) +class OverloadCandidate: + function: Function + mapped: list[MappedArgument] + + class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]): """A resolver which evaluates Midas type definitions and build a registry""" @@ -197,42 +222,81 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type return self._bool def visit_binary_expr(self, expr: m.BinaryExpr) -> Type: - # TODO - self.reporter.warning(expr.location, "BinaryExpr not yet supported") - return UnknownType() - - def visit_unary_expr(self, expr: m.UnaryExpr) -> Type: - # TODO - self.reporter.warning(expr.location, "UnaryExpr not yet supported") - return UnknownType() - - def visit_call_expr(self, expr: m.CallExpr) -> Type: - callee: Type = expr.callee.accept(self) - if not isinstance(callee, Function): - self.reporter.error(expr.location, f"Cannot call {callee}") - return UnknownType() - args: list[Type] = [arg.accept(self) for arg in expr.arguments] - - n_args: int = len(args) - n_params: int = len(callee.args) - if n_args != n_params: - self.reporter.error( - expr.location, - f"Wrong number of argument, expected {n_params}, got {n_args}", + method: Optional[str] = MIDAS_BINARY_METHODS.get(expr.operator.type) + if method is None: + self.logger.warning(f"Unsupported operator {expr.operator}") + self.reporter.warning( + expr.location, f"Unsupported operator {expr.operator}" ) return UnknownType() - valid: bool = True - for arg, param in zip(args, callee.args): - if not self.types.is_subtype(arg, param.type): - self.reporter.error( - expr.location, - f"Invalid argument type at pos {param.pos}, expected {param.type}, got {arg}", - ) - valid = False - if not valid: + return self._visit_binary_expr(expr.location, expr.left, expr.right, method) + + def _visit_binary_expr( + self, location: Location, left_expr: m.Expr, right_expr: m.Expr, method: str + ) -> Type: + left: Type = self.type_of(left_expr) + right: Type = self.type_of(right_expr) + + operation: Optional[Type] = self.types.lookup_member(left, method) + if operation is None: + self.reporter.error( + location, + f"Undefined operation {method} between {left} and {right}", + ) return UnknownType() - return callee.returns + + result: Optional[Type] = self._get_call_result( + location, + operation, + [(right_expr, right)], + {}, + ) + return result or UnknownType() + + def visit_unary_expr(self, expr: m.UnaryExpr) -> Type: + method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type) + if method is None: + self.logger.warning(f"Unsupported operator {expr.operator}") + self.reporter.warning( + expr.location, f"Unsupported operator {expr.operator}" + ) + return UnknownType() + + operand: Type = self.type_of(expr.right) + operation: Optional[Type] = self.types.lookup_member(operand, method) + if operation is None: + self.reporter.error( + expr.location, + f"Undefined operation {method} for {operand}", + ) + return UnknownType() + + result: Optional[Type] = self._get_call_result( + expr.location, + operation, + [], + {}, + ) + return result or UnknownType() + + def visit_call_expr(self, expr: m.CallExpr) -> Type: + callee: Type = expr.callee.accept(self) + positional: list[TypedExpr] = [ + (arg, self.type_of(arg)) for arg in expr.arguments + ] + keywords: dict[str, TypedExpr] = { + name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items() + } + return ( + self._get_call_result( + expr.location, + callee, + positional, + keywords, + ) + or UnknownType() + ) def visit_get_expr(self, expr: m.GetExpr) -> Type: object: Type = expr.expr.accept(self) @@ -344,3 +408,343 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type self._local_variables[name] = var vars.append(var) return vars + + def _get_call_result( + self, + location: Location, + callee: Type, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + report_errors: bool = True, + ) -> Optional[Type]: + """Get the result type of a function call + + If the function has overloads, the function will try to resolve the + appropriate signature. + Argument types are matched to the defined parameters. + The function doesn't take the raw expression as a parameter to accommodate + for desugared calls such as for operators. + + Args: + location (Location): the call location + callee (Type): the called function + positional (list[TypedExpr]): the list positional arguments + keywords (dict[str, TypedExpr]): the map of keyword arguments + report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. + + Returns: + Type: the return type of the call, or `None` if either + the call is invalid or no overload matched the arguments uniquely + """ + match callee: + case Function() as function: + valid: bool + mapped: list[MappedArgument] + valid, mapped = self.map_call_arguments( + function, location, positional, keywords + ) + valid = valid and self._are_arguments_valid(mapped, report_errors) + if not valid: + return None + return function.returns + + case OverloadedFunction(overloads=overloads): + function = self._match_overload( + overloads, location, positional, keywords, report_errors + ) + if function is None: + return None + return function.returns + + case AppliedType(body=body): + return self._get_call_result( + location, body, positional, keywords, report_errors + ) + + case UnknownType(): + return UnknownType() + + case _: + if report_errors: + self.reporter.error(location, f"{callee} is not callable") + return None + + def _are_arguments_valid( + self, + arguments: list[MappedArgument], + report_errors: bool = True, + ) -> bool: + """Check whether the passed argument types correspond to their matched parameter definitions + + Args: + arguments (list[MappedArgument]): the list of argument/parameter pairs + report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. + + Returns: + bool: True if all arguments fit the matching parameter definitions, False otherwise + """ + valid: bool = True + for arg in arguments: + if not self.types.is_subtype(arg.type, arg.argument.type): + if report_errors: + self.reporter.error( + arg.expr.location, + f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", + ) + valid = False + return valid + + def _match_overload( + self, + overloads: list[Type], + location: Location, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + report_errors: bool = True, + ) -> Optional[Function]: + """Try and resolve the appropriate overload for the given arguments + + Args: + overloads (list[Type]): the list of possible overloads + location (Location): the call location + positional (list[TypedExpr]): the list of positional arguments + keywords (dict[str, TypedExpr]): the map of keywords arguments + report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. + + Returns: + Optional[Function]: the resolved function signature if it can be + determined unambiguously, or `None`. + """ + candidates: list[OverloadCandidate] = [] + for overload in overloads: + function: Type = unfold_type(overload) + if not isinstance(function, Function): + if report_errors: + self.logger.error( + f"Overload is not a function: {overload} is {function}" + ) + continue + valid, mapped = self.map_call_arguments( + function=function, + location=location, + positional=positional, + keywords=keywords, + report_errors=False, + ) + if valid and self._are_arguments_valid(mapped, report_errors=False): + candidates.append( + OverloadCandidate( + function=function, + mapped=mapped, + ) + ) + + pos_types: str = ", ".join(str(type) for _, type in positional) + kw_types: str = ", ".join( + f"{name}: {type}" for name, (_, type) in keywords.items() + ) + for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}" + + n_candidates: int = len(candidates) + + # Exactly 1 match -> return it + if n_candidates == 1: + return candidates[0].function + + # No match -> invalid call + if n_candidates == 0: + overloads_str: str = ", ".join(map(str, overloads)) + if report_errors: + self.reporter.error( + location, + f"No matching overload in [{overloads_str}] {for_args}", + ) + return None + + # Multiple matches -> see if one <: all others (more specific) + for i1, c1 in enumerate(candidates): + mapped1: list[MappedArgument] = c1.mapped + best_match: bool = True + for i2, c2 in enumerate(candidates): + if i1 == i2: + continue + mapped2: list[MappedArgument] = c2.mapped + if not self._are_mapped_subtypes(mapped1, mapped2): + best_match = False + break + self.logger.debug(f"{c1.function} is a full overload of {c2.function}") + if best_match: + return c1.function + + candidates_str: str = ", ".join( + str(candidate.function) for candidate in candidates + ) + if report_errors: + self.reporter.error( + location, + f"Multiple matching overloads {for_args}: {candidates_str}", + ) + return None + + def map_call_arguments( + self, + function: Function, + location: Location, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + report_errors: bool = True, + ) -> tuple[bool, list[MappedArgument]]: + """Map call arguments to a function's parameters as defined in its signature + + This method maps positional-only, keyword-only and mixed parameter definitions + with the arguments passed at the call site + + Any mismatched, missing or unexpected argument is reported as a diagnostic, + unless `report_errors` is set to `False` + + Args: + function (Function): the function definition + location (Location): the call location + positional (list[TypedExpr]): the list of positional arguments + keywords (dict[str, TypedExpr]): the map of keyword arguments + report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. + + Returns: + tuple[bool, list[MappedArgument]]: a boolean reporting whether + the call is valid and the list of mapped arguments + """ + set_args: set[str] = set() + + required_positional: list[str] = [ + arg.name for arg in function.pos_args + function.args if arg.required + ] + required_keyword: list[str] = [ + arg.name for arg in function.kw_args if arg.required + ] + + mapped: list[MappedArgument] = [] + + pos_params: list[Function.Argument] = list(function.pos_args) + mixed_params: list[Function.Argument] = list(function.args) + kw_params: dict[str, Function.Argument] = { + arg.name: arg for arg in function.kw_args + } + + valid_call: bool = True + + # TODO: handle *args and **kwargs sinks + for arg in positional: + param: Function.Argument + if len(pos_params) != 0: + param = pos_params.pop(0) + elif len(mixed_params) != 0: + param = mixed_params.pop(0) + else: + if report_errors: + self.reporter.error( + arg[0].location, "Too many positional arguments" + ) + valid_call = False + break + name: str = param.name + if name in required_positional: + required_positional.remove(name) + if name in required_keyword: + required_keyword.remove(name) + set_args.add(name) + mapped.append( + MappedArgument( + expr=arg[0], + type=arg[1], + argument=param, + ) + ) + + kw_params.update({arg.name: arg for arg in mixed_params}) + for name, arg in keywords.items(): + param: Function.Argument + if name not in kw_params: + if report_errors: + if name in set_args: + self.reporter.error( + arg[0].location, f"Multiple values for argument '{name}'" + ) + else: + self.reporter.error( + arg[0].location, f"Unknown keyword argument '{name}'" + ) + valid_call = False + continue + param = kw_params.pop(name) + if name in required_positional: + required_positional.remove(name) + if name in required_keyword: + required_keyword.remove(name) + set_args.add(name) + mapped.append( + MappedArgument( + expr=arg[0], + type=arg[1], + argument=param, + ) + ) + + def join_args(args: list[str]) -> str: + args = list(map(lambda a: f"'{a}'", args)) + if len(args) == 0: + return "" + if len(args) == 1: + return args[0] + return ", ".join(args[:-1]) + " and " + args[-1] + + if len(required_positional) != 0: + plural: str = "" if len(required_positional) == 1 else "s" + args: str = join_args(required_positional) + if report_errors: + self.reporter.error( + location, + f"Missing required positional argument{plural}: {args}", + ) + valid_call = False + + if len(required_keyword) != 0: + plural: str = "" if len(required_keyword) == 1 else "s" + args: str = join_args(required_keyword) + if report_errors: + self.reporter.error( + location, + f"Missing required keyword argument{plural}: {args}", + ) + valid_call = False + + return valid_call, mapped + + def _are_mapped_subtypes( + self, mapped1: list[MappedArgument], mapped2: list[MappedArgument] + ) -> bool: + """Check whether the given argument mappings are subtype/supertype of one another + + This function checks whether the argument mappings `mapped1` are subtypes + of `mapped2`. If any of the parameter type in `mapped1` is not a subtype + of the corresponding parameter in `mapped2`, `False` is returned. + + This is used to check whether a given overload is + a more specific function/ a subtype of another. + + Args: + mapped1 (list[MappedArgument]): the first argument mappings (subtype) + mapped2 (list[MappedArgument]): the second argument mappings (supertype) + + Returns: + bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise + """ + by_expr: dict[m.Expr, Type] = {} + for arg in mapped1: + by_expr[arg.expr] = arg.argument.type + + for arg in mapped2: + type2: Type = arg.argument.type + type1: Type = by_expr[arg.expr] + if not self.types.is_subtype(type1, type2): + return False + return True diff --git a/midas/checker/operators.py b/midas/checker/operators.py index 58af88c..60b8b25 100644 --- a/midas/checker/operators.py +++ b/midas/checker/operators.py @@ -1,7 +1,9 @@ import ast from typing import Type -OPERATOR_METHODS: dict[Type[ast.operator], str] = { +from midas.lexer.token import TokenType + +PY_OPERATOR_METHODS: dict[Type[ast.operator], str] = { ast.Add: "__add__", ast.Sub: "__sub__", ast.Mult: "__mul__", @@ -17,7 +19,7 @@ OPERATOR_METHODS: dict[Type[ast.operator], str] = { ast.FloorDiv: "__floordiv__", } -COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = { +PY_COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = { ast.Eq: "__eq__", # ast.NotEq: "__noteq__", ast.Lt: "__lt__", @@ -30,9 +32,40 @@ COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = { # ast.NotIn: "__notin__", } -UNARY_METHODS: dict[Type[ast.unaryop], str] = { +PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = { ast.Invert: "__invert__", # ast.Not: "", ast.UAdd: "__pos__", ast.USub: "__neg__", } + + +MIDAS_BINARY_METHODS: dict[TokenType, str] = { + # TokenType.PLUS: "__add__", + TokenType.MINUS: "__sub__", + TokenType.STAR: "__mul__", + TokenType.SLASH: "__truediv__", + # TokenType.MODULO: "__mod__", + # TokenType.POW: "__pow__", + # ast.BitOr: "__or__", + # ast.BitXor: "__xor__", + # ast.BitAnd: "__and__", + # ast.FloorDiv: "__floordiv__", + TokenType.EQUAL_EQUAL: "__eq__", + # ast.NotEq: "__noteq__", + TokenType.LESS: "__lt__", + TokenType.LESS_EQUAL: "__le__", + TokenType.GREATER: "__gt__", + TokenType.GREATER_EQUAL: "__ge__", + # ast.Is: "__is__", + # ast.IsNot: "__isnot__", + # ast.In: "__in__", + # ast.NotIn: "__notin__", +} + +MIDAS_UNARY_METHODS: dict[TokenType, str] = { + # ast.Invert: "__invert__", + # ast.Not: "", + # TokenType.PLUS: "__pos__", + TokenType.MINUS: "__neg__", +} diff --git a/midas/checker/python.py b/midas/checker/python.py index da54a96..6b2892b 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -6,7 +6,11 @@ from typing import Optional import midas.ast.python as p from midas.ast.location import Location from midas.checker.environment import Environment -from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS +from midas.checker.operators import ( + PY_COMPARATOR_METHODS, + PY_OPERATOR_METHODS, + PY_UNARY_METHODS, +) from midas.checker.preamble import Preamble from midas.checker.registry import TypesRegistry from midas.checker.reporter import FileReporter, Reporter @@ -376,7 +380,7 @@ class PythonTyper( pass def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: - method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) + method: Optional[str] = PY_OPERATOR_METHODS.get(expr.operator.__class__) if method is None: self.logger.warning(f"Unsupported operator {expr.operator}") self.reporter.warning( @@ -387,7 +391,7 @@ class PythonTyper( return self._visit_binary_expr(expr.location, expr.left, expr.right, method) def visit_compare_expr(self, expr: p.CompareExpr) -> Type: - method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__) + method: Optional[str] = PY_COMPARATOR_METHODS.get(expr.operator.__class__) if method is None: self.logger.warning(f"Unsupported operator {expr.operator}") self.reporter.warning( @@ -420,7 +424,7 @@ class PythonTyper( return result or UnknownType() def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: - method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__) + method: Optional[str] = PY_UNARY_METHODS.get(expr.operator.__class__) if method is None: self.logger.warning(f"Unsupported operator {expr.operator}") self.reporter.warning(