feat(checker): implement function subtyping
the logic for checking function subtypes is a WIP and has not been fully tested, there may be some errors and unhandled edge cases Claude helped lay out and verify the overall steps Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -215,7 +215,120 @@ class Checker(
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
case (Function(returns=return1), Function(returns=return2)):
|
||||||
|
if not self.is_func_subtype(type1, type2):
|
||||||
|
return False
|
||||||
|
if not self.is_subtype(return1, return2):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# TODO: verify the logic in here
|
||||||
|
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
|
||||||
|
"""Check whether a function is a subtype of another
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func1 (Function): the potential function subtype
|
||||||
|
func2 (Function): the potential function supertype
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether `func1` is a subtype of `func2`
|
||||||
|
"""
|
||||||
|
if not self.is_subtype(func1.returns, func2.returns):
|
||||||
|
return False
|
||||||
|
|
||||||
|
pos1: list[Function.Argument] = func1.pos_args
|
||||||
|
mixed1: list[Function.Argument] = func1.args
|
||||||
|
kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args}
|
||||||
|
pos2: list[Function.Argument] = func2.pos_args
|
||||||
|
mixed2: list[Function.Argument] = func2.args
|
||||||
|
kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args}
|
||||||
|
|
||||||
|
mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2}
|
||||||
|
mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2}
|
||||||
|
|
||||||
|
def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool:
|
||||||
|
if not self.is_subtype(sub.type, sup.type):
|
||||||
|
return False
|
||||||
|
if not sup.required and sub.required:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
for arg1 in pos1:
|
||||||
|
arg2: Function.Argument
|
||||||
|
if arg1.pos < len(pos2):
|
||||||
|
arg2 = pos2[arg1.pos]
|
||||||
|
elif arg1.pos in mixed_by_pos:
|
||||||
|
arg2 = mixed_by_pos[arg1.pos]
|
||||||
|
elif not arg1.required:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
if not is_arg_subtype(arg2, arg1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for name, arg1 in kw1.items():
|
||||||
|
arg2: Function.Argument
|
||||||
|
if name in kw2:
|
||||||
|
arg2 = kw2[name]
|
||||||
|
elif name in mixed_by_name:
|
||||||
|
arg2 = mixed_by_name[name]
|
||||||
|
elif not arg1.required:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
if not is_arg_subtype(arg2, arg1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for arg1 in mixed1:
|
||||||
|
pos_arg2: Optional[Function.Argument] = None
|
||||||
|
kw_arg2: Optional[Function.Argument] = None
|
||||||
|
if arg1.name in kw2:
|
||||||
|
kw_arg2 = kw2[arg1.name]
|
||||||
|
elif arg1.name in mixed_by_name:
|
||||||
|
kw_arg2 = mixed_by_name[arg1.name]
|
||||||
|
if arg1.pos < len(pos2):
|
||||||
|
pos_arg2 = pos2[arg1.pos]
|
||||||
|
elif arg1.pos in mixed_by_pos:
|
||||||
|
pos_arg2 = mixed_by_pos[arg1.pos]
|
||||||
|
|
||||||
|
# No match in func2 and arg is required
|
||||||
|
if pos_arg2 is None and kw_arg2 is None and arg1.required:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Matching keyword argument
|
||||||
|
if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Matching positional argument
|
||||||
|
if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
mixed_positions: set[int] = {a.pos for a in mixed1}
|
||||||
|
mixed_names: set[str] = {a.name for a in mixed1}
|
||||||
|
for arg2 in pos2:
|
||||||
|
if not arg2.required:
|
||||||
|
continue
|
||||||
|
if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for name, arg2 in kw2.items():
|
||||||
|
if not arg2.required:
|
||||||
|
continue
|
||||||
|
if name not in kw1 and name not in mixed_names:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for arg2 in mixed2:
|
||||||
|
if arg2.required:
|
||||||
|
continue
|
||||||
|
pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions
|
||||||
|
kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names
|
||||||
|
if not pos_match or not kw_match:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||||
self.type_of(stmt.expr)
|
self.type_of(stmt.expr)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user