Compare commits
10 Commits
c81287df7f
...
e2abc04fe4
| Author | SHA1 | Date | |
|---|---|---|---|
|
e2abc04fe4
|
|||
|
a4016b55ce
|
|||
|
1ea5da7024
|
|||
|
a017a8cf1f
|
|||
|
8fc5ab623e
|
|||
|
14007db846
|
|||
|
6ad2ce4b68
|
|||
|
9a276c34c7
|
|||
|
6e717a3f9e
|
|||
|
77aadfa264
|
484
midas/checker/dispatcher.py
Normal file
484
midas/checker/dispatcher.py
Normal file
@@ -0,0 +1,484 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Generic, Optional, Protocol, TypeVar, Union
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
DerivedType,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.checker.unifier import Unifier
|
||||
|
||||
|
||||
class HasLocation(Protocol):
|
||||
@property
|
||||
def location(self) -> Location: ...
|
||||
|
||||
|
||||
E = TypeVar("E", bound=HasLocation)
|
||||
|
||||
TypedExpr = tuple[E, Type]
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MappedArgument(Generic[E]):
|
||||
expr: E
|
||||
type: Type
|
||||
argument: Function.Argument
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class OverloadCandidate:
|
||||
function: Function
|
||||
mapped: list[MappedArgument]
|
||||
|
||||
|
||||
class CallError(StrEnum):
|
||||
INVALID_ARGS = "Invalid arguments"
|
||||
NO_MATCHING_OVERLOAD = "No matching overload"
|
||||
IMPOSSIBLE_UNIFICATION = "Parameters unification failed"
|
||||
NOT_CALLABLE = "Not callable"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class CallResult:
|
||||
error: Optional[CallError] = None
|
||||
result: Type = UnknownType()
|
||||
message: Optional[str] = None
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return self.error is None
|
||||
|
||||
@property
|
||||
def error_message(self) -> str:
|
||||
if self.message is not None:
|
||||
return self.message
|
||||
if self.error is not None:
|
||||
return str(self.error)
|
||||
return ""
|
||||
|
||||
|
||||
class CallDispatcher(Generic[E]):
|
||||
def __init__(self, types: TypesRegistry, reporter: FileReporter) -> None:
|
||||
self.types: TypesRegistry = types
|
||||
self.reporter: FileReporter = reporter
|
||||
self.logger: logging.Logger = logging.getLogger("CallDispatcher")
|
||||
|
||||
def set_reporter(self, reporter: FileReporter):
|
||||
self.reporter = reporter
|
||||
|
||||
def get_result(
|
||||
self,
|
||||
location: Location,
|
||||
callee: Type,
|
||||
positional: list[TypedExpr[E]],
|
||||
keywords: dict[str, TypedExpr[E]],
|
||||
report_errors: bool = True,
|
||||
) -> CallResult:
|
||||
"""Get the result type of a function call
|
||||
|
||||
If the function has overloads, the function will try to resolve the
|
||||
appropriate signature.
|
||||
Argument types are matched to the defined parameters.
|
||||
The function doesn't take the raw expression as a parameter to accommodate
|
||||
for desugared calls such as for operators.
|
||||
|
||||
Args:
|
||||
location (Location): the call location
|
||||
callee (Type): the called function
|
||||
positional (list[TypedExpr]): the list positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Type: the return type of the call, or `None` if either
|
||||
the call is invalid or no overload matched the arguments uniquely
|
||||
"""
|
||||
match callee:
|
||||
case Function() as function:
|
||||
valid: bool
|
||||
mapped: list[MappedArgument[E]]
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function, location, positional, keywords
|
||||
)
|
||||
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
||||
if not valid:
|
||||
return CallResult(error=CallError.INVALID_ARGS)
|
||||
return CallResult(result=function.returns)
|
||||
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
res = self._match_overload(
|
||||
overloads, location, positional, keywords, report_errors
|
||||
)
|
||||
if res[0] is None:
|
||||
return CallResult(
|
||||
error=CallError.NO_MATCHING_OVERLOAD,
|
||||
message=res[1],
|
||||
)
|
||||
return CallResult(result=res[0].returns)
|
||||
|
||||
case AppliedType(body=body):
|
||||
return self.get_result(
|
||||
location, body, positional, keywords, report_errors
|
||||
)
|
||||
|
||||
case UnknownType():
|
||||
return CallResult(result=UnknownType())
|
||||
|
||||
case DerivedType(type=base):
|
||||
return self.get_result(
|
||||
location, base, positional, keywords, report_errors
|
||||
)
|
||||
|
||||
case GenericType():
|
||||
unifier: Unifier = Unifier(self.types)
|
||||
pos: list[Type] = [a[1] for a in positional]
|
||||
kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()}
|
||||
unified: Optional[Type] = unifier.unify_call(callee, pos, kw)
|
||||
if unified is None:
|
||||
pos_str: str = ", ".join(str(t) for t in pos)
|
||||
kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items())
|
||||
message: str = (
|
||||
f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}"
|
||||
)
|
||||
if report_errors:
|
||||
self.reporter.error(location, message)
|
||||
return CallResult(
|
||||
error=CallError.IMPOSSIBLE_UNIFICATION,
|
||||
message=message,
|
||||
)
|
||||
return self.get_result(
|
||||
location,
|
||||
unified,
|
||||
positional,
|
||||
keywords,
|
||||
report_errors,
|
||||
)
|
||||
|
||||
case _:
|
||||
message: str = f"{callee} ({callee.__class__.__name__}) is not callable"
|
||||
if report_errors:
|
||||
self.reporter.error(location, message)
|
||||
return CallResult(
|
||||
error=CallError.NOT_CALLABLE,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def _unwrap_function(
|
||||
self,
|
||||
callee: Type,
|
||||
positional: list[TypedExpr[E]],
|
||||
keywords: dict[str, TypedExpr[E]],
|
||||
) -> Union[tuple[Function, None], tuple[None, CallError]]:
|
||||
match callee:
|
||||
case DerivedType(type=base):
|
||||
return self._unwrap_function(base, positional, keywords)
|
||||
|
||||
case GenericType():
|
||||
unifier: Unifier = Unifier(self.types)
|
||||
unified: Optional[Type] = unifier.unify_call(
|
||||
callee,
|
||||
[a[1] for a in positional],
|
||||
{k: v[1] for k, v in keywords.items()},
|
||||
)
|
||||
if unified is None:
|
||||
return None, CallError.IMPOSSIBLE_UNIFICATION
|
||||
return self._unwrap_function(unified, positional, keywords)
|
||||
|
||||
case Function():
|
||||
return callee, None
|
||||
|
||||
case AppliedType(body=body):
|
||||
return self._unwrap_function(body, positional, keywords)
|
||||
|
||||
case _:
|
||||
return None, CallError.NOT_CALLABLE
|
||||
|
||||
def _are_arguments_valid(
|
||||
self,
|
||||
arguments: list[MappedArgument[E]],
|
||||
report_errors: bool = True,
|
||||
) -> bool:
|
||||
"""Check whether the passed argument types correspond to their matched parameter definitions
|
||||
|
||||
Args:
|
||||
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
||||
"""
|
||||
valid: bool = True
|
||||
for arg in arguments:
|
||||
if not self.types.is_subtype(arg.type, arg.argument.type):
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
)
|
||||
valid = False
|
||||
return valid
|
||||
|
||||
def _match_overload(
|
||||
self,
|
||||
overloads: list[Type],
|
||||
location: Location,
|
||||
positional: list[TypedExpr[E]],
|
||||
keywords: dict[str, TypedExpr[E]],
|
||||
report_errors: bool = True,
|
||||
) -> Union[tuple[Function, None], tuple[None, str]]:
|
||||
"""Try and resolve the appropriate overload for the given arguments
|
||||
|
||||
Args:
|
||||
overloads (list[Type]): the list of possible overloads
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional[Function]: the resolved function signature if it can be
|
||||
determined unambiguously, or `None`.
|
||||
"""
|
||||
candidates: list[OverloadCandidate] = []
|
||||
errors: list[CallError] = []
|
||||
for overload in overloads:
|
||||
function, unwrap_error = self._unwrap_function(
|
||||
overload, positional, keywords
|
||||
)
|
||||
if function is None:
|
||||
errors.append(unwrap_error) # type: ignore
|
||||
continue
|
||||
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function=function,
|
||||
location=location,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
report_errors=False,
|
||||
)
|
||||
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||
candidates.append(
|
||||
OverloadCandidate(
|
||||
function=function,
|
||||
mapped=mapped,
|
||||
)
|
||||
)
|
||||
|
||||
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||
kw_types: str = ", ".join(
|
||||
f"{name}: {type}" for name, (_, type) in keywords.items()
|
||||
)
|
||||
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||
|
||||
n_candidates: int = len(candidates)
|
||||
|
||||
# Exactly 1 match -> return it
|
||||
if n_candidates == 1:
|
||||
return candidates[0].function, None
|
||||
|
||||
# No match -> invalid call
|
||||
if n_candidates == 0:
|
||||
overloads_str: str = ", ".join(map(str, overloads))
|
||||
errors_str: str = ", ".join(errors)
|
||||
message: str = (
|
||||
f"No matching overload in [{overloads_str}] {for_args} (errors: {errors_str})"
|
||||
)
|
||||
if report_errors:
|
||||
self.reporter.error(location, message)
|
||||
return None, message
|
||||
|
||||
# Multiple matches -> see if one <: all others (more specific)
|
||||
for i1, c1 in enumerate(candidates):
|
||||
mapped1: list[MappedArgument[E]] = c1.mapped
|
||||
best_match: bool = True
|
||||
for i2, c2 in enumerate(candidates):
|
||||
if i1 == i2:
|
||||
continue
|
||||
mapped2: list[MappedArgument[E]] = c2.mapped
|
||||
if not self._are_mapped_subtypes(mapped1, mapped2):
|
||||
best_match = False
|
||||
break
|
||||
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
||||
if best_match:
|
||||
return c1.function, None
|
||||
|
||||
candidates_str: str = ", ".join(
|
||||
str(candidate.function) for candidate in candidates
|
||||
)
|
||||
message: str = f"Multiple matching overloads {for_args}: {candidates_str}"
|
||||
if report_errors:
|
||||
self.reporter.error(location, message)
|
||||
return None, message
|
||||
|
||||
def map_call_arguments(
|
||||
self,
|
||||
function: Function,
|
||||
location: Location,
|
||||
positional: list[TypedExpr[E]],
|
||||
keywords: dict[str, TypedExpr[E]],
|
||||
report_errors: bool = True,
|
||||
) -> tuple[bool, list[MappedArgument]]:
|
||||
"""Map call arguments to a function's parameters as defined in its signature
|
||||
|
||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||
with the arguments passed at the call site
|
||||
|
||||
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
||||
unless `report_errors` is set to `False`
|
||||
|
||||
Args:
|
||||
function (Function): the function definition
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
||||
the call is valid and the list of mapped arguments
|
||||
"""
|
||||
set_args: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
arg.name for arg in function.pos_args + function.args if arg.required
|
||||
]
|
||||
required_keyword: list[str] = [
|
||||
arg.name for arg in function.kw_args if arg.required
|
||||
]
|
||||
|
||||
mapped: list[MappedArgument[E]] = []
|
||||
|
||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||
mixed_params: list[Function.Argument] = list(function.args)
|
||||
kw_params: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.kw_args
|
||||
}
|
||||
|
||||
valid_call: bool = True
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Argument
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg[0].location, "Too many positional arguments"
|
||||
)
|
||||
valid_call = False
|
||||
break
|
||||
name: str = param.name
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Argument
|
||||
if name not in kw_params:
|
||||
if report_errors:
|
||||
if name in set_args:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
)
|
||||
else:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||
)
|
||||
valid_call = False
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
def join_args(args: list[str]) -> str:
|
||||
args = list(map(lambda a: f"'{a}'", args))
|
||||
if len(args) == 0:
|
||||
return ""
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
args: str = join_args(required_positional)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required positional argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
args: str = join_args(required_keyword)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required keyword argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
return valid_call, mapped
|
||||
|
||||
def _are_mapped_subtypes(
|
||||
self, mapped1: list[MappedArgument[E]], mapped2: list[MappedArgument[E]]
|
||||
) -> bool:
|
||||
"""Check whether the given argument mappings are subtype/supertype of one another
|
||||
|
||||
This function checks whether the argument mappings `mapped1` are subtypes
|
||||
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
||||
of the corresponding parameter in `mapped2`, `False` is returned.
|
||||
|
||||
This is used to check whether a given overload is
|
||||
a more specific function/ a subtype of another.
|
||||
|
||||
Args:
|
||||
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
||||
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
||||
|
||||
Returns:
|
||||
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
||||
"""
|
||||
by_expr: dict[E, Type] = {}
|
||||
for arg in mapped1:
|
||||
by_expr[arg.expr] = arg.argument.type
|
||||
|
||||
for arg in mapped2:
|
||||
type2: Type = arg.argument.type
|
||||
type1: Type = by_expr[arg.expr]
|
||||
if not self.types.is_subtype(type1, type2):
|
||||
return False
|
||||
return True
|
||||
@@ -6,13 +6,13 @@ from typing import Optional
|
||||
import midas.ast.midas as m
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.builtins import define_builtins
|
||||
from midas.checker.dispatcher import CallDispatcher, CallResult
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
|
||||
from midas.checker.preamble import Preamble
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter, Reporter
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
ColumnType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
@@ -21,12 +21,10 @@ from midas.checker.types import (
|
||||
ExtensionType,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
Predicate,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
from midas.checker.variance import VarianceInferrer
|
||||
from midas.lexer.midas import MidasLexer
|
||||
@@ -41,9 +39,6 @@ class TypedParamSpec:
|
||||
kw: list[Function.Argument]
|
||||
|
||||
|
||||
TypedExpr = tuple[m.Expr, Type]
|
||||
|
||||
|
||||
class ReturnException(Exception):
|
||||
pass
|
||||
|
||||
@@ -67,8 +62,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||
self.logger: logging.Logger = logging.getLogger("MidasTyper")
|
||||
self.reporter: FileReporter = reporter.for_file(None)
|
||||
|
||||
self.types: TypesRegistry = types
|
||||
self.dispatcher: CallDispatcher[m.Expr] = CallDispatcher[m.Expr](
|
||||
self.types, self.reporter
|
||||
)
|
||||
|
||||
self._local_variables: dict[str, TypeVar] = {}
|
||||
|
||||
self._predicate_params: dict[str, Type] = {}
|
||||
@@ -83,8 +81,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
|
||||
self._preamble: Environment = Preamble(self.types)
|
||||
|
||||
def set_reporter(self, reporter: FileReporter):
|
||||
self.reporter = reporter
|
||||
self.dispatcher.set_reporter(reporter)
|
||||
|
||||
def process(self, source: str, path: Optional[str]):
|
||||
self.reporter = self.reporter.for_file(path)
|
||||
reporter: FileReporter = self.reporter.for_file(path)
|
||||
self.set_reporter(reporter)
|
||||
|
||||
lexer: MidasLexer = MidasLexer(source)
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser: MidasParser = MidasParser(tokens)
|
||||
@@ -259,13 +263,13 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
result: Optional[Type] = self._get_call_result(
|
||||
location,
|
||||
operation,
|
||||
[(right_expr, right)],
|
||||
{},
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=location,
|
||||
callee=operation,
|
||||
positional=[(right_expr, right)],
|
||||
keywords={},
|
||||
)
|
||||
return result or UnknownType()
|
||||
return result.result
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
|
||||
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
|
||||
@@ -285,31 +289,29 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
result: Optional[Type] = self._get_call_result(
|
||||
expr.location,
|
||||
operation,
|
||||
[],
|
||||
{},
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=expr.location,
|
||||
callee=operation,
|
||||
positional=[],
|
||||
keywords={},
|
||||
)
|
||||
return result or UnknownType()
|
||||
return result.result
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> Type:
|
||||
callee: Type = expr.callee.accept(self)
|
||||
positional: list[TypedExpr] = [
|
||||
positional: list[tuple[m.Expr, Type]] = [
|
||||
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||
]
|
||||
keywords: dict[str, TypedExpr] = {
|
||||
keywords: dict[str, tuple[m.Expr, Type]] = {
|
||||
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
||||
}
|
||||
return (
|
||||
self._get_call_result(
|
||||
expr.location,
|
||||
callee,
|
||||
positional,
|
||||
keywords,
|
||||
)
|
||||
or UnknownType()
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=expr.location,
|
||||
callee=callee,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> Type:
|
||||
object: Type = expr.expr.accept(self)
|
||||
@@ -433,343 +435,3 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
self._local_variables[name] = var
|
||||
vars.append(var)
|
||||
return vars
|
||||
|
||||
def _get_call_result(
|
||||
self,
|
||||
location: Location,
|
||||
callee: Type,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
report_errors: bool = True,
|
||||
) -> Optional[Type]:
|
||||
"""Get the result type of a function call
|
||||
|
||||
If the function has overloads, the function will try to resolve the
|
||||
appropriate signature.
|
||||
Argument types are matched to the defined parameters.
|
||||
The function doesn't take the raw expression as a parameter to accommodate
|
||||
for desugared calls such as for operators.
|
||||
|
||||
Args:
|
||||
location (Location): the call location
|
||||
callee (Type): the called function
|
||||
positional (list[TypedExpr]): the list positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Type: the return type of the call, or `None` if either
|
||||
the call is invalid or no overload matched the arguments uniquely
|
||||
"""
|
||||
match callee:
|
||||
case Function() as function:
|
||||
valid: bool
|
||||
mapped: list[MappedArgument]
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function, location, positional, keywords
|
||||
)
|
||||
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
||||
if not valid:
|
||||
return None
|
||||
return function.returns
|
||||
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
function = self._match_overload(
|
||||
overloads, location, positional, keywords, report_errors
|
||||
)
|
||||
if function is None:
|
||||
return None
|
||||
return function.returns
|
||||
|
||||
case AppliedType(body=body):
|
||||
return self._get_call_result(
|
||||
location, body, positional, keywords, report_errors
|
||||
)
|
||||
|
||||
case UnknownType():
|
||||
return UnknownType()
|
||||
|
||||
case _:
|
||||
if report_errors:
|
||||
self.reporter.error(location, f"{callee} is not callable")
|
||||
return None
|
||||
|
||||
def _are_arguments_valid(
|
||||
self,
|
||||
arguments: list[MappedArgument],
|
||||
report_errors: bool = True,
|
||||
) -> bool:
|
||||
"""Check whether the passed argument types correspond to their matched parameter definitions
|
||||
|
||||
Args:
|
||||
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
||||
"""
|
||||
valid: bool = True
|
||||
for arg in arguments:
|
||||
if not self.types.is_subtype(arg.type, arg.argument.type):
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
)
|
||||
valid = False
|
||||
return valid
|
||||
|
||||
def _match_overload(
|
||||
self,
|
||||
overloads: list[Type],
|
||||
location: Location,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
report_errors: bool = True,
|
||||
) -> Optional[Function]:
|
||||
"""Try and resolve the appropriate overload for the given arguments
|
||||
|
||||
Args:
|
||||
overloads (list[Type]): the list of possible overloads
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional[Function]: the resolved function signature if it can be
|
||||
determined unambiguously, or `None`.
|
||||
"""
|
||||
candidates: list[OverloadCandidate] = []
|
||||
for overload in overloads:
|
||||
function: Type = unfold_type(overload)
|
||||
if not isinstance(function, Function):
|
||||
if report_errors:
|
||||
self.logger.error(
|
||||
f"Overload is not a function: {overload} is {function}"
|
||||
)
|
||||
continue
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function=function,
|
||||
location=location,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
report_errors=False,
|
||||
)
|
||||
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||
candidates.append(
|
||||
OverloadCandidate(
|
||||
function=function,
|
||||
mapped=mapped,
|
||||
)
|
||||
)
|
||||
|
||||
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||
kw_types: str = ", ".join(
|
||||
f"{name}: {type}" for name, (_, type) in keywords.items()
|
||||
)
|
||||
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||
|
||||
n_candidates: int = len(candidates)
|
||||
|
||||
# Exactly 1 match -> return it
|
||||
if n_candidates == 1:
|
||||
return candidates[0].function
|
||||
|
||||
# No match -> invalid call
|
||||
if n_candidates == 0:
|
||||
overloads_str: str = ", ".join(map(str, overloads))
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"No matching overload in [{overloads_str}] {for_args}",
|
||||
)
|
||||
return None
|
||||
|
||||
# Multiple matches -> see if one <: all others (more specific)
|
||||
for i1, c1 in enumerate(candidates):
|
||||
mapped1: list[MappedArgument] = c1.mapped
|
||||
best_match: bool = True
|
||||
for i2, c2 in enumerate(candidates):
|
||||
if i1 == i2:
|
||||
continue
|
||||
mapped2: list[MappedArgument] = c2.mapped
|
||||
if not self._are_mapped_subtypes(mapped1, mapped2):
|
||||
best_match = False
|
||||
break
|
||||
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
||||
if best_match:
|
||||
return c1.function
|
||||
|
||||
candidates_str: str = ", ".join(
|
||||
str(candidate.function) for candidate in candidates
|
||||
)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Multiple matching overloads {for_args}: {candidates_str}",
|
||||
)
|
||||
return None
|
||||
|
||||
def map_call_arguments(
|
||||
self,
|
||||
function: Function,
|
||||
location: Location,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
report_errors: bool = True,
|
||||
) -> tuple[bool, list[MappedArgument]]:
|
||||
"""Map call arguments to a function's parameters as defined in its signature
|
||||
|
||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||
with the arguments passed at the call site
|
||||
|
||||
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
||||
unless `report_errors` is set to `False`
|
||||
|
||||
Args:
|
||||
function (Function): the function definition
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
||||
the call is valid and the list of mapped arguments
|
||||
"""
|
||||
set_args: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
arg.name for arg in function.pos_args + function.args if arg.required
|
||||
]
|
||||
required_keyword: list[str] = [
|
||||
arg.name for arg in function.kw_args if arg.required
|
||||
]
|
||||
|
||||
mapped: list[MappedArgument] = []
|
||||
|
||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||
mixed_params: list[Function.Argument] = list(function.args)
|
||||
kw_params: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.kw_args
|
||||
}
|
||||
|
||||
valid_call: bool = True
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Argument
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg[0].location, "Too many positional arguments"
|
||||
)
|
||||
valid_call = False
|
||||
break
|
||||
name: str = param.name
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Argument
|
||||
if name not in kw_params:
|
||||
if report_errors:
|
||||
if name in set_args:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
)
|
||||
else:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||
)
|
||||
valid_call = False
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
def join_args(args: list[str]) -> str:
|
||||
args = list(map(lambda a: f"'{a}'", args))
|
||||
if len(args) == 0:
|
||||
return ""
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
args: str = join_args(required_positional)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required positional argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
args: str = join_args(required_keyword)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required keyword argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
return valid_call, mapped
|
||||
|
||||
def _are_mapped_subtypes(
|
||||
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
|
||||
) -> bool:
|
||||
"""Check whether the given argument mappings are subtype/supertype of one another
|
||||
|
||||
This function checks whether the argument mappings `mapped1` are subtypes
|
||||
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
||||
of the corresponding parameter in `mapped2`, `False` is returned.
|
||||
|
||||
This is used to check whether a given overload is
|
||||
a more specific function/ a subtype of another.
|
||||
|
||||
Args:
|
||||
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
||||
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
||||
|
||||
Returns:
|
||||
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
||||
"""
|
||||
by_expr: dict[m.Expr, Type] = {}
|
||||
for arg in mapped1:
|
||||
by_expr[arg.expr] = arg.argument.type
|
||||
|
||||
for arg in mapped2:
|
||||
type2: Type = arg.argument.type
|
||||
type1: Type = by_expr[arg.expr]
|
||||
if not self.types.is_subtype(type1, type2):
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -41,7 +41,7 @@ PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = {
|
||||
|
||||
|
||||
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
|
||||
# TokenType.PLUS: "__add__",
|
||||
TokenType.PLUS: "__add__",
|
||||
TokenType.MINUS: "__sub__",
|
||||
TokenType.STAR: "__mul__",
|
||||
TokenType.SLASH: "__truediv__",
|
||||
|
||||
@@ -3,7 +3,15 @@ from typing import Any, Callable, Optional
|
||||
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import Function, GenericType, TopType, Type, TypeVar, UnitType
|
||||
from midas.checker.types import (
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -70,6 +78,36 @@ class Preamble(Environment):
|
||||
returns=self._types.get_type("int"),
|
||||
)
|
||||
|
||||
T = TypeVar(name="T", bound=None)
|
||||
self._def_overloads(
|
||||
name="max",
|
||||
py_function=max,
|
||||
signatures=[
|
||||
(
|
||||
[Param("arg1", T), Param("arg2", T)],
|
||||
[],
|
||||
[],
|
||||
T,
|
||||
[T],
|
||||
),
|
||||
([Param("iterable", self._list_of(T))], [], [], T, [T]),
|
||||
],
|
||||
)
|
||||
self._def_overloads(
|
||||
name="min",
|
||||
py_function=min,
|
||||
signatures=[
|
||||
(
|
||||
[Param("arg1", T), Param("arg2", T)],
|
||||
[],
|
||||
[],
|
||||
T,
|
||||
[T],
|
||||
),
|
||||
([Param("iterable", self._list_of(T))], [], [], T, [T]),
|
||||
],
|
||||
)
|
||||
|
||||
def _list_of(self, item_type: Type) -> Type:
|
||||
return self._types.apply_generic(self._types.get_type("list"), [item_type])
|
||||
|
||||
@@ -142,5 +180,31 @@ class Preamble(Environment):
|
||||
if py_function is not None:
|
||||
self._python_funcs[name] = py_function
|
||||
|
||||
def _def_overloads(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
signatures: list[
|
||||
tuple[list[Param], list[Param], list[Param], Type, list[TypeVar]]
|
||||
],
|
||||
py_function: Optional[Callable[..., Any]] = None,
|
||||
):
|
||||
overloads: list[Type] = []
|
||||
for pos, mixed, kw, returns, type_vars in signatures:
|
||||
overloads.append(
|
||||
self._make_function(
|
||||
name=name,
|
||||
pos=pos,
|
||||
mixed=mixed,
|
||||
kw=kw,
|
||||
returns=returns,
|
||||
type_vars=type_vars,
|
||||
)
|
||||
)
|
||||
function: Type = OverloadedFunction(overloads=overloads)
|
||||
self.define(name, function)
|
||||
if py_function is not None:
|
||||
self._python_funcs[name] = py_function
|
||||
|
||||
def get_py_func(self, name: str) -> Optional[Callable[..., Any]]:
|
||||
return self._python_funcs.get(name)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Optional
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.printer import MidasPrinter
|
||||
from midas.checker.dispatcher import CallDispatcher, CallResult
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.evaluator import Evaluator
|
||||
from midas.checker.frames import FrameManager
|
||||
@@ -27,7 +28,6 @@ from midas.checker.types import (
|
||||
DerivedType,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
TupleType,
|
||||
Type,
|
||||
TypeVar,
|
||||
@@ -36,7 +36,6 @@ from midas.checker.types import (
|
||||
Variance,
|
||||
unfold_type,
|
||||
)
|
||||
from midas.checker.unifier import Unifier
|
||||
from midas.parser.python import PythonParser
|
||||
from midas.utils import TypedAST
|
||||
|
||||
@@ -85,9 +84,17 @@ class PythonTyper(
|
||||
self.locals: dict[p.Expr, int] = {}
|
||||
self.judgements: list[tuple[p.Expr, Type]] = []
|
||||
self.evaluated_casts: list[p.CastExpr] = []
|
||||
self.dispatcher: CallDispatcher[p.Expr] = CallDispatcher[p.Expr](
|
||||
self.types, self.reporter
|
||||
)
|
||||
|
||||
def set_reporter(self, reporter: FileReporter):
|
||||
self.reporter = reporter
|
||||
self.dispatcher.set_reporter(self.reporter)
|
||||
|
||||
def process(self, source: str, path: Optional[str]) -> TypedAST:
|
||||
self.reporter = self.reporter.for_file(path)
|
||||
reporter: FileReporter = self.reporter.for_file(path)
|
||||
self.set_reporter(reporter)
|
||||
|
||||
tree: ast.Module = ast.parse(source, filename=path or "<unknown>")
|
||||
parser = PythonParser()
|
||||
@@ -222,12 +229,13 @@ class PythonTyper(
|
||||
if method is None:
|
||||
raise UndefinedMethodException
|
||||
|
||||
return self._get_call_result(
|
||||
location,
|
||||
method,
|
||||
positional,
|
||||
keywords,
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=location,
|
||||
callee=method,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||
return self.types.is_subtype(type1, type2)
|
||||
@@ -571,15 +579,13 @@ class PythonTyper(
|
||||
)
|
||||
|
||||
callee: Type = self.type_of(expr.callee)
|
||||
return (
|
||||
self._get_call_result(
|
||||
location=expr.location,
|
||||
callee=callee,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
or UnknownType()
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=expr.location,
|
||||
callee=callee,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
||||
object: Type = self.type_of(expr.object)
|
||||
@@ -742,10 +748,13 @@ class PythonTyper(
|
||||
return UnknownType()
|
||||
|
||||
index: Type = self.type_of(expr.index)
|
||||
return (
|
||||
self._get_call_result(expr.location, operation, [(expr.index, index)], {})
|
||||
or UnknownType()
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=expr.location,
|
||||
callee=operation,
|
||||
positional=[(expr.index, index)],
|
||||
keywords={},
|
||||
)
|
||||
return result.result
|
||||
|
||||
def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
|
||||
return self.types.get_type("slice")
|
||||
@@ -796,376 +805,6 @@ class PythonTyper(
|
||||
]
|
||||
)
|
||||
|
||||
def _get_call_result(
|
||||
self,
|
||||
location: Location,
|
||||
callee: Type,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
report_errors: bool = True,
|
||||
) -> Optional[Type]:
|
||||
"""Get the result type of a function call
|
||||
|
||||
If the function has overloads, the function will try to resolve the
|
||||
appropriate signature.
|
||||
Argument types are matched to the defined parameters.
|
||||
The function doesn't take the raw expression as a parameter to accommodate
|
||||
for desugared calls such as for operators.
|
||||
|
||||
Args:
|
||||
location (Location): the call location
|
||||
callee (Type): the called function
|
||||
positional (list[TypedExpr]): the list positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Type: the return type of the call, or `None` if either
|
||||
the call is invalid or no overload matched the arguments uniquely
|
||||
"""
|
||||
match callee:
|
||||
case Function() as function:
|
||||
valid: bool
|
||||
mapped: list[MappedArgument]
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function, location, positional, keywords
|
||||
)
|
||||
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
||||
if not valid:
|
||||
return None
|
||||
return function.returns
|
||||
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
function = self._match_overload(
|
||||
overloads, location, positional, keywords, report_errors
|
||||
)
|
||||
if function is None:
|
||||
return None
|
||||
return function.returns
|
||||
|
||||
case AppliedType(body=body):
|
||||
return self._get_call_result(
|
||||
location, body, positional, keywords, report_errors
|
||||
)
|
||||
|
||||
case UnknownType():
|
||||
return UnknownType()
|
||||
|
||||
case DerivedType(type=base):
|
||||
return self._get_call_result(
|
||||
location, base, positional, keywords, report_errors
|
||||
)
|
||||
|
||||
case GenericType():
|
||||
unifier: Unifier = Unifier(self.types)
|
||||
pos: list[Type] = [a[1] for a in positional]
|
||||
kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()}
|
||||
unified: Optional[Type] = unifier.unify_call(callee, pos, kw)
|
||||
if unified is None:
|
||||
if report_errors:
|
||||
pos_str: str = ", ".join(str(t) for t in pos)
|
||||
kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items())
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}",
|
||||
)
|
||||
return None
|
||||
return self._get_call_result(
|
||||
location,
|
||||
unified,
|
||||
positional,
|
||||
keywords,
|
||||
report_errors,
|
||||
)
|
||||
|
||||
case _:
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"{callee} ({callee.__class__.__name__}) is not callable",
|
||||
)
|
||||
return None
|
||||
|
||||
def _are_arguments_valid(
|
||||
self,
|
||||
arguments: list[MappedArgument],
|
||||
report_errors: bool = True,
|
||||
) -> bool:
|
||||
"""Check whether the passed argument types correspond to their matched parameter definitions
|
||||
|
||||
Args:
|
||||
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
||||
"""
|
||||
valid: bool = True
|
||||
for arg in arguments:
|
||||
if not self.is_subtype(arg.type, arg.argument.type):
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
)
|
||||
valid = False
|
||||
return valid
|
||||
|
||||
def _match_overload(
|
||||
self,
|
||||
overloads: list[Type],
|
||||
location: Location,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
report_errors: bool = True,
|
||||
) -> Optional[Function]:
|
||||
"""Try and resolve the appropriate overload for the given arguments
|
||||
|
||||
Args:
|
||||
overloads (list[Type]): the list of possible overloads
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional[Function]: the resolved function signature if it can be
|
||||
determined unambiguously, or `None`.
|
||||
"""
|
||||
candidates: list[OverloadCandidate] = []
|
||||
for overload in overloads:
|
||||
function: Type = unfold_type(overload)
|
||||
if not isinstance(function, Function):
|
||||
if report_errors:
|
||||
self.logger.error(
|
||||
f"Overload is not a function: {overload} is {function}"
|
||||
)
|
||||
continue
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function=function,
|
||||
location=location,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
report_errors=False,
|
||||
)
|
||||
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||
candidates.append(
|
||||
OverloadCandidate(
|
||||
function=function,
|
||||
mapped=mapped,
|
||||
)
|
||||
)
|
||||
|
||||
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||
kw_types: str = ", ".join(
|
||||
f"{name}: {type}" for name, (_, type) in keywords.items()
|
||||
)
|
||||
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||
|
||||
n_candidates: int = len(candidates)
|
||||
|
||||
# Exactly 1 match -> return it
|
||||
if n_candidates == 1:
|
||||
return candidates[0].function
|
||||
|
||||
# No match -> invalid call
|
||||
if n_candidates == 0:
|
||||
overloads_str: str = ", ".join(map(str, overloads))
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"No matching overload in [{overloads_str}] {for_args}",
|
||||
)
|
||||
return None
|
||||
|
||||
# Multiple matches -> see if one <: all others (more specific)
|
||||
for i1, c1 in enumerate(candidates):
|
||||
mapped1: list[MappedArgument] = c1.mapped
|
||||
best_match: bool = True
|
||||
for i2, c2 in enumerate(candidates):
|
||||
if i1 == i2:
|
||||
continue
|
||||
mapped2: list[MappedArgument] = c2.mapped
|
||||
if not self._are_mapped_subtypes(mapped1, mapped2):
|
||||
best_match = False
|
||||
break
|
||||
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
||||
if best_match:
|
||||
return c1.function
|
||||
|
||||
candidates_str: str = ", ".join(
|
||||
str(candidate.function) for candidate in candidates
|
||||
)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Multiple matching overloads {for_args}: {candidates_str}",
|
||||
)
|
||||
return None
|
||||
|
||||
def map_call_arguments(
|
||||
self,
|
||||
function: Function,
|
||||
location: Location,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
report_errors: bool = True,
|
||||
) -> tuple[bool, list[MappedArgument]]:
|
||||
"""Map call arguments to a function's parameters as defined in its signature
|
||||
|
||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||
with the arguments passed at the call site
|
||||
|
||||
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
||||
unless `report_errors` is set to `False`
|
||||
|
||||
Args:
|
||||
function (Function): the function definition
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
||||
the call is valid and the list of mapped arguments
|
||||
"""
|
||||
set_args: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
arg.name for arg in function.pos_args + function.args if arg.required
|
||||
]
|
||||
required_keyword: list[str] = [
|
||||
arg.name for arg in function.kw_args if arg.required
|
||||
]
|
||||
|
||||
mapped: list[MappedArgument] = []
|
||||
|
||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||
mixed_params: list[Function.Argument] = list(function.args)
|
||||
kw_params: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.kw_args
|
||||
}
|
||||
|
||||
valid_call: bool = True
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Argument
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg[0].location, "Too many positional arguments"
|
||||
)
|
||||
valid_call = False
|
||||
break
|
||||
name: str = param.name
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Argument
|
||||
if name not in kw_params:
|
||||
if report_errors:
|
||||
if name in set_args:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
)
|
||||
else:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||
)
|
||||
valid_call = False
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
def join_args(args: list[str]) -> str:
|
||||
args = list(map(lambda a: f"'{a}'", args))
|
||||
if len(args) == 0:
|
||||
return ""
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
args: str = join_args(required_positional)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required positional argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
args: str = join_args(required_keyword)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required keyword argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
return valid_call, mapped
|
||||
|
||||
def _are_mapped_subtypes(
|
||||
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
|
||||
) -> bool:
|
||||
"""Check whether the given argument mappings are subtype/supertype of one another
|
||||
|
||||
This function checks whether the argument mappings `mapped1` are subtypes
|
||||
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
||||
of the corresponding parameter in `mapped2`, `False` is returned.
|
||||
|
||||
This is used to check whether a given overload is
|
||||
a more specific function/ a subtype of another.
|
||||
|
||||
Args:
|
||||
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
||||
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
||||
|
||||
Returns:
|
||||
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
||||
"""
|
||||
by_expr: dict[p.Expr, Type] = {}
|
||||
for arg in mapped1:
|
||||
by_expr[arg.expr] = arg.argument.type
|
||||
|
||||
for arg in mapped2:
|
||||
type2: Type = arg.argument.type
|
||||
type1: Type = by_expr[arg.expr]
|
||||
if not self.is_subtype(type1, type2):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_iterator_type(self, expr: p.Expr, type: Type) -> Optional[Type]:
|
||||
# TODO: lookup __iter__
|
||||
getitem: Optional[Type] = self.types.lookup_member(type, "__getitem__")
|
||||
@@ -1174,14 +813,16 @@ class PythonTyper(
|
||||
|
||||
index: p.Expr = p.LiteralExpr(location=expr.location, value=0)
|
||||
index_type: Type = self.compute_type(index)
|
||||
result: Optional[Type] = self._get_call_result(
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=expr.location,
|
||||
callee=getitem,
|
||||
positional=[(index, index_type)],
|
||||
keywords={},
|
||||
report_errors=False,
|
||||
)
|
||||
return result
|
||||
if not result.is_valid:
|
||||
return None
|
||||
return result.result
|
||||
|
||||
def define_typevar(self, call: p.CallExpr) -> Optional[TypeVar]:
|
||||
def is_kw_true(name: str) -> bool:
|
||||
@@ -1272,6 +913,22 @@ class PythonTyper(
|
||||
pairs.append((key_val, value_val))
|
||||
return True, dict(pairs)
|
||||
|
||||
case p.UnaryExpr(operator=operator, right=operand):
|
||||
is_lit, operand_val = self._get_literal(operand)
|
||||
if not is_lit:
|
||||
return False, None
|
||||
match operator:
|
||||
case ast.UAdd():
|
||||
return True, operand_val
|
||||
case ast.USub():
|
||||
return True, -operand_val
|
||||
case ast.Invert():
|
||||
return True, ~operand_val
|
||||
case ast.Not():
|
||||
return True, not operand_val
|
||||
case _: # Should never be reached
|
||||
return False, None
|
||||
|
||||
case _:
|
||||
return False, None
|
||||
|
||||
@@ -1284,6 +941,40 @@ class PythonTyper(
|
||||
expr, subject_type, base, lit_value
|
||||
)
|
||||
|
||||
case AppliedType(name="list", args=[item_type]) if isinstance(
|
||||
lit_value, list
|
||||
):
|
||||
match subject_type:
|
||||
case AppliedType(name="list", args=[lit_item_type]):
|
||||
evaluated: bool = True
|
||||
for item in lit_value:
|
||||
if not self._evaluate_cast_statically(
|
||||
expr, lit_item_type, item_type, item
|
||||
):
|
||||
evaluated = False
|
||||
return evaluated
|
||||
case _:
|
||||
return False
|
||||
|
||||
case AppliedType(name="dict", args=[key_type, value_type]) if isinstance(
|
||||
lit_value, dict
|
||||
):
|
||||
match subject_type:
|
||||
case AppliedType(name="dict", args=[lit_key_type, lit_value_type]):
|
||||
evaluated: bool = True
|
||||
for key, value in lit_value.items():
|
||||
if not self._evaluate_cast_statically(
|
||||
expr, lit_key_type, key_type, key
|
||||
):
|
||||
evaluated = False
|
||||
if not self._evaluate_cast_statically(
|
||||
expr, lit_value_type, value_type, value
|
||||
):
|
||||
evaluated = False
|
||||
return evaluated
|
||||
case _:
|
||||
return False
|
||||
|
||||
case AppliedType(body=body):
|
||||
return self._evaluate_cast_statically(
|
||||
expr, subject_type, body, lit_value
|
||||
@@ -1298,10 +989,19 @@ class PythonTyper(
|
||||
|
||||
evaluator = Evaluator(self.types)
|
||||
evaluator.set_value("_", lit_value)
|
||||
res = evaluator.evaluate(constraint)
|
||||
printer = MidasPrinter()
|
||||
constraint_str: str = printer.print(constraint)
|
||||
res: Any
|
||||
try:
|
||||
res = evaluator.evaluate(constraint)
|
||||
except Exception as e:
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"An error occurred while checking constraint '{constraint_str}' on the value {lit_value!r}: {e}",
|
||||
)
|
||||
return False
|
||||
|
||||
if not res:
|
||||
printer = MidasPrinter()
|
||||
constraint_str: str = printer.print(constraint)
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Value {lit_value!r} does not fit constraint '{constraint_str}'",
|
||||
|
||||
@@ -46,8 +46,8 @@ class MidasLexer(Lexer):
|
||||
self.add_token(TokenType.UNDERSCORE)
|
||||
case "-" if self.match(">"):
|
||||
self.add_token(TokenType.ARROW)
|
||||
# case "+":
|
||||
# self.add_token(TokenType.PLUS)
|
||||
case "+":
|
||||
self.add_token(TokenType.PLUS)
|
||||
case "-":
|
||||
self.add_token(TokenType.MINUS)
|
||||
case "*":
|
||||
|
||||
@@ -25,7 +25,7 @@ class TokenType(Enum):
|
||||
DOT = auto()
|
||||
|
||||
# Operators
|
||||
# PLUS = auto()
|
||||
PLUS = auto()
|
||||
MINUS = auto()
|
||||
STAR = auto()
|
||||
SLASH = auto()
|
||||
|
||||
@@ -361,13 +361,35 @@ class MidasParser(Parser):
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.unary()
|
||||
expr: Expr = self.term()
|
||||
while self.match(
|
||||
TokenType.LESS,
|
||||
TokenType.LESS_EQUAL,
|
||||
TokenType.GREATER,
|
||||
TokenType.GREATER_EQUAL,
|
||||
):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.term()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = BinaryExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
return expr
|
||||
|
||||
def term(self) -> Expr:
|
||||
expr: Expr = self.factor()
|
||||
while self.match(TokenType.PLUS, TokenType.MINUS):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.factor()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = BinaryExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
return expr
|
||||
|
||||
def factor(self) -> Expr:
|
||||
expr: Expr = self.unary()
|
||||
while self.match(TokenType.STAR, TokenType.SLASH):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.unary()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
|
||||
Reference in New Issue
Block a user