feat(checker): handle overloaded function calls
This commit is contained in:
@@ -12,12 +12,16 @@ from midas.checker.reporter import FileReporter, Reporter
|
||||
from midas.checker.resolver import Resolver
|
||||
from midas.checker.types import (
|
||||
Function,
|
||||
OverloadedFunction,
|
||||
Type,
|
||||
UnitType,
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
from midas.parser.python import PythonParser
|
||||
|
||||
TypedExpr = tuple[p.Expr, Type]
|
||||
|
||||
|
||||
class ReturnException(Exception):
|
||||
pass
|
||||
@@ -354,26 +358,7 @@ class PythonTyper(
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
match operation:
|
||||
case Function() as function:
|
||||
if not self._check_arity(function, 1, 0, 0):
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
rhs: Function.Argument = function.pos_args[0]
|
||||
if not self.is_subtype(right, rhs.type):
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Wrong type for right-hand side, expected {rhs.type}, got {right}",
|
||||
)
|
||||
return UnknownType()
|
||||
return function.returns
|
||||
case _:
|
||||
self.reporter.warning(location, f"Unsupported operation {operation}")
|
||||
return UnknownType()
|
||||
return self._get_call_result(location, operation, [(right_expr, right)], {})
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
||||
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
|
||||
@@ -393,35 +378,24 @@ class PythonTyper(
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
match operation:
|
||||
case Function() as function:
|
||||
if not self._check_arity(function, 0, 0, 0):
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Wrong definition of unary operation. Expected function with 0 parameters, got {function}",
|
||||
)
|
||||
return UnknownType()
|
||||
return function.returns
|
||||
case _:
|
||||
self.reporter.warning(
|
||||
expr.location, f"Unsupported operation {operation}"
|
||||
)
|
||||
return UnknownType()
|
||||
return self._get_call_result(
|
||||
expr.location, operation, [(expr.right, operand)], {}
|
||||
)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||
callee: Type = self.type_of(expr.callee)
|
||||
if not isinstance(callee, Function):
|
||||
self.reporter.error(expr.callee.location, "Callee is not a function")
|
||||
return UnknownType()
|
||||
function: Function = callee
|
||||
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
||||
for arg in mapped:
|
||||
if not self.is_subtype(arg.type, arg.argument.type):
|
||||
self.reporter.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
)
|
||||
return function.returns
|
||||
positional: list[TypedExpr] = [
|
||||
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||
]
|
||||
keywords: dict[str, TypedExpr] = {
|
||||
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
||||
}
|
||||
return self._get_call_result(
|
||||
location=expr.location,
|
||||
callee=callee,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
||||
object: Type = self.type_of(expr.object)
|
||||
@@ -572,9 +546,105 @@ class PythonTyper(
|
||||
self.reporter.warning(node.location, "FrameType not yet supported")
|
||||
return UnknownType()
|
||||
|
||||
def _get_call_result(
|
||||
self,
|
||||
location: Location,
|
||||
callee: Type,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
match callee:
|
||||
case Function() as function:
|
||||
valid: bool
|
||||
mapped: list[MappedArgument]
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function, location, positional, keywords
|
||||
)
|
||||
valid = valid and self._are_arguments_valid(mapped)
|
||||
if not valid:
|
||||
return UnknownType()
|
||||
return function.returns
|
||||
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
function = self._match_overload(
|
||||
overloads, location, positional, keywords
|
||||
)
|
||||
if function is None:
|
||||
return UnknownType()
|
||||
return function.returns
|
||||
case _:
|
||||
self.reporter.error(location, f"{callee} is not callable")
|
||||
return UnknownType()
|
||||
|
||||
def _are_arguments_valid(
|
||||
self,
|
||||
arguments: list[MappedArgument],
|
||||
report_errors: bool = True,
|
||||
) -> bool:
|
||||
valid: bool = True
|
||||
for arg in arguments:
|
||||
if not self.is_subtype(arg.type, arg.argument.type):
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
)
|
||||
valid = False
|
||||
return valid
|
||||
|
||||
def _match_overload(
|
||||
self,
|
||||
overloads: list[Type],
|
||||
location: Location,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Optional[Function]:
|
||||
candidates: list[Function] = []
|
||||
for overload in overloads:
|
||||
function: Type = unfold_type(overload)
|
||||
if not isinstance(function, Function):
|
||||
self.logger.error(
|
||||
f"Overload is not a function: {overload} is {function}"
|
||||
)
|
||||
continue
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function=function,
|
||||
location=location,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
report_errors=False,
|
||||
)
|
||||
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||
candidates.append(function)
|
||||
|
||||
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||
kw_types: str = ", ".join(
|
||||
f"{name}: {type}" for name, (_, type) in keywords.items()
|
||||
)
|
||||
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||
|
||||
if len(candidates) == 0:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"No matching overload in {overloads} {for_args}",
|
||||
)
|
||||
return None
|
||||
if len(candidates) > 1:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Multiple matching overloads {for_args}: {', '.join(map(str, candidates))}",
|
||||
)
|
||||
return None
|
||||
return candidates[0]
|
||||
|
||||
def map_call_arguments(
|
||||
self, function: Function, call: p.CallExpr
|
||||
) -> list[MappedArgument]:
|
||||
self,
|
||||
function: Function,
|
||||
location: Location,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
report_errors: bool = True,
|
||||
) -> tuple[bool, list[MappedArgument]]:
|
||||
"""Map call arguments to function parameters as defined in its signature
|
||||
|
||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||
@@ -589,12 +659,6 @@ class PythonTyper(
|
||||
Returns:
|
||||
list[MappedArgument]: the list of mapped arguments
|
||||
"""
|
||||
positional: list[tuple[p.Expr, Type]] = [
|
||||
(arg, self.type_of(arg)) for arg in call.arguments
|
||||
]
|
||||
keywords: dict[str, tuple[p.Expr, Type]] = {
|
||||
name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
|
||||
}
|
||||
set_args: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
@@ -612,6 +676,8 @@ class PythonTyper(
|
||||
arg.name: arg for arg in function.kw_args
|
||||
}
|
||||
|
||||
valid_call: bool = True
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Argument
|
||||
@@ -620,7 +686,11 @@ class PythonTyper(
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
self.reporter.error(arg[0].location, "Too many positional arguments")
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg[0].location, "Too many positional arguments"
|
||||
)
|
||||
valid_call = False
|
||||
break
|
||||
name: str = param.name
|
||||
if name in required_positional:
|
||||
@@ -640,14 +710,16 @@ class PythonTyper(
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Argument
|
||||
if name not in kw_params:
|
||||
if name in set_args:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
)
|
||||
else:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||
)
|
||||
if report_errors:
|
||||
if name in set_args:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
)
|
||||
else:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||
)
|
||||
valid_call = False
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
if name in required_positional:
|
||||
@@ -674,20 +746,24 @@ class PythonTyper(
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
args: str = join_args(required_positional)
|
||||
self.reporter.error(
|
||||
call.location,
|
||||
f"Missing required positional argument{plural}: {args}",
|
||||
)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required positional argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
args: str = join_args(required_keyword)
|
||||
self.reporter.error(
|
||||
call.location,
|
||||
f"Missing required keyword argument{plural}: {args}",
|
||||
)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required keyword argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
return mapped
|
||||
return valid_call, mapped
|
||||
|
||||
def _check_arity(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user