diff --git a/midas/checker/checker.py b/midas/checker/checker.py index dcf90b0..13b7855 100644 --- a/midas/checker/checker.py +++ b/midas/checker/checker.py @@ -215,7 +215,120 @@ class Checker( return False return True + case (Function(returns=return1), Function(returns=return2)): + if not self.is_func_subtype(type1, type2): + return False + if not self.is_subtype(return1, return2): + return False + return True + return False + + # TODO: verify the logic in here + def is_func_subtype(self, func1: Function, func2: Function) -> bool: + """Check whether a function is a subtype of another + + Args: + func1 (Function): the potential function subtype + func2 (Function): the potential function supertype + + Returns: + bool: whether `func1` is a subtype of `func2` + """ + if not self.is_subtype(func1.returns, func2.returns): + return False + + pos1: list[Function.Argument] = func1.pos_args + mixed1: list[Function.Argument] = func1.args + kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args} + pos2: list[Function.Argument] = func2.pos_args + mixed2: list[Function.Argument] = func2.args + kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args} + + mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2} + mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2} + + def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool: + if not self.is_subtype(sub.type, sup.type): + return False + if not sup.required and sub.required: + return False + return True + + for arg1 in pos1: + arg2: Function.Argument + if arg1.pos < len(pos2): + arg2 = pos2[arg1.pos] + elif arg1.pos in mixed_by_pos: + arg2 = mixed_by_pos[arg1.pos] + elif not arg1.required: + continue + else: + return False + if not is_arg_subtype(arg2, arg1): + return False + + for name, arg1 in kw1.items(): + arg2: Function.Argument + if name in kw2: + arg2 = kw2[name] + elif name in mixed_by_name: + arg2 = mixed_by_name[name] + elif not arg1.required: + continue + else: + return False + if not is_arg_subtype(arg2, arg1): + return False + + for arg1 in mixed1: + pos_arg2: Optional[Function.Argument] = None + kw_arg2: Optional[Function.Argument] = None + if arg1.name in kw2: + kw_arg2 = kw2[arg1.name] + elif arg1.name in mixed_by_name: + kw_arg2 = mixed_by_name[arg1.name] + if arg1.pos < len(pos2): + pos_arg2 = pos2[arg1.pos] + elif arg1.pos in mixed_by_pos: + pos_arg2 = mixed_by_pos[arg1.pos] + + # No match in func2 and arg is required + if pos_arg2 is None and kw_arg2 is None and arg1.required: + return False + + # Matching keyword argument + if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1): + return False + + # Matching positional argument + if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1): + return False + + mixed_positions: set[int] = {a.pos for a in mixed1} + mixed_names: set[str] = {a.name for a in mixed1} + for arg2 in pos2: + if not arg2.required: + continue + if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions: + return False + + for name, arg2 in kw2.items(): + if not arg2.required: + continue + if name not in kw1 and name not in mixed_names: + return False + + for arg2 in mixed2: + if arg2.required: + continue + pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions + kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names + if not pos_match or not kw_match: + return False + + return True + def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: self.type_of(stmt.expr)