diff --git a/midas/checker/checker.py b/midas/checker/checker.py index 3100587..0c54fa4 100644 --- a/midas/checker/checker.py +++ b/midas/checker/checker.py @@ -6,13 +6,10 @@ from typing import Optional import midas.ast.midas as m import midas.ast.python as p from midas.ast.location import Location -from midas.checker.builtins import BUILTIN_SUBTYPES from midas.checker.diagnostic import Diagnostic, DiagnosticType from midas.checker.environment import Environment from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS from midas.checker.types import ( - AliasType, - BaseType, ComplexType, Function, Operation, @@ -180,149 +177,7 @@ class Checker( self.ctx.resolve(stmts) def is_subtype(self, type1: Type, type2: Type) -> bool: - """Check whether `type1` is a subtype of `type2` - - For more details on the rules checked here, see TAPL Chap. 15-16-17 - - Args: - type1 (Type): the potential subtype - type2 (Type): the potential supertype - - Returns: - bool: whether `type1` is a subtype of `type2` - """ - - if type1 == type2: - return True - - match (type1, type2): - case (AliasType(type=base1), _): - return self.is_subtype(base1, type2) - - case (BaseType(name=name1), BaseType(name=name2)): - return name1 in BUILTIN_SUBTYPES.get(name2, set()) - - case (ComplexType(properties=props1), ComplexType(properties=props2)): - for k, t in props2.items(): - if k not in props1: - return False - if not self.is_subtype(props1[k], t): - return False - return True - - case (Function(returns=return1), Function(returns=return2)): - if not self.is_func_subtype(type1, type2): - return False - if not self.is_subtype(return1, return2): - return False - return True - - return False - - # TODO: verify the logic in here - def is_func_subtype(self, func1: Function, func2: Function) -> bool: - """Check whether a function is a subtype of another - - Args: - func1 (Function): the potential function subtype - func2 (Function): the potential function supertype - - Returns: - bool: whether `func1` is a subtype of `func2` - """ - if not self.is_subtype(func1.returns, func2.returns): - return False - - pos1: list[Function.Argument] = func1.pos_args - mixed1: list[Function.Argument] = func1.args - kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args} - pos2: list[Function.Argument] = func2.pos_args - mixed2: list[Function.Argument] = func2.args - kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args} - - mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2} - mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2} - - def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool: - if not self.is_subtype(sub.type, sup.type): - return False - if not sup.required and sub.required: - return False - return True - - for arg1 in pos1: - arg2: Function.Argument - if arg1.pos < len(pos2): - arg2 = pos2[arg1.pos] - elif arg1.pos in mixed_by_pos: - arg2 = mixed_by_pos[arg1.pos] - elif not arg1.required: - continue - else: - return False - if not is_arg_subtype(arg2, arg1): - return False - - for name, arg1 in kw1.items(): - arg2: Function.Argument - if name in kw2: - arg2 = kw2[name] - elif name in mixed_by_name: - arg2 = mixed_by_name[name] - elif not arg1.required: - continue - else: - return False - if not is_arg_subtype(arg2, arg1): - return False - - for arg1 in mixed1: - pos_arg2: Optional[Function.Argument] = None - kw_arg2: Optional[Function.Argument] = None - if arg1.name in kw2: - kw_arg2 = kw2[arg1.name] - elif arg1.name in mixed_by_name: - kw_arg2 = mixed_by_name[arg1.name] - if arg1.pos < len(pos2): - pos_arg2 = pos2[arg1.pos] - elif arg1.pos in mixed_by_pos: - pos_arg2 = mixed_by_pos[arg1.pos] - - # No match in func2 and arg is required - if pos_arg2 is None and kw_arg2 is None and arg1.required: - return False - - # Matching keyword argument - if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1): - return False - - # Matching positional argument - if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1): - return False - - mixed_positions: set[int] = {a.pos for a in mixed1} - mixed_names: set[str] = {a.name for a in mixed1} - for arg2 in pos2: - if not arg2.required: - continue - if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions: - return False - - for name, arg2 in kw2.items(): - if not arg2.required: - continue - if name not in kw1 and name not in mixed_names: - return False - - for arg2 in mixed2: - if arg2.required: - continue - pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions - kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names - if not pos_match or not kw_match: - return False - - return True + return self.ctx.is_subtype(type1, type2) def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: self.type_of(stmt.expr) diff --git a/midas/resolver/midas.py b/midas/resolver/midas.py index 063a752..c7b168d 100644 --- a/midas/resolver/midas.py +++ b/midas/resolver/midas.py @@ -1,9 +1,12 @@ from typing import Optional import midas.ast.midas as m +from midas.checker.builtins import BUILTIN_SUBTYPES from midas.checker.types import ( AliasType, + BaseType, ComplexType, + Function, GenericType, Operation, Type, @@ -198,3 +201,144 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T prop.name.lexeme: prop.type.accept(self) for prop in type.properties } ) + + def is_subtype(self, type1: Type, type2: Type) -> bool: + """Check whether `type1` is a subtype of `type2` + + For more details on the rules checked here, see TAPL Chap. 15-16-17 + + Args: + type1 (Type): the potential subtype + type2 (Type): the potential supertype + + Returns: + bool: whether `type1` is a subtype of `type2` + """ + + if type1 == type2: + return True + + match (type1, type2): + case (AliasType(type=base1), _): + return self.is_subtype(base1, type2) + + case (BaseType(name=name1), BaseType(name=name2)): + return name1 in BUILTIN_SUBTYPES.get(name2, set()) + + case (ComplexType(properties=props1), ComplexType(properties=props2)): + for k, t in props2.items(): + if k not in props1: + return False + if not self.is_subtype(props1[k], t): + return False + return True + + case (Function(), Function()): + return self.is_func_subtype(type1, type2) + + return False + + # TODO: verify the logic in here + def is_func_subtype(self, func1: Function, func2: Function) -> bool: + """Check whether a function is a subtype of another + + Args: + func1 (Function): the potential function subtype + func2 (Function): the potential function supertype + + Returns: + bool: whether `func1` is a subtype of `func2` + """ + if not self.is_subtype(func1.returns, func2.returns): + return False + + pos1: list[Function.Argument] = func1.pos_args + mixed1: list[Function.Argument] = func1.args + kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args} + pos2: list[Function.Argument] = func2.pos_args + mixed2: list[Function.Argument] = func2.args + kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args} + + mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2} + mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2} + + def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool: + if not self.is_subtype(sub.type, sup.type): + return False + if not sup.required and sub.required: + return False + return True + + for arg1 in pos1: + arg2: Function.Argument + if arg1.pos < len(pos2): + arg2 = pos2[arg1.pos] + elif arg1.pos in mixed_by_pos: + arg2 = mixed_by_pos[arg1.pos] + elif not arg1.required: + continue + else: + return False + if not is_arg_subtype(arg2, arg1): + return False + + for name, arg1 in kw1.items(): + arg2: Function.Argument + if name in kw2: + arg2 = kw2[name] + elif name in mixed_by_name: + arg2 = mixed_by_name[name] + elif not arg1.required: + continue + else: + return False + if not is_arg_subtype(arg2, arg1): + return False + + for arg1 in mixed1: + pos_arg2: Optional[Function.Argument] = None + kw_arg2: Optional[Function.Argument] = None + if arg1.name in kw2: + kw_arg2 = kw2[arg1.name] + elif arg1.name in mixed_by_name: + kw_arg2 = mixed_by_name[arg1.name] + if arg1.pos < len(pos2): + pos_arg2 = pos2[arg1.pos] + elif arg1.pos in mixed_by_pos: + pos_arg2 = mixed_by_pos[arg1.pos] + + # No match in func2 and arg is required + if pos_arg2 is None and kw_arg2 is None and arg1.required: + return False + + # Matching keyword argument + if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1): + return False + + # Matching positional argument + if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1): + return False + + mixed_positions: set[int] = {a.pos for a in mixed1} + mixed_names: set[str] = {a.name for a in mixed1} + for arg2 in pos2: + if not arg2.required: + continue + if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions: + return False + + for name, arg2 in kw2.items(): + if not arg2.required: + continue + if name not in kw1 and name not in mixed_names: + return False + + for arg2 in mixed2: + if arg2.required: + continue + pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions + kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names + if not pos_match or not kw_match: + return False + + return True