feat(checker): handle overloaded function calls
This commit is contained in:
@@ -12,12 +12,16 @@ from midas.checker.reporter import FileReporter, Reporter
|
|||||||
from midas.checker.resolver import Resolver
|
from midas.checker.resolver import Resolver
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
Function,
|
Function,
|
||||||
|
OverloadedFunction,
|
||||||
Type,
|
Type,
|
||||||
UnitType,
|
UnitType,
|
||||||
UnknownType,
|
UnknownType,
|
||||||
|
unfold_type,
|
||||||
)
|
)
|
||||||
from midas.parser.python import PythonParser
|
from midas.parser.python import PythonParser
|
||||||
|
|
||||||
|
TypedExpr = tuple[p.Expr, Type]
|
||||||
|
|
||||||
|
|
||||||
class ReturnException(Exception):
|
class ReturnException(Exception):
|
||||||
pass
|
pass
|
||||||
@@ -354,26 +358,7 @@ class PythonTyper(
|
|||||||
)
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
match operation:
|
return self._get_call_result(location, operation, [(right_expr, right)], {})
|
||||||
case Function() as function:
|
|
||||||
if not self._check_arity(function, 1, 0, 0):
|
|
||||||
self.reporter.error(
|
|
||||||
location,
|
|
||||||
f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}",
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
rhs: Function.Argument = function.pos_args[0]
|
|
||||||
if not self.is_subtype(right, rhs.type):
|
|
||||||
self.reporter.error(
|
|
||||||
location,
|
|
||||||
f"Wrong type for right-hand side, expected {rhs.type}, got {right}",
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
return function.returns
|
|
||||||
case _:
|
|
||||||
self.reporter.warning(location, f"Unsupported operation {operation}")
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
||||||
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
|
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
|
||||||
@@ -393,35 +378,24 @@ class PythonTyper(
|
|||||||
)
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
match operation:
|
return self._get_call_result(
|
||||||
case Function() as function:
|
expr.location, operation, [(expr.right, operand)], {}
|
||||||
if not self._check_arity(function, 0, 0, 0):
|
)
|
||||||
self.reporter.error(
|
|
||||||
expr.location,
|
|
||||||
f"Wrong definition of unary operation. Expected function with 0 parameters, got {function}",
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
return function.returns
|
|
||||||
case _:
|
|
||||||
self.reporter.warning(
|
|
||||||
expr.location, f"Unsupported operation {operation}"
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||||
callee: Type = self.type_of(expr.callee)
|
callee: Type = self.type_of(expr.callee)
|
||||||
if not isinstance(callee, Function):
|
positional: list[TypedExpr] = [
|
||||||
self.reporter.error(expr.callee.location, "Callee is not a function")
|
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||||
return UnknownType()
|
]
|
||||||
function: Function = callee
|
keywords: dict[str, TypedExpr] = {
|
||||||
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
||||||
for arg in mapped:
|
}
|
||||||
if not self.is_subtype(arg.type, arg.argument.type):
|
return self._get_call_result(
|
||||||
self.reporter.error(
|
location=expr.location,
|
||||||
arg.expr.location,
|
callee=callee,
|
||||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
positional=positional,
|
||||||
)
|
keywords=keywords,
|
||||||
return function.returns
|
)
|
||||||
|
|
||||||
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
||||||
object: Type = self.type_of(expr.object)
|
object: Type = self.type_of(expr.object)
|
||||||
@@ -572,9 +546,105 @@ class PythonTyper(
|
|||||||
self.reporter.warning(node.location, "FrameType not yet supported")
|
self.reporter.warning(node.location, "FrameType not yet supported")
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
|
def _get_call_result(
|
||||||
|
self,
|
||||||
|
location: Location,
|
||||||
|
callee: Type,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
) -> Type:
|
||||||
|
match callee:
|
||||||
|
case Function() as function:
|
||||||
|
valid: bool
|
||||||
|
mapped: list[MappedArgument]
|
||||||
|
valid, mapped = self.map_call_arguments(
|
||||||
|
function, location, positional, keywords
|
||||||
|
)
|
||||||
|
valid = valid and self._are_arguments_valid(mapped)
|
||||||
|
if not valid:
|
||||||
|
return UnknownType()
|
||||||
|
return function.returns
|
||||||
|
|
||||||
|
case OverloadedFunction(overloads=overloads):
|
||||||
|
function = self._match_overload(
|
||||||
|
overloads, location, positional, keywords
|
||||||
|
)
|
||||||
|
if function is None:
|
||||||
|
return UnknownType()
|
||||||
|
return function.returns
|
||||||
|
case _:
|
||||||
|
self.reporter.error(location, f"{callee} is not callable")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def _are_arguments_valid(
|
||||||
|
self,
|
||||||
|
arguments: list[MappedArgument],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
valid: bool = True
|
||||||
|
for arg in arguments:
|
||||||
|
if not self.is_subtype(arg.type, arg.argument.type):
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
arg.expr.location,
|
||||||
|
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
return valid
|
||||||
|
|
||||||
|
def _match_overload(
|
||||||
|
self,
|
||||||
|
overloads: list[Type],
|
||||||
|
location: Location,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
) -> Optional[Function]:
|
||||||
|
candidates: list[Function] = []
|
||||||
|
for overload in overloads:
|
||||||
|
function: Type = unfold_type(overload)
|
||||||
|
if not isinstance(function, Function):
|
||||||
|
self.logger.error(
|
||||||
|
f"Overload is not a function: {overload} is {function}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
valid, mapped = self.map_call_arguments(
|
||||||
|
function=function,
|
||||||
|
location=location,
|
||||||
|
positional=positional,
|
||||||
|
keywords=keywords,
|
||||||
|
report_errors=False,
|
||||||
|
)
|
||||||
|
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||||
|
candidates.append(function)
|
||||||
|
|
||||||
|
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||||
|
kw_types: str = ", ".join(
|
||||||
|
f"{name}: {type}" for name, (_, type) in keywords.items()
|
||||||
|
)
|
||||||
|
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||||
|
|
||||||
|
if len(candidates) == 0:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"No matching overload in {overloads} {for_args}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
if len(candidates) > 1:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Multiple matching overloads {for_args}: {', '.join(map(str, candidates))}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return candidates[0]
|
||||||
|
|
||||||
def map_call_arguments(
|
def map_call_arguments(
|
||||||
self, function: Function, call: p.CallExpr
|
self,
|
||||||
) -> list[MappedArgument]:
|
function: Function,
|
||||||
|
location: Location,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> tuple[bool, list[MappedArgument]]:
|
||||||
"""Map call arguments to function parameters as defined in its signature
|
"""Map call arguments to function parameters as defined in its signature
|
||||||
|
|
||||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||||
@@ -589,12 +659,6 @@ class PythonTyper(
|
|||||||
Returns:
|
Returns:
|
||||||
list[MappedArgument]: the list of mapped arguments
|
list[MappedArgument]: the list of mapped arguments
|
||||||
"""
|
"""
|
||||||
positional: list[tuple[p.Expr, Type]] = [
|
|
||||||
(arg, self.type_of(arg)) for arg in call.arguments
|
|
||||||
]
|
|
||||||
keywords: dict[str, tuple[p.Expr, Type]] = {
|
|
||||||
name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
|
|
||||||
}
|
|
||||||
set_args: set[str] = set()
|
set_args: set[str] = set()
|
||||||
|
|
||||||
required_positional: list[str] = [
|
required_positional: list[str] = [
|
||||||
@@ -612,6 +676,8 @@ class PythonTyper(
|
|||||||
arg.name: arg for arg in function.kw_args
|
arg.name: arg for arg in function.kw_args
|
||||||
}
|
}
|
||||||
|
|
||||||
|
valid_call: bool = True
|
||||||
|
|
||||||
# TODO: handle *args and **kwargs sinks
|
# TODO: handle *args and **kwargs sinks
|
||||||
for arg in positional:
|
for arg in positional:
|
||||||
param: Function.Argument
|
param: Function.Argument
|
||||||
@@ -620,7 +686,11 @@ class PythonTyper(
|
|||||||
elif len(mixed_params) != 0:
|
elif len(mixed_params) != 0:
|
||||||
param = mixed_params.pop(0)
|
param = mixed_params.pop(0)
|
||||||
else:
|
else:
|
||||||
self.reporter.error(arg[0].location, "Too many positional arguments")
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, "Too many positional arguments"
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
break
|
break
|
||||||
name: str = param.name
|
name: str = param.name
|
||||||
if name in required_positional:
|
if name in required_positional:
|
||||||
@@ -640,14 +710,16 @@ class PythonTyper(
|
|||||||
for name, arg in keywords.items():
|
for name, arg in keywords.items():
|
||||||
param: Function.Argument
|
param: Function.Argument
|
||||||
if name not in kw_params:
|
if name not in kw_params:
|
||||||
if name in set_args:
|
if report_errors:
|
||||||
self.reporter.error(
|
if name in set_args:
|
||||||
arg[0].location, f"Multiple values for argument '{name}'"
|
self.reporter.error(
|
||||||
)
|
arg[0].location, f"Multiple values for argument '{name}'"
|
||||||
else:
|
)
|
||||||
self.reporter.error(
|
else:
|
||||||
arg[0].location, f"Unknown keyword argument '{name}'"
|
self.reporter.error(
|
||||||
)
|
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
continue
|
continue
|
||||||
param = kw_params.pop(name)
|
param = kw_params.pop(name)
|
||||||
if name in required_positional:
|
if name in required_positional:
|
||||||
@@ -674,20 +746,24 @@ class PythonTyper(
|
|||||||
if len(required_positional) != 0:
|
if len(required_positional) != 0:
|
||||||
plural: str = "" if len(required_positional) == 1 else "s"
|
plural: str = "" if len(required_positional) == 1 else "s"
|
||||||
args: str = join_args(required_positional)
|
args: str = join_args(required_positional)
|
||||||
self.reporter.error(
|
if report_errors:
|
||||||
call.location,
|
self.reporter.error(
|
||||||
f"Missing required positional argument{plural}: {args}",
|
location,
|
||||||
)
|
f"Missing required positional argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
|
||||||
if len(required_keyword) != 0:
|
if len(required_keyword) != 0:
|
||||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||||
args: str = join_args(required_keyword)
|
args: str = join_args(required_keyword)
|
||||||
self.reporter.error(
|
if report_errors:
|
||||||
call.location,
|
self.reporter.error(
|
||||||
f"Missing required keyword argument{plural}: {args}",
|
location,
|
||||||
)
|
f"Missing required keyword argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
|
||||||
return mapped
|
return valid_call, mapped
|
||||||
|
|
||||||
def _check_arity(
|
def _check_arity(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user