Compare commits
10 Commits
c81287df7f
...
e2abc04fe4
| Author | SHA1 | Date | |
|---|---|---|---|
|
e2abc04fe4
|
|||
|
a4016b55ce
|
|||
|
1ea5da7024
|
|||
|
a017a8cf1f
|
|||
|
8fc5ab623e
|
|||
|
14007db846
|
|||
|
6ad2ce4b68
|
|||
|
9a276c34c7
|
|||
|
6e717a3f9e
|
|||
|
77aadfa264
|
484
midas/checker/dispatcher.py
Normal file
484
midas/checker/dispatcher.py
Normal file
@@ -0,0 +1,484 @@
|
|||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Generic, Optional, Protocol, TypeVar, Union
|
||||||
|
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.reporter import FileReporter
|
||||||
|
from midas.checker.types import (
|
||||||
|
AppliedType,
|
||||||
|
DerivedType,
|
||||||
|
Function,
|
||||||
|
GenericType,
|
||||||
|
OverloadedFunction,
|
||||||
|
Type,
|
||||||
|
UnknownType,
|
||||||
|
)
|
||||||
|
from midas.checker.unifier import Unifier
|
||||||
|
|
||||||
|
|
||||||
|
class HasLocation(Protocol):
|
||||||
|
@property
|
||||||
|
def location(self) -> Location: ...
|
||||||
|
|
||||||
|
|
||||||
|
E = TypeVar("E", bound=HasLocation)
|
||||||
|
|
||||||
|
TypedExpr = tuple[E, Type]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class MappedArgument(Generic[E]):
|
||||||
|
expr: E
|
||||||
|
type: Type
|
||||||
|
argument: Function.Argument
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class OverloadCandidate:
|
||||||
|
function: Function
|
||||||
|
mapped: list[MappedArgument]
|
||||||
|
|
||||||
|
|
||||||
|
class CallError(StrEnum):
|
||||||
|
INVALID_ARGS = "Invalid arguments"
|
||||||
|
NO_MATCHING_OVERLOAD = "No matching overload"
|
||||||
|
IMPOSSIBLE_UNIFICATION = "Parameters unification failed"
|
||||||
|
NOT_CALLABLE = "Not callable"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class CallResult:
|
||||||
|
error: Optional[CallError] = None
|
||||||
|
result: Type = UnknownType()
|
||||||
|
message: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
return self.error is None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def error_message(self) -> str:
|
||||||
|
if self.message is not None:
|
||||||
|
return self.message
|
||||||
|
if self.error is not None:
|
||||||
|
return str(self.error)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class CallDispatcher(Generic[E]):
|
||||||
|
def __init__(self, types: TypesRegistry, reporter: FileReporter) -> None:
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
self.reporter: FileReporter = reporter
|
||||||
|
self.logger: logging.Logger = logging.getLogger("CallDispatcher")
|
||||||
|
|
||||||
|
def set_reporter(self, reporter: FileReporter):
|
||||||
|
self.reporter = reporter
|
||||||
|
|
||||||
|
def get_result(
|
||||||
|
self,
|
||||||
|
location: Location,
|
||||||
|
callee: Type,
|
||||||
|
positional: list[TypedExpr[E]],
|
||||||
|
keywords: dict[str, TypedExpr[E]],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> CallResult:
|
||||||
|
"""Get the result type of a function call
|
||||||
|
|
||||||
|
If the function has overloads, the function will try to resolve the
|
||||||
|
appropriate signature.
|
||||||
|
Argument types are matched to the defined parameters.
|
||||||
|
The function doesn't take the raw expression as a parameter to accommodate
|
||||||
|
for desugared calls such as for operators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location (Location): the call location
|
||||||
|
callee (Type): the called function
|
||||||
|
positional (list[TypedExpr]): the list positional arguments
|
||||||
|
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the return type of the call, or `None` if either
|
||||||
|
the call is invalid or no overload matched the arguments uniquely
|
||||||
|
"""
|
||||||
|
match callee:
|
||||||
|
case Function() as function:
|
||||||
|
valid: bool
|
||||||
|
mapped: list[MappedArgument[E]]
|
||||||
|
valid, mapped = self.map_call_arguments(
|
||||||
|
function, location, positional, keywords
|
||||||
|
)
|
||||||
|
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
||||||
|
if not valid:
|
||||||
|
return CallResult(error=CallError.INVALID_ARGS)
|
||||||
|
return CallResult(result=function.returns)
|
||||||
|
|
||||||
|
case OverloadedFunction(overloads=overloads):
|
||||||
|
res = self._match_overload(
|
||||||
|
overloads, location, positional, keywords, report_errors
|
||||||
|
)
|
||||||
|
if res[0] is None:
|
||||||
|
return CallResult(
|
||||||
|
error=CallError.NO_MATCHING_OVERLOAD,
|
||||||
|
message=res[1],
|
||||||
|
)
|
||||||
|
return CallResult(result=res[0].returns)
|
||||||
|
|
||||||
|
case AppliedType(body=body):
|
||||||
|
return self.get_result(
|
||||||
|
location, body, positional, keywords, report_errors
|
||||||
|
)
|
||||||
|
|
||||||
|
case UnknownType():
|
||||||
|
return CallResult(result=UnknownType())
|
||||||
|
|
||||||
|
case DerivedType(type=base):
|
||||||
|
return self.get_result(
|
||||||
|
location, base, positional, keywords, report_errors
|
||||||
|
)
|
||||||
|
|
||||||
|
case GenericType():
|
||||||
|
unifier: Unifier = Unifier(self.types)
|
||||||
|
pos: list[Type] = [a[1] for a in positional]
|
||||||
|
kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()}
|
||||||
|
unified: Optional[Type] = unifier.unify_call(callee, pos, kw)
|
||||||
|
if unified is None:
|
||||||
|
pos_str: str = ", ".join(str(t) for t in pos)
|
||||||
|
kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items())
|
||||||
|
message: str = (
|
||||||
|
f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}"
|
||||||
|
)
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(location, message)
|
||||||
|
return CallResult(
|
||||||
|
error=CallError.IMPOSSIBLE_UNIFICATION,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
return self.get_result(
|
||||||
|
location,
|
||||||
|
unified,
|
||||||
|
positional,
|
||||||
|
keywords,
|
||||||
|
report_errors,
|
||||||
|
)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
message: str = f"{callee} ({callee.__class__.__name__}) is not callable"
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(location, message)
|
||||||
|
return CallResult(
|
||||||
|
error=CallError.NOT_CALLABLE,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _unwrap_function(
|
||||||
|
self,
|
||||||
|
callee: Type,
|
||||||
|
positional: list[TypedExpr[E]],
|
||||||
|
keywords: dict[str, TypedExpr[E]],
|
||||||
|
) -> Union[tuple[Function, None], tuple[None, CallError]]:
|
||||||
|
match callee:
|
||||||
|
case DerivedType(type=base):
|
||||||
|
return self._unwrap_function(base, positional, keywords)
|
||||||
|
|
||||||
|
case GenericType():
|
||||||
|
unifier: Unifier = Unifier(self.types)
|
||||||
|
unified: Optional[Type] = unifier.unify_call(
|
||||||
|
callee,
|
||||||
|
[a[1] for a in positional],
|
||||||
|
{k: v[1] for k, v in keywords.items()},
|
||||||
|
)
|
||||||
|
if unified is None:
|
||||||
|
return None, CallError.IMPOSSIBLE_UNIFICATION
|
||||||
|
return self._unwrap_function(unified, positional, keywords)
|
||||||
|
|
||||||
|
case Function():
|
||||||
|
return callee, None
|
||||||
|
|
||||||
|
case AppliedType(body=body):
|
||||||
|
return self._unwrap_function(body, positional, keywords)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
return None, CallError.NOT_CALLABLE
|
||||||
|
|
||||||
|
def _are_arguments_valid(
|
||||||
|
self,
|
||||||
|
arguments: list[MappedArgument[E]],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
"""Check whether the passed argument types correspond to their matched parameter definitions
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
||||||
|
"""
|
||||||
|
valid: bool = True
|
||||||
|
for arg in arguments:
|
||||||
|
if not self.types.is_subtype(arg.type, arg.argument.type):
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
arg.expr.location,
|
||||||
|
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
return valid
|
||||||
|
|
||||||
|
def _match_overload(
|
||||||
|
self,
|
||||||
|
overloads: list[Type],
|
||||||
|
location: Location,
|
||||||
|
positional: list[TypedExpr[E]],
|
||||||
|
keywords: dict[str, TypedExpr[E]],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> Union[tuple[Function, None], tuple[None, str]]:
|
||||||
|
"""Try and resolve the appropriate overload for the given arguments
|
||||||
|
|
||||||
|
Args:
|
||||||
|
overloads (list[Type]): the list of possible overloads
|
||||||
|
location (Location): the call location
|
||||||
|
positional (list[TypedExpr]): the list of positional arguments
|
||||||
|
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Function]: the resolved function signature if it can be
|
||||||
|
determined unambiguously, or `None`.
|
||||||
|
"""
|
||||||
|
candidates: list[OverloadCandidate] = []
|
||||||
|
errors: list[CallError] = []
|
||||||
|
for overload in overloads:
|
||||||
|
function, unwrap_error = self._unwrap_function(
|
||||||
|
overload, positional, keywords
|
||||||
|
)
|
||||||
|
if function is None:
|
||||||
|
errors.append(unwrap_error) # type: ignore
|
||||||
|
continue
|
||||||
|
|
||||||
|
valid, mapped = self.map_call_arguments(
|
||||||
|
function=function,
|
||||||
|
location=location,
|
||||||
|
positional=positional,
|
||||||
|
keywords=keywords,
|
||||||
|
report_errors=False,
|
||||||
|
)
|
||||||
|
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||||
|
candidates.append(
|
||||||
|
OverloadCandidate(
|
||||||
|
function=function,
|
||||||
|
mapped=mapped,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||||
|
kw_types: str = ", ".join(
|
||||||
|
f"{name}: {type}" for name, (_, type) in keywords.items()
|
||||||
|
)
|
||||||
|
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||||
|
|
||||||
|
n_candidates: int = len(candidates)
|
||||||
|
|
||||||
|
# Exactly 1 match -> return it
|
||||||
|
if n_candidates == 1:
|
||||||
|
return candidates[0].function, None
|
||||||
|
|
||||||
|
# No match -> invalid call
|
||||||
|
if n_candidates == 0:
|
||||||
|
overloads_str: str = ", ".join(map(str, overloads))
|
||||||
|
errors_str: str = ", ".join(errors)
|
||||||
|
message: str = (
|
||||||
|
f"No matching overload in [{overloads_str}] {for_args} (errors: {errors_str})"
|
||||||
|
)
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(location, message)
|
||||||
|
return None, message
|
||||||
|
|
||||||
|
# Multiple matches -> see if one <: all others (more specific)
|
||||||
|
for i1, c1 in enumerate(candidates):
|
||||||
|
mapped1: list[MappedArgument[E]] = c1.mapped
|
||||||
|
best_match: bool = True
|
||||||
|
for i2, c2 in enumerate(candidates):
|
||||||
|
if i1 == i2:
|
||||||
|
continue
|
||||||
|
mapped2: list[MappedArgument[E]] = c2.mapped
|
||||||
|
if not self._are_mapped_subtypes(mapped1, mapped2):
|
||||||
|
best_match = False
|
||||||
|
break
|
||||||
|
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
||||||
|
if best_match:
|
||||||
|
return c1.function, None
|
||||||
|
|
||||||
|
candidates_str: str = ", ".join(
|
||||||
|
str(candidate.function) for candidate in candidates
|
||||||
|
)
|
||||||
|
message: str = f"Multiple matching overloads {for_args}: {candidates_str}"
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(location, message)
|
||||||
|
return None, message
|
||||||
|
|
||||||
|
def map_call_arguments(
|
||||||
|
self,
|
||||||
|
function: Function,
|
||||||
|
location: Location,
|
||||||
|
positional: list[TypedExpr[E]],
|
||||||
|
keywords: dict[str, TypedExpr[E]],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> tuple[bool, list[MappedArgument]]:
|
||||||
|
"""Map call arguments to a function's parameters as defined in its signature
|
||||||
|
|
||||||
|
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||||
|
with the arguments passed at the call site
|
||||||
|
|
||||||
|
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
||||||
|
unless `report_errors` is set to `False`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function (Function): the function definition
|
||||||
|
location (Location): the call location
|
||||||
|
positional (list[TypedExpr]): the list of positional arguments
|
||||||
|
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
||||||
|
the call is valid and the list of mapped arguments
|
||||||
|
"""
|
||||||
|
set_args: set[str] = set()
|
||||||
|
|
||||||
|
required_positional: list[str] = [
|
||||||
|
arg.name for arg in function.pos_args + function.args if arg.required
|
||||||
|
]
|
||||||
|
required_keyword: list[str] = [
|
||||||
|
arg.name for arg in function.kw_args if arg.required
|
||||||
|
]
|
||||||
|
|
||||||
|
mapped: list[MappedArgument[E]] = []
|
||||||
|
|
||||||
|
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||||
|
mixed_params: list[Function.Argument] = list(function.args)
|
||||||
|
kw_params: dict[str, Function.Argument] = {
|
||||||
|
arg.name: arg for arg in function.kw_args
|
||||||
|
}
|
||||||
|
|
||||||
|
valid_call: bool = True
|
||||||
|
|
||||||
|
# TODO: handle *args and **kwargs sinks
|
||||||
|
for arg in positional:
|
||||||
|
param: Function.Argument
|
||||||
|
if len(pos_params) != 0:
|
||||||
|
param = pos_params.pop(0)
|
||||||
|
elif len(mixed_params) != 0:
|
||||||
|
param = mixed_params.pop(0)
|
||||||
|
else:
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, "Too many positional arguments"
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
break
|
||||||
|
name: str = param.name
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||||
|
for name, arg in keywords.items():
|
||||||
|
param: Function.Argument
|
||||||
|
if name not in kw_params:
|
||||||
|
if report_errors:
|
||||||
|
if name in set_args:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, f"Multiple values for argument '{name}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
continue
|
||||||
|
param = kw_params.pop(name)
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def join_args(args: list[str]) -> str:
|
||||||
|
args = list(map(lambda a: f"'{a}'", args))
|
||||||
|
if len(args) == 0:
|
||||||
|
return ""
|
||||||
|
if len(args) == 1:
|
||||||
|
return args[0]
|
||||||
|
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||||
|
|
||||||
|
if len(required_positional) != 0:
|
||||||
|
plural: str = "" if len(required_positional) == 1 else "s"
|
||||||
|
args: str = join_args(required_positional)
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Missing required positional argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
|
||||||
|
if len(required_keyword) != 0:
|
||||||
|
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||||
|
args: str = join_args(required_keyword)
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Missing required keyword argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
|
||||||
|
return valid_call, mapped
|
||||||
|
|
||||||
|
def _are_mapped_subtypes(
|
||||||
|
self, mapped1: list[MappedArgument[E]], mapped2: list[MappedArgument[E]]
|
||||||
|
) -> bool:
|
||||||
|
"""Check whether the given argument mappings are subtype/supertype of one another
|
||||||
|
|
||||||
|
This function checks whether the argument mappings `mapped1` are subtypes
|
||||||
|
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
||||||
|
of the corresponding parameter in `mapped2`, `False` is returned.
|
||||||
|
|
||||||
|
This is used to check whether a given overload is
|
||||||
|
a more specific function/ a subtype of another.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
||||||
|
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
||||||
|
"""
|
||||||
|
by_expr: dict[E, Type] = {}
|
||||||
|
for arg in mapped1:
|
||||||
|
by_expr[arg.expr] = arg.argument.type
|
||||||
|
|
||||||
|
for arg in mapped2:
|
||||||
|
type2: Type = arg.argument.type
|
||||||
|
type1: Type = by_expr[arg.expr]
|
||||||
|
if not self.types.is_subtype(type1, type2):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
@@ -6,13 +6,13 @@ from typing import Optional
|
|||||||
import midas.ast.midas as m
|
import midas.ast.midas as m
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.checker.builtins import define_builtins
|
from midas.checker.builtins import define_builtins
|
||||||
|
from midas.checker.dispatcher import CallDispatcher, CallResult
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
|
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
|
||||||
from midas.checker.preamble import Preamble
|
from midas.checker.preamble import Preamble
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.reporter import FileReporter, Reporter
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AppliedType,
|
|
||||||
ColumnType,
|
ColumnType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
@@ -21,12 +21,10 @@ from midas.checker.types import (
|
|||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
|
||||||
Predicate,
|
Predicate,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnknownType,
|
UnknownType,
|
||||||
unfold_type,
|
|
||||||
)
|
)
|
||||||
from midas.checker.variance import VarianceInferrer
|
from midas.checker.variance import VarianceInferrer
|
||||||
from midas.lexer.midas import MidasLexer
|
from midas.lexer.midas import MidasLexer
|
||||||
@@ -41,9 +39,6 @@ class TypedParamSpec:
|
|||||||
kw: list[Function.Argument]
|
kw: list[Function.Argument]
|
||||||
|
|
||||||
|
|
||||||
TypedExpr = tuple[m.Expr, Type]
|
|
||||||
|
|
||||||
|
|
||||||
class ReturnException(Exception):
|
class ReturnException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -67,8 +62,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
|||||||
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||||
self.logger: logging.Logger = logging.getLogger("MidasTyper")
|
self.logger: logging.Logger = logging.getLogger("MidasTyper")
|
||||||
self.reporter: FileReporter = reporter.for_file(None)
|
self.reporter: FileReporter = reporter.for_file(None)
|
||||||
|
|
||||||
self.types: TypesRegistry = types
|
self.types: TypesRegistry = types
|
||||||
|
self.dispatcher: CallDispatcher[m.Expr] = CallDispatcher[m.Expr](
|
||||||
|
self.types, self.reporter
|
||||||
|
)
|
||||||
|
|
||||||
self._local_variables: dict[str, TypeVar] = {}
|
self._local_variables: dict[str, TypeVar] = {}
|
||||||
|
|
||||||
self._predicate_params: dict[str, Type] = {}
|
self._predicate_params: dict[str, Type] = {}
|
||||||
@@ -83,8 +81,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
|||||||
|
|
||||||
self._preamble: Environment = Preamble(self.types)
|
self._preamble: Environment = Preamble(self.types)
|
||||||
|
|
||||||
|
def set_reporter(self, reporter: FileReporter):
|
||||||
|
self.reporter = reporter
|
||||||
|
self.dispatcher.set_reporter(reporter)
|
||||||
|
|
||||||
def process(self, source: str, path: Optional[str]):
|
def process(self, source: str, path: Optional[str]):
|
||||||
self.reporter = self.reporter.for_file(path)
|
reporter: FileReporter = self.reporter.for_file(path)
|
||||||
|
self.set_reporter(reporter)
|
||||||
|
|
||||||
lexer: MidasLexer = MidasLexer(source)
|
lexer: MidasLexer = MidasLexer(source)
|
||||||
tokens: list[Token] = lexer.process()
|
tokens: list[Token] = lexer.process()
|
||||||
parser: MidasParser = MidasParser(tokens)
|
parser: MidasParser = MidasParser(tokens)
|
||||||
@@ -259,13 +263,13 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
|||||||
)
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
result: Optional[Type] = self._get_call_result(
|
result: CallResult = self.dispatcher.get_result(
|
||||||
location,
|
location=location,
|
||||||
operation,
|
callee=operation,
|
||||||
[(right_expr, right)],
|
positional=[(right_expr, right)],
|
||||||
{},
|
keywords={},
|
||||||
)
|
)
|
||||||
return result or UnknownType()
|
return result.result
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
|
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
|
||||||
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
|
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
|
||||||
@@ -285,31 +289,29 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
|||||||
)
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
result: Optional[Type] = self._get_call_result(
|
result: CallResult = self.dispatcher.get_result(
|
||||||
expr.location,
|
location=expr.location,
|
||||||
operation,
|
callee=operation,
|
||||||
[],
|
positional=[],
|
||||||
{},
|
keywords={},
|
||||||
)
|
)
|
||||||
return result or UnknownType()
|
return result.result
|
||||||
|
|
||||||
def visit_call_expr(self, expr: m.CallExpr) -> Type:
|
def visit_call_expr(self, expr: m.CallExpr) -> Type:
|
||||||
callee: Type = expr.callee.accept(self)
|
callee: Type = expr.callee.accept(self)
|
||||||
positional: list[TypedExpr] = [
|
positional: list[tuple[m.Expr, Type]] = [
|
||||||
(arg, self.type_of(arg)) for arg in expr.arguments
|
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||||
]
|
]
|
||||||
keywords: dict[str, TypedExpr] = {
|
keywords: dict[str, tuple[m.Expr, Type]] = {
|
||||||
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
||||||
}
|
}
|
||||||
return (
|
result: CallResult = self.dispatcher.get_result(
|
||||||
self._get_call_result(
|
location=expr.location,
|
||||||
expr.location,
|
callee=callee,
|
||||||
callee,
|
positional=positional,
|
||||||
positional,
|
keywords=keywords,
|
||||||
keywords,
|
|
||||||
)
|
|
||||||
or UnknownType()
|
|
||||||
)
|
)
|
||||||
|
return result.result
|
||||||
|
|
||||||
def visit_get_expr(self, expr: m.GetExpr) -> Type:
|
def visit_get_expr(self, expr: m.GetExpr) -> Type:
|
||||||
object: Type = expr.expr.accept(self)
|
object: Type = expr.expr.accept(self)
|
||||||
@@ -433,343 +435,3 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
|||||||
self._local_variables[name] = var
|
self._local_variables[name] = var
|
||||||
vars.append(var)
|
vars.append(var)
|
||||||
return vars
|
return vars
|
||||||
|
|
||||||
def _get_call_result(
|
|
||||||
self,
|
|
||||||
location: Location,
|
|
||||||
callee: Type,
|
|
||||||
positional: list[TypedExpr],
|
|
||||||
keywords: dict[str, TypedExpr],
|
|
||||||
report_errors: bool = True,
|
|
||||||
) -> Optional[Type]:
|
|
||||||
"""Get the result type of a function call
|
|
||||||
|
|
||||||
If the function has overloads, the function will try to resolve the
|
|
||||||
appropriate signature.
|
|
||||||
Argument types are matched to the defined parameters.
|
|
||||||
The function doesn't take the raw expression as a parameter to accommodate
|
|
||||||
for desugared calls such as for operators.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
location (Location): the call location
|
|
||||||
callee (Type): the called function
|
|
||||||
positional (list[TypedExpr]): the list positional arguments
|
|
||||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
|
||||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Type: the return type of the call, or `None` if either
|
|
||||||
the call is invalid or no overload matched the arguments uniquely
|
|
||||||
"""
|
|
||||||
match callee:
|
|
||||||
case Function() as function:
|
|
||||||
valid: bool
|
|
||||||
mapped: list[MappedArgument]
|
|
||||||
valid, mapped = self.map_call_arguments(
|
|
||||||
function, location, positional, keywords
|
|
||||||
)
|
|
||||||
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
|
||||||
if not valid:
|
|
||||||
return None
|
|
||||||
return function.returns
|
|
||||||
|
|
||||||
case OverloadedFunction(overloads=overloads):
|
|
||||||
function = self._match_overload(
|
|
||||||
overloads, location, positional, keywords, report_errors
|
|
||||||
)
|
|
||||||
if function is None:
|
|
||||||
return None
|
|
||||||
return function.returns
|
|
||||||
|
|
||||||
case AppliedType(body=body):
|
|
||||||
return self._get_call_result(
|
|
||||||
location, body, positional, keywords, report_errors
|
|
||||||
)
|
|
||||||
|
|
||||||
case UnknownType():
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
case _:
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(location, f"{callee} is not callable")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _are_arguments_valid(
|
|
||||||
self,
|
|
||||||
arguments: list[MappedArgument],
|
|
||||||
report_errors: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""Check whether the passed argument types correspond to their matched parameter definitions
|
|
||||||
|
|
||||||
Args:
|
|
||||||
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
|
||||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
|
||||||
"""
|
|
||||||
valid: bool = True
|
|
||||||
for arg in arguments:
|
|
||||||
if not self.types.is_subtype(arg.type, arg.argument.type):
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
arg.expr.location,
|
|
||||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def _match_overload(
|
|
||||||
self,
|
|
||||||
overloads: list[Type],
|
|
||||||
location: Location,
|
|
||||||
positional: list[TypedExpr],
|
|
||||||
keywords: dict[str, TypedExpr],
|
|
||||||
report_errors: bool = True,
|
|
||||||
) -> Optional[Function]:
|
|
||||||
"""Try and resolve the appropriate overload for the given arguments
|
|
||||||
|
|
||||||
Args:
|
|
||||||
overloads (list[Type]): the list of possible overloads
|
|
||||||
location (Location): the call location
|
|
||||||
positional (list[TypedExpr]): the list of positional arguments
|
|
||||||
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
|
||||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Function]: the resolved function signature if it can be
|
|
||||||
determined unambiguously, or `None`.
|
|
||||||
"""
|
|
||||||
candidates: list[OverloadCandidate] = []
|
|
||||||
for overload in overloads:
|
|
||||||
function: Type = unfold_type(overload)
|
|
||||||
if not isinstance(function, Function):
|
|
||||||
if report_errors:
|
|
||||||
self.logger.error(
|
|
||||||
f"Overload is not a function: {overload} is {function}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
valid, mapped = self.map_call_arguments(
|
|
||||||
function=function,
|
|
||||||
location=location,
|
|
||||||
positional=positional,
|
|
||||||
keywords=keywords,
|
|
||||||
report_errors=False,
|
|
||||||
)
|
|
||||||
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
|
||||||
candidates.append(
|
|
||||||
OverloadCandidate(
|
|
||||||
function=function,
|
|
||||||
mapped=mapped,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
pos_types: str = ", ".join(str(type) for _, type in positional)
|
|
||||||
kw_types: str = ", ".join(
|
|
||||||
f"{name}: {type}" for name, (_, type) in keywords.items()
|
|
||||||
)
|
|
||||||
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
|
||||||
|
|
||||||
n_candidates: int = len(candidates)
|
|
||||||
|
|
||||||
# Exactly 1 match -> return it
|
|
||||||
if n_candidates == 1:
|
|
||||||
return candidates[0].function
|
|
||||||
|
|
||||||
# No match -> invalid call
|
|
||||||
if n_candidates == 0:
|
|
||||||
overloads_str: str = ", ".join(map(str, overloads))
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
location,
|
|
||||||
f"No matching overload in [{overloads_str}] {for_args}",
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Multiple matches -> see if one <: all others (more specific)
|
|
||||||
for i1, c1 in enumerate(candidates):
|
|
||||||
mapped1: list[MappedArgument] = c1.mapped
|
|
||||||
best_match: bool = True
|
|
||||||
for i2, c2 in enumerate(candidates):
|
|
||||||
if i1 == i2:
|
|
||||||
continue
|
|
||||||
mapped2: list[MappedArgument] = c2.mapped
|
|
||||||
if not self._are_mapped_subtypes(mapped1, mapped2):
|
|
||||||
best_match = False
|
|
||||||
break
|
|
||||||
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
|
||||||
if best_match:
|
|
||||||
return c1.function
|
|
||||||
|
|
||||||
candidates_str: str = ", ".join(
|
|
||||||
str(candidate.function) for candidate in candidates
|
|
||||||
)
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
location,
|
|
||||||
f"Multiple matching overloads {for_args}: {candidates_str}",
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def map_call_arguments(
|
|
||||||
self,
|
|
||||||
function: Function,
|
|
||||||
location: Location,
|
|
||||||
positional: list[TypedExpr],
|
|
||||||
keywords: dict[str, TypedExpr],
|
|
||||||
report_errors: bool = True,
|
|
||||||
) -> tuple[bool, list[MappedArgument]]:
|
|
||||||
"""Map call arguments to a function's parameters as defined in its signature
|
|
||||||
|
|
||||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
|
||||||
with the arguments passed at the call site
|
|
||||||
|
|
||||||
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
|
||||||
unless `report_errors` is set to `False`
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function (Function): the function definition
|
|
||||||
location (Location): the call location
|
|
||||||
positional (list[TypedExpr]): the list of positional arguments
|
|
||||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
|
||||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
|
||||||
the call is valid and the list of mapped arguments
|
|
||||||
"""
|
|
||||||
set_args: set[str] = set()
|
|
||||||
|
|
||||||
required_positional: list[str] = [
|
|
||||||
arg.name for arg in function.pos_args + function.args if arg.required
|
|
||||||
]
|
|
||||||
required_keyword: list[str] = [
|
|
||||||
arg.name for arg in function.kw_args if arg.required
|
|
||||||
]
|
|
||||||
|
|
||||||
mapped: list[MappedArgument] = []
|
|
||||||
|
|
||||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
|
||||||
mixed_params: list[Function.Argument] = list(function.args)
|
|
||||||
kw_params: dict[str, Function.Argument] = {
|
|
||||||
arg.name: arg for arg in function.kw_args
|
|
||||||
}
|
|
||||||
|
|
||||||
valid_call: bool = True
|
|
||||||
|
|
||||||
# TODO: handle *args and **kwargs sinks
|
|
||||||
for arg in positional:
|
|
||||||
param: Function.Argument
|
|
||||||
if len(pos_params) != 0:
|
|
||||||
param = pos_params.pop(0)
|
|
||||||
elif len(mixed_params) != 0:
|
|
||||||
param = mixed_params.pop(0)
|
|
||||||
else:
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
arg[0].location, "Too many positional arguments"
|
|
||||||
)
|
|
||||||
valid_call = False
|
|
||||||
break
|
|
||||||
name: str = param.name
|
|
||||||
if name in required_positional:
|
|
||||||
required_positional.remove(name)
|
|
||||||
if name in required_keyword:
|
|
||||||
required_keyword.remove(name)
|
|
||||||
set_args.add(name)
|
|
||||||
mapped.append(
|
|
||||||
MappedArgument(
|
|
||||||
expr=arg[0],
|
|
||||||
type=arg[1],
|
|
||||||
argument=param,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
|
||||||
for name, arg in keywords.items():
|
|
||||||
param: Function.Argument
|
|
||||||
if name not in kw_params:
|
|
||||||
if report_errors:
|
|
||||||
if name in set_args:
|
|
||||||
self.reporter.error(
|
|
||||||
arg[0].location, f"Multiple values for argument '{name}'"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.reporter.error(
|
|
||||||
arg[0].location, f"Unknown keyword argument '{name}'"
|
|
||||||
)
|
|
||||||
valid_call = False
|
|
||||||
continue
|
|
||||||
param = kw_params.pop(name)
|
|
||||||
if name in required_positional:
|
|
||||||
required_positional.remove(name)
|
|
||||||
if name in required_keyword:
|
|
||||||
required_keyword.remove(name)
|
|
||||||
set_args.add(name)
|
|
||||||
mapped.append(
|
|
||||||
MappedArgument(
|
|
||||||
expr=arg[0],
|
|
||||||
type=arg[1],
|
|
||||||
argument=param,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def join_args(args: list[str]) -> str:
|
|
||||||
args = list(map(lambda a: f"'{a}'", args))
|
|
||||||
if len(args) == 0:
|
|
||||||
return ""
|
|
||||||
if len(args) == 1:
|
|
||||||
return args[0]
|
|
||||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
|
||||||
|
|
||||||
if len(required_positional) != 0:
|
|
||||||
plural: str = "" if len(required_positional) == 1 else "s"
|
|
||||||
args: str = join_args(required_positional)
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
location,
|
|
||||||
f"Missing required positional argument{plural}: {args}",
|
|
||||||
)
|
|
||||||
valid_call = False
|
|
||||||
|
|
||||||
if len(required_keyword) != 0:
|
|
||||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
|
||||||
args: str = join_args(required_keyword)
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
location,
|
|
||||||
f"Missing required keyword argument{plural}: {args}",
|
|
||||||
)
|
|
||||||
valid_call = False
|
|
||||||
|
|
||||||
return valid_call, mapped
|
|
||||||
|
|
||||||
def _are_mapped_subtypes(
|
|
||||||
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
|
|
||||||
) -> bool:
|
|
||||||
"""Check whether the given argument mappings are subtype/supertype of one another
|
|
||||||
|
|
||||||
This function checks whether the argument mappings `mapped1` are subtypes
|
|
||||||
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
|
||||||
of the corresponding parameter in `mapped2`, `False` is returned.
|
|
||||||
|
|
||||||
This is used to check whether a given overload is
|
|
||||||
a more specific function/ a subtype of another.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
|
||||||
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
|
||||||
"""
|
|
||||||
by_expr: dict[m.Expr, Type] = {}
|
|
||||||
for arg in mapped1:
|
|
||||||
by_expr[arg.expr] = arg.argument.type
|
|
||||||
|
|
||||||
for arg in mapped2:
|
|
||||||
type2: Type = arg.argument.type
|
|
||||||
type1: Type = by_expr[arg.expr]
|
|
||||||
if not self.types.is_subtype(type1, type2):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = {
|
|||||||
|
|
||||||
|
|
||||||
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
|
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
|
||||||
# TokenType.PLUS: "__add__",
|
TokenType.PLUS: "__add__",
|
||||||
TokenType.MINUS: "__sub__",
|
TokenType.MINUS: "__sub__",
|
||||||
TokenType.STAR: "__mul__",
|
TokenType.STAR: "__mul__",
|
||||||
TokenType.SLASH: "__truediv__",
|
TokenType.SLASH: "__truediv__",
|
||||||
|
|||||||
@@ -3,7 +3,15 @@ from typing import Any, Callable, Optional
|
|||||||
|
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.types import Function, GenericType, TopType, Type, TypeVar, UnitType
|
from midas.checker.types import (
|
||||||
|
Function,
|
||||||
|
GenericType,
|
||||||
|
OverloadedFunction,
|
||||||
|
TopType,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
UnitType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -70,6 +78,36 @@ class Preamble(Environment):
|
|||||||
returns=self._types.get_type("int"),
|
returns=self._types.get_type("int"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
T = TypeVar(name="T", bound=None)
|
||||||
|
self._def_overloads(
|
||||||
|
name="max",
|
||||||
|
py_function=max,
|
||||||
|
signatures=[
|
||||||
|
(
|
||||||
|
[Param("arg1", T), Param("arg2", T)],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
T,
|
||||||
|
[T],
|
||||||
|
),
|
||||||
|
([Param("iterable", self._list_of(T))], [], [], T, [T]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self._def_overloads(
|
||||||
|
name="min",
|
||||||
|
py_function=min,
|
||||||
|
signatures=[
|
||||||
|
(
|
||||||
|
[Param("arg1", T), Param("arg2", T)],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
T,
|
||||||
|
[T],
|
||||||
|
),
|
||||||
|
([Param("iterable", self._list_of(T))], [], [], T, [T]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
def _list_of(self, item_type: Type) -> Type:
|
def _list_of(self, item_type: Type) -> Type:
|
||||||
return self._types.apply_generic(self._types.get_type("list"), [item_type])
|
return self._types.apply_generic(self._types.get_type("list"), [item_type])
|
||||||
|
|
||||||
@@ -142,5 +180,31 @@ class Preamble(Environment):
|
|||||||
if py_function is not None:
|
if py_function is not None:
|
||||||
self._python_funcs[name] = py_function
|
self._python_funcs[name] = py_function
|
||||||
|
|
||||||
|
def _def_overloads(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
signatures: list[
|
||||||
|
tuple[list[Param], list[Param], list[Param], Type, list[TypeVar]]
|
||||||
|
],
|
||||||
|
py_function: Optional[Callable[..., Any]] = None,
|
||||||
|
):
|
||||||
|
overloads: list[Type] = []
|
||||||
|
for pos, mixed, kw, returns, type_vars in signatures:
|
||||||
|
overloads.append(
|
||||||
|
self._make_function(
|
||||||
|
name=name,
|
||||||
|
pos=pos,
|
||||||
|
mixed=mixed,
|
||||||
|
kw=kw,
|
||||||
|
returns=returns,
|
||||||
|
type_vars=type_vars,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
function: Type = OverloadedFunction(overloads=overloads)
|
||||||
|
self.define(name, function)
|
||||||
|
if py_function is not None:
|
||||||
|
self._python_funcs[name] = py_function
|
||||||
|
|
||||||
def get_py_func(self, name: str) -> Optional[Callable[..., Any]]:
|
def get_py_func(self, name: str) -> Optional[Callable[..., Any]]:
|
||||||
return self._python_funcs.get(name)
|
return self._python_funcs.get(name)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Any, Optional
|
|||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.ast.printer import MidasPrinter
|
from midas.ast.printer import MidasPrinter
|
||||||
|
from midas.checker.dispatcher import CallDispatcher, CallResult
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
from midas.checker.evaluator import Evaluator
|
from midas.checker.evaluator import Evaluator
|
||||||
from midas.checker.frames import FrameManager
|
from midas.checker.frames import FrameManager
|
||||||
@@ -27,7 +28,6 @@ from midas.checker.types import (
|
|||||||
DerivedType,
|
DerivedType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
|
||||||
TupleType,
|
TupleType,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
@@ -36,7 +36,6 @@ from midas.checker.types import (
|
|||||||
Variance,
|
Variance,
|
||||||
unfold_type,
|
unfold_type,
|
||||||
)
|
)
|
||||||
from midas.checker.unifier import Unifier
|
|
||||||
from midas.parser.python import PythonParser
|
from midas.parser.python import PythonParser
|
||||||
from midas.utils import TypedAST
|
from midas.utils import TypedAST
|
||||||
|
|
||||||
@@ -85,9 +84,17 @@ class PythonTyper(
|
|||||||
self.locals: dict[p.Expr, int] = {}
|
self.locals: dict[p.Expr, int] = {}
|
||||||
self.judgements: list[tuple[p.Expr, Type]] = []
|
self.judgements: list[tuple[p.Expr, Type]] = []
|
||||||
self.evaluated_casts: list[p.CastExpr] = []
|
self.evaluated_casts: list[p.CastExpr] = []
|
||||||
|
self.dispatcher: CallDispatcher[p.Expr] = CallDispatcher[p.Expr](
|
||||||
|
self.types, self.reporter
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_reporter(self, reporter: FileReporter):
|
||||||
|
self.reporter = reporter
|
||||||
|
self.dispatcher.set_reporter(self.reporter)
|
||||||
|
|
||||||
def process(self, source: str, path: Optional[str]) -> TypedAST:
|
def process(self, source: str, path: Optional[str]) -> TypedAST:
|
||||||
self.reporter = self.reporter.for_file(path)
|
reporter: FileReporter = self.reporter.for_file(path)
|
||||||
|
self.set_reporter(reporter)
|
||||||
|
|
||||||
tree: ast.Module = ast.parse(source, filename=path or "<unknown>")
|
tree: ast.Module = ast.parse(source, filename=path or "<unknown>")
|
||||||
parser = PythonParser()
|
parser = PythonParser()
|
||||||
@@ -222,12 +229,13 @@ class PythonTyper(
|
|||||||
if method is None:
|
if method is None:
|
||||||
raise UndefinedMethodException
|
raise UndefinedMethodException
|
||||||
|
|
||||||
return self._get_call_result(
|
result: CallResult = self.dispatcher.get_result(
|
||||||
location,
|
location=location,
|
||||||
method,
|
callee=method,
|
||||||
positional,
|
positional=positional,
|
||||||
keywords,
|
keywords=keywords,
|
||||||
)
|
)
|
||||||
|
return result.result
|
||||||
|
|
||||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||||
return self.types.is_subtype(type1, type2)
|
return self.types.is_subtype(type1, type2)
|
||||||
@@ -571,15 +579,13 @@ class PythonTyper(
|
|||||||
)
|
)
|
||||||
|
|
||||||
callee: Type = self.type_of(expr.callee)
|
callee: Type = self.type_of(expr.callee)
|
||||||
return (
|
result: CallResult = self.dispatcher.get_result(
|
||||||
self._get_call_result(
|
location=expr.location,
|
||||||
location=expr.location,
|
callee=callee,
|
||||||
callee=callee,
|
positional=positional,
|
||||||
positional=positional,
|
keywords=keywords,
|
||||||
keywords=keywords,
|
|
||||||
)
|
|
||||||
or UnknownType()
|
|
||||||
)
|
)
|
||||||
|
return result.result
|
||||||
|
|
||||||
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
||||||
object: Type = self.type_of(expr.object)
|
object: Type = self.type_of(expr.object)
|
||||||
@@ -742,10 +748,13 @@ class PythonTyper(
|
|||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
index: Type = self.type_of(expr.index)
|
index: Type = self.type_of(expr.index)
|
||||||
return (
|
result: CallResult = self.dispatcher.get_result(
|
||||||
self._get_call_result(expr.location, operation, [(expr.index, index)], {})
|
location=expr.location,
|
||||||
or UnknownType()
|
callee=operation,
|
||||||
|
positional=[(expr.index, index)],
|
||||||
|
keywords={},
|
||||||
)
|
)
|
||||||
|
return result.result
|
||||||
|
|
||||||
def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
|
def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
|
||||||
return self.types.get_type("slice")
|
return self.types.get_type("slice")
|
||||||
@@ -796,376 +805,6 @@ class PythonTyper(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_call_result(
|
|
||||||
self,
|
|
||||||
location: Location,
|
|
||||||
callee: Type,
|
|
||||||
positional: list[TypedExpr],
|
|
||||||
keywords: dict[str, TypedExpr],
|
|
||||||
report_errors: bool = True,
|
|
||||||
) -> Optional[Type]:
|
|
||||||
"""Get the result type of a function call
|
|
||||||
|
|
||||||
If the function has overloads, the function will try to resolve the
|
|
||||||
appropriate signature.
|
|
||||||
Argument types are matched to the defined parameters.
|
|
||||||
The function doesn't take the raw expression as a parameter to accommodate
|
|
||||||
for desugared calls such as for operators.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
location (Location): the call location
|
|
||||||
callee (Type): the called function
|
|
||||||
positional (list[TypedExpr]): the list positional arguments
|
|
||||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
|
||||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Type: the return type of the call, or `None` if either
|
|
||||||
the call is invalid or no overload matched the arguments uniquely
|
|
||||||
"""
|
|
||||||
match callee:
|
|
||||||
case Function() as function:
|
|
||||||
valid: bool
|
|
||||||
mapped: list[MappedArgument]
|
|
||||||
valid, mapped = self.map_call_arguments(
|
|
||||||
function, location, positional, keywords
|
|
||||||
)
|
|
||||||
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
|
||||||
if not valid:
|
|
||||||
return None
|
|
||||||
return function.returns
|
|
||||||
|
|
||||||
case OverloadedFunction(overloads=overloads):
|
|
||||||
function = self._match_overload(
|
|
||||||
overloads, location, positional, keywords, report_errors
|
|
||||||
)
|
|
||||||
if function is None:
|
|
||||||
return None
|
|
||||||
return function.returns
|
|
||||||
|
|
||||||
case AppliedType(body=body):
|
|
||||||
return self._get_call_result(
|
|
||||||
location, body, positional, keywords, report_errors
|
|
||||||
)
|
|
||||||
|
|
||||||
case UnknownType():
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
case DerivedType(type=base):
|
|
||||||
return self._get_call_result(
|
|
||||||
location, base, positional, keywords, report_errors
|
|
||||||
)
|
|
||||||
|
|
||||||
case GenericType():
|
|
||||||
unifier: Unifier = Unifier(self.types)
|
|
||||||
pos: list[Type] = [a[1] for a in positional]
|
|
||||||
kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()}
|
|
||||||
unified: Optional[Type] = unifier.unify_call(callee, pos, kw)
|
|
||||||
if unified is None:
|
|
||||||
if report_errors:
|
|
||||||
pos_str: str = ", ".join(str(t) for t in pos)
|
|
||||||
kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items())
|
|
||||||
self.reporter.error(
|
|
||||||
location,
|
|
||||||
f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}",
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
return self._get_call_result(
|
|
||||||
location,
|
|
||||||
unified,
|
|
||||||
positional,
|
|
||||||
keywords,
|
|
||||||
report_errors,
|
|
||||||
)
|
|
||||||
|
|
||||||
case _:
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
location,
|
|
||||||
f"{callee} ({callee.__class__.__name__}) is not callable",
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _are_arguments_valid(
|
|
||||||
self,
|
|
||||||
arguments: list[MappedArgument],
|
|
||||||
report_errors: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""Check whether the passed argument types correspond to their matched parameter definitions
|
|
||||||
|
|
||||||
Args:
|
|
||||||
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
|
||||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
|
||||||
"""
|
|
||||||
valid: bool = True
|
|
||||||
for arg in arguments:
|
|
||||||
if not self.is_subtype(arg.type, arg.argument.type):
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
arg.expr.location,
|
|
||||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def _match_overload(
|
|
||||||
self,
|
|
||||||
overloads: list[Type],
|
|
||||||
location: Location,
|
|
||||||
positional: list[TypedExpr],
|
|
||||||
keywords: dict[str, TypedExpr],
|
|
||||||
report_errors: bool = True,
|
|
||||||
) -> Optional[Function]:
|
|
||||||
"""Try and resolve the appropriate overload for the given arguments
|
|
||||||
|
|
||||||
Args:
|
|
||||||
overloads (list[Type]): the list of possible overloads
|
|
||||||
location (Location): the call location
|
|
||||||
positional (list[TypedExpr]): the list of positional arguments
|
|
||||||
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
|
||||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Function]: the resolved function signature if it can be
|
|
||||||
determined unambiguously, or `None`.
|
|
||||||
"""
|
|
||||||
candidates: list[OverloadCandidate] = []
|
|
||||||
for overload in overloads:
|
|
||||||
function: Type = unfold_type(overload)
|
|
||||||
if not isinstance(function, Function):
|
|
||||||
if report_errors:
|
|
||||||
self.logger.error(
|
|
||||||
f"Overload is not a function: {overload} is {function}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
valid, mapped = self.map_call_arguments(
|
|
||||||
function=function,
|
|
||||||
location=location,
|
|
||||||
positional=positional,
|
|
||||||
keywords=keywords,
|
|
||||||
report_errors=False,
|
|
||||||
)
|
|
||||||
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
|
||||||
candidates.append(
|
|
||||||
OverloadCandidate(
|
|
||||||
function=function,
|
|
||||||
mapped=mapped,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
pos_types: str = ", ".join(str(type) for _, type in positional)
|
|
||||||
kw_types: str = ", ".join(
|
|
||||||
f"{name}: {type}" for name, (_, type) in keywords.items()
|
|
||||||
)
|
|
||||||
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
|
||||||
|
|
||||||
n_candidates: int = len(candidates)
|
|
||||||
|
|
||||||
# Exactly 1 match -> return it
|
|
||||||
if n_candidates == 1:
|
|
||||||
return candidates[0].function
|
|
||||||
|
|
||||||
# No match -> invalid call
|
|
||||||
if n_candidates == 0:
|
|
||||||
overloads_str: str = ", ".join(map(str, overloads))
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
location,
|
|
||||||
f"No matching overload in [{overloads_str}] {for_args}",
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Multiple matches -> see if one <: all others (more specific)
|
|
||||||
for i1, c1 in enumerate(candidates):
|
|
||||||
mapped1: list[MappedArgument] = c1.mapped
|
|
||||||
best_match: bool = True
|
|
||||||
for i2, c2 in enumerate(candidates):
|
|
||||||
if i1 == i2:
|
|
||||||
continue
|
|
||||||
mapped2: list[MappedArgument] = c2.mapped
|
|
||||||
if not self._are_mapped_subtypes(mapped1, mapped2):
|
|
||||||
best_match = False
|
|
||||||
break
|
|
||||||
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
|
||||||
if best_match:
|
|
||||||
return c1.function
|
|
||||||
|
|
||||||
candidates_str: str = ", ".join(
|
|
||||||
str(candidate.function) for candidate in candidates
|
|
||||||
)
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
location,
|
|
||||||
f"Multiple matching overloads {for_args}: {candidates_str}",
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def map_call_arguments(
|
|
||||||
self,
|
|
||||||
function: Function,
|
|
||||||
location: Location,
|
|
||||||
positional: list[TypedExpr],
|
|
||||||
keywords: dict[str, TypedExpr],
|
|
||||||
report_errors: bool = True,
|
|
||||||
) -> tuple[bool, list[MappedArgument]]:
|
|
||||||
"""Map call arguments to a function's parameters as defined in its signature
|
|
||||||
|
|
||||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
|
||||||
with the arguments passed at the call site
|
|
||||||
|
|
||||||
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
|
||||||
unless `report_errors` is set to `False`
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function (Function): the function definition
|
|
||||||
location (Location): the call location
|
|
||||||
positional (list[TypedExpr]): the list of positional arguments
|
|
||||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
|
||||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
|
||||||
the call is valid and the list of mapped arguments
|
|
||||||
"""
|
|
||||||
set_args: set[str] = set()
|
|
||||||
|
|
||||||
required_positional: list[str] = [
|
|
||||||
arg.name for arg in function.pos_args + function.args if arg.required
|
|
||||||
]
|
|
||||||
required_keyword: list[str] = [
|
|
||||||
arg.name for arg in function.kw_args if arg.required
|
|
||||||
]
|
|
||||||
|
|
||||||
mapped: list[MappedArgument] = []
|
|
||||||
|
|
||||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
|
||||||
mixed_params: list[Function.Argument] = list(function.args)
|
|
||||||
kw_params: dict[str, Function.Argument] = {
|
|
||||||
arg.name: arg for arg in function.kw_args
|
|
||||||
}
|
|
||||||
|
|
||||||
valid_call: bool = True
|
|
||||||
|
|
||||||
# TODO: handle *args and **kwargs sinks
|
|
||||||
for arg in positional:
|
|
||||||
param: Function.Argument
|
|
||||||
if len(pos_params) != 0:
|
|
||||||
param = pos_params.pop(0)
|
|
||||||
elif len(mixed_params) != 0:
|
|
||||||
param = mixed_params.pop(0)
|
|
||||||
else:
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
arg[0].location, "Too many positional arguments"
|
|
||||||
)
|
|
||||||
valid_call = False
|
|
||||||
break
|
|
||||||
name: str = param.name
|
|
||||||
if name in required_positional:
|
|
||||||
required_positional.remove(name)
|
|
||||||
if name in required_keyword:
|
|
||||||
required_keyword.remove(name)
|
|
||||||
set_args.add(name)
|
|
||||||
mapped.append(
|
|
||||||
MappedArgument(
|
|
||||||
expr=arg[0],
|
|
||||||
type=arg[1],
|
|
||||||
argument=param,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
|
||||||
for name, arg in keywords.items():
|
|
||||||
param: Function.Argument
|
|
||||||
if name not in kw_params:
|
|
||||||
if report_errors:
|
|
||||||
if name in set_args:
|
|
||||||
self.reporter.error(
|
|
||||||
arg[0].location, f"Multiple values for argument '{name}'"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.reporter.error(
|
|
||||||
arg[0].location, f"Unknown keyword argument '{name}'"
|
|
||||||
)
|
|
||||||
valid_call = False
|
|
||||||
continue
|
|
||||||
param = kw_params.pop(name)
|
|
||||||
if name in required_positional:
|
|
||||||
required_positional.remove(name)
|
|
||||||
if name in required_keyword:
|
|
||||||
required_keyword.remove(name)
|
|
||||||
set_args.add(name)
|
|
||||||
mapped.append(
|
|
||||||
MappedArgument(
|
|
||||||
expr=arg[0],
|
|
||||||
type=arg[1],
|
|
||||||
argument=param,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def join_args(args: list[str]) -> str:
|
|
||||||
args = list(map(lambda a: f"'{a}'", args))
|
|
||||||
if len(args) == 0:
|
|
||||||
return ""
|
|
||||||
if len(args) == 1:
|
|
||||||
return args[0]
|
|
||||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
|
||||||
|
|
||||||
if len(required_positional) != 0:
|
|
||||||
plural: str = "" if len(required_positional) == 1 else "s"
|
|
||||||
args: str = join_args(required_positional)
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
location,
|
|
||||||
f"Missing required positional argument{plural}: {args}",
|
|
||||||
)
|
|
||||||
valid_call = False
|
|
||||||
|
|
||||||
if len(required_keyword) != 0:
|
|
||||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
|
||||||
args: str = join_args(required_keyword)
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
location,
|
|
||||||
f"Missing required keyword argument{plural}: {args}",
|
|
||||||
)
|
|
||||||
valid_call = False
|
|
||||||
|
|
||||||
return valid_call, mapped
|
|
||||||
|
|
||||||
def _are_mapped_subtypes(
|
|
||||||
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
|
|
||||||
) -> bool:
|
|
||||||
"""Check whether the given argument mappings are subtype/supertype of one another
|
|
||||||
|
|
||||||
This function checks whether the argument mappings `mapped1` are subtypes
|
|
||||||
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
|
||||||
of the corresponding parameter in `mapped2`, `False` is returned.
|
|
||||||
|
|
||||||
This is used to check whether a given overload is
|
|
||||||
a more specific function/ a subtype of another.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
|
||||||
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
|
||||||
"""
|
|
||||||
by_expr: dict[p.Expr, Type] = {}
|
|
||||||
for arg in mapped1:
|
|
||||||
by_expr[arg.expr] = arg.argument.type
|
|
||||||
|
|
||||||
for arg in mapped2:
|
|
||||||
type2: Type = arg.argument.type
|
|
||||||
type1: Type = by_expr[arg.expr]
|
|
||||||
if not self.is_subtype(type1, type2):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _get_iterator_type(self, expr: p.Expr, type: Type) -> Optional[Type]:
|
def _get_iterator_type(self, expr: p.Expr, type: Type) -> Optional[Type]:
|
||||||
# TODO: lookup __iter__
|
# TODO: lookup __iter__
|
||||||
getitem: Optional[Type] = self.types.lookup_member(type, "__getitem__")
|
getitem: Optional[Type] = self.types.lookup_member(type, "__getitem__")
|
||||||
@@ -1174,14 +813,16 @@ class PythonTyper(
|
|||||||
|
|
||||||
index: p.Expr = p.LiteralExpr(location=expr.location, value=0)
|
index: p.Expr = p.LiteralExpr(location=expr.location, value=0)
|
||||||
index_type: Type = self.compute_type(index)
|
index_type: Type = self.compute_type(index)
|
||||||
result: Optional[Type] = self._get_call_result(
|
result: CallResult = self.dispatcher.get_result(
|
||||||
location=expr.location,
|
location=expr.location,
|
||||||
callee=getitem,
|
callee=getitem,
|
||||||
positional=[(index, index_type)],
|
positional=[(index, index_type)],
|
||||||
keywords={},
|
keywords={},
|
||||||
report_errors=False,
|
report_errors=False,
|
||||||
)
|
)
|
||||||
return result
|
if not result.is_valid:
|
||||||
|
return None
|
||||||
|
return result.result
|
||||||
|
|
||||||
def define_typevar(self, call: p.CallExpr) -> Optional[TypeVar]:
|
def define_typevar(self, call: p.CallExpr) -> Optional[TypeVar]:
|
||||||
def is_kw_true(name: str) -> bool:
|
def is_kw_true(name: str) -> bool:
|
||||||
@@ -1272,6 +913,22 @@ class PythonTyper(
|
|||||||
pairs.append((key_val, value_val))
|
pairs.append((key_val, value_val))
|
||||||
return True, dict(pairs)
|
return True, dict(pairs)
|
||||||
|
|
||||||
|
case p.UnaryExpr(operator=operator, right=operand):
|
||||||
|
is_lit, operand_val = self._get_literal(operand)
|
||||||
|
if not is_lit:
|
||||||
|
return False, None
|
||||||
|
match operator:
|
||||||
|
case ast.UAdd():
|
||||||
|
return True, operand_val
|
||||||
|
case ast.USub():
|
||||||
|
return True, -operand_val
|
||||||
|
case ast.Invert():
|
||||||
|
return True, ~operand_val
|
||||||
|
case ast.Not():
|
||||||
|
return True, not operand_val
|
||||||
|
case _: # Should never be reached
|
||||||
|
return False, None
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
@@ -1284,6 +941,40 @@ class PythonTyper(
|
|||||||
expr, subject_type, base, lit_value
|
expr, subject_type, base, lit_value
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case AppliedType(name="list", args=[item_type]) if isinstance(
|
||||||
|
lit_value, list
|
||||||
|
):
|
||||||
|
match subject_type:
|
||||||
|
case AppliedType(name="list", args=[lit_item_type]):
|
||||||
|
evaluated: bool = True
|
||||||
|
for item in lit_value:
|
||||||
|
if not self._evaluate_cast_statically(
|
||||||
|
expr, lit_item_type, item_type, item
|
||||||
|
):
|
||||||
|
evaluated = False
|
||||||
|
return evaluated
|
||||||
|
case _:
|
||||||
|
return False
|
||||||
|
|
||||||
|
case AppliedType(name="dict", args=[key_type, value_type]) if isinstance(
|
||||||
|
lit_value, dict
|
||||||
|
):
|
||||||
|
match subject_type:
|
||||||
|
case AppliedType(name="dict", args=[lit_key_type, lit_value_type]):
|
||||||
|
evaluated: bool = True
|
||||||
|
for key, value in lit_value.items():
|
||||||
|
if not self._evaluate_cast_statically(
|
||||||
|
expr, lit_key_type, key_type, key
|
||||||
|
):
|
||||||
|
evaluated = False
|
||||||
|
if not self._evaluate_cast_statically(
|
||||||
|
expr, lit_value_type, value_type, value
|
||||||
|
):
|
||||||
|
evaluated = False
|
||||||
|
return evaluated
|
||||||
|
case _:
|
||||||
|
return False
|
||||||
|
|
||||||
case AppliedType(body=body):
|
case AppliedType(body=body):
|
||||||
return self._evaluate_cast_statically(
|
return self._evaluate_cast_statically(
|
||||||
expr, subject_type, body, lit_value
|
expr, subject_type, body, lit_value
|
||||||
@@ -1298,10 +989,19 @@ class PythonTyper(
|
|||||||
|
|
||||||
evaluator = Evaluator(self.types)
|
evaluator = Evaluator(self.types)
|
||||||
evaluator.set_value("_", lit_value)
|
evaluator.set_value("_", lit_value)
|
||||||
res = evaluator.evaluate(constraint)
|
printer = MidasPrinter()
|
||||||
|
constraint_str: str = printer.print(constraint)
|
||||||
|
res: Any
|
||||||
|
try:
|
||||||
|
res = evaluator.evaluate(constraint)
|
||||||
|
except Exception as e:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"An error occurred while checking constraint '{constraint_str}' on the value {lit_value!r}: {e}",
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
if not res:
|
if not res:
|
||||||
printer = MidasPrinter()
|
|
||||||
constraint_str: str = printer.print(constraint)
|
|
||||||
self.reporter.error(
|
self.reporter.error(
|
||||||
expr.location,
|
expr.location,
|
||||||
f"Value {lit_value!r} does not fit constraint '{constraint_str}'",
|
f"Value {lit_value!r} does not fit constraint '{constraint_str}'",
|
||||||
|
|||||||
@@ -46,8 +46,8 @@ class MidasLexer(Lexer):
|
|||||||
self.add_token(TokenType.UNDERSCORE)
|
self.add_token(TokenType.UNDERSCORE)
|
||||||
case "-" if self.match(">"):
|
case "-" if self.match(">"):
|
||||||
self.add_token(TokenType.ARROW)
|
self.add_token(TokenType.ARROW)
|
||||||
# case "+":
|
case "+":
|
||||||
# self.add_token(TokenType.PLUS)
|
self.add_token(TokenType.PLUS)
|
||||||
case "-":
|
case "-":
|
||||||
self.add_token(TokenType.MINUS)
|
self.add_token(TokenType.MINUS)
|
||||||
case "*":
|
case "*":
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class TokenType(Enum):
|
|||||||
DOT = auto()
|
DOT = auto()
|
||||||
|
|
||||||
# Operators
|
# Operators
|
||||||
# PLUS = auto()
|
PLUS = auto()
|
||||||
MINUS = auto()
|
MINUS = auto()
|
||||||
STAR = auto()
|
STAR = auto()
|
||||||
SLASH = auto()
|
SLASH = auto()
|
||||||
|
|||||||
@@ -361,13 +361,35 @@ class MidasParser(Parser):
|
|||||||
Returns:
|
Returns:
|
||||||
Expr: the parsed expression
|
Expr: the parsed expression
|
||||||
"""
|
"""
|
||||||
expr: Expr = self.unary()
|
expr: Expr = self.term()
|
||||||
while self.match(
|
while self.match(
|
||||||
TokenType.LESS,
|
TokenType.LESS,
|
||||||
TokenType.LESS_EQUAL,
|
TokenType.LESS_EQUAL,
|
||||||
TokenType.GREATER,
|
TokenType.GREATER,
|
||||||
TokenType.GREATER_EQUAL,
|
TokenType.GREATER_EQUAL,
|
||||||
):
|
):
|
||||||
|
operator: Token = self.previous()
|
||||||
|
right: Expr = self.term()
|
||||||
|
location: Location = Location.span(expr.location, right.location)
|
||||||
|
expr = BinaryExpr(
|
||||||
|
location=location, left=expr, operator=operator, right=right
|
||||||
|
)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
def term(self) -> Expr:
|
||||||
|
expr: Expr = self.factor()
|
||||||
|
while self.match(TokenType.PLUS, TokenType.MINUS):
|
||||||
|
operator: Token = self.previous()
|
||||||
|
right: Expr = self.factor()
|
||||||
|
location: Location = Location.span(expr.location, right.location)
|
||||||
|
expr = BinaryExpr(
|
||||||
|
location=location, left=expr, operator=operator, right=right
|
||||||
|
)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
def factor(self) -> Expr:
|
||||||
|
expr: Expr = self.unary()
|
||||||
|
while self.match(TokenType.STAR, TokenType.SLASH):
|
||||||
operator: Token = self.previous()
|
operator: Token = self.previous()
|
||||||
right: Expr = self.unary()
|
right: Expr = self.unary()
|
||||||
location: Location = Location.span(expr.location, right.location)
|
location: Location = Location.span(expr.location, right.location)
|
||||||
|
|||||||
Reference in New Issue
Block a user