fix(checker): handle all operations and calls in predicates
This commit is contained in:
@@ -4,22 +4,27 @@ from pathlib import Path
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import midas.ast.midas as m
|
import midas.ast.midas as m
|
||||||
|
from midas.ast.location import Location
|
||||||
from midas.checker.builtins import define_builtins
|
from midas.checker.builtins import define_builtins
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
|
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
|
||||||
from midas.checker.preamble import Preamble
|
from midas.checker.preamble import Preamble
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.reporter import FileReporter, Reporter
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
AliasType,
|
||||||
|
AppliedType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
|
OverloadedFunction,
|
||||||
Predicate,
|
Predicate,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnknownType,
|
UnknownType,
|
||||||
|
unfold_type,
|
||||||
)
|
)
|
||||||
from midas.lexer.midas import MidasLexer
|
from midas.lexer.midas import MidasLexer
|
||||||
from midas.lexer.token import Token
|
from midas.lexer.token import Token
|
||||||
@@ -33,6 +38,26 @@ class TypedParamSpec:
|
|||||||
kw: list[Function.Argument]
|
kw: list[Function.Argument]
|
||||||
|
|
||||||
|
|
||||||
|
TypedExpr = tuple[m.Expr, Type]
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class MappedArgument:
|
||||||
|
expr: m.Expr
|
||||||
|
type: Type
|
||||||
|
argument: Function.Argument
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class OverloadCandidate:
|
||||||
|
function: Function
|
||||||
|
mapped: list[MappedArgument]
|
||||||
|
|
||||||
|
|
||||||
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
|
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
|
||||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||||
|
|
||||||
@@ -197,42 +222,81 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
|||||||
return self._bool
|
return self._bool
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
|
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
|
||||||
# TODO
|
method: Optional[str] = MIDAS_BINARY_METHODS.get(expr.operator.type)
|
||||||
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
|
if method is None:
|
||||||
return UnknownType()
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
|
self.reporter.warning(
|
||||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
|
expr.location, f"Unsupported operator {expr.operator}"
|
||||||
# TODO
|
|
||||||
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
def visit_call_expr(self, expr: m.CallExpr) -> Type:
|
|
||||||
callee: Type = expr.callee.accept(self)
|
|
||||||
if not isinstance(callee, Function):
|
|
||||||
self.reporter.error(expr.location, f"Cannot call {callee}")
|
|
||||||
return UnknownType()
|
|
||||||
args: list[Type] = [arg.accept(self) for arg in expr.arguments]
|
|
||||||
|
|
||||||
n_args: int = len(args)
|
|
||||||
n_params: int = len(callee.args)
|
|
||||||
if n_args != n_params:
|
|
||||||
self.reporter.error(
|
|
||||||
expr.location,
|
|
||||||
f"Wrong number of argument, expected {n_params}, got {n_args}",
|
|
||||||
)
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
valid: bool = True
|
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||||
for arg, param in zip(args, callee.args):
|
|
||||||
if not self.types.is_subtype(arg, param.type):
|
def _visit_binary_expr(
|
||||||
self.reporter.error(
|
self, location: Location, left_expr: m.Expr, right_expr: m.Expr, method: str
|
||||||
expr.location,
|
) -> Type:
|
||||||
f"Invalid argument type at pos {param.pos}, expected {param.type}, got {arg}",
|
left: Type = self.type_of(left_expr)
|
||||||
)
|
right: Type = self.type_of(right_expr)
|
||||||
valid = False
|
|
||||||
if not valid:
|
operation: Optional[Type] = self.types.lookup_member(left, method)
|
||||||
|
if operation is None:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Undefined operation {method} between {left} and {right}",
|
||||||
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
return callee.returns
|
|
||||||
|
result: Optional[Type] = self._get_call_result(
|
||||||
|
location,
|
||||||
|
operation,
|
||||||
|
[(right_expr, right)],
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
return result or UnknownType()
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
|
||||||
|
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
|
||||||
|
if method is None:
|
||||||
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
|
self.reporter.warning(
|
||||||
|
expr.location, f"Unsupported operator {expr.operator}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
operand: Type = self.type_of(expr.right)
|
||||||
|
operation: Optional[Type] = self.types.lookup_member(operand, method)
|
||||||
|
if operation is None:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Undefined operation {method} for {operand}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
result: Optional[Type] = self._get_call_result(
|
||||||
|
expr.location,
|
||||||
|
operation,
|
||||||
|
[],
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
return result or UnknownType()
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: m.CallExpr) -> Type:
|
||||||
|
callee: Type = expr.callee.accept(self)
|
||||||
|
positional: list[TypedExpr] = [
|
||||||
|
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||||
|
]
|
||||||
|
keywords: dict[str, TypedExpr] = {
|
||||||
|
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
self._get_call_result(
|
||||||
|
expr.location,
|
||||||
|
callee,
|
||||||
|
positional,
|
||||||
|
keywords,
|
||||||
|
)
|
||||||
|
or UnknownType()
|
||||||
|
)
|
||||||
|
|
||||||
def visit_get_expr(self, expr: m.GetExpr) -> Type:
|
def visit_get_expr(self, expr: m.GetExpr) -> Type:
|
||||||
object: Type = expr.expr.accept(self)
|
object: Type = expr.expr.accept(self)
|
||||||
@@ -344,3 +408,343 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
|||||||
self._local_variables[name] = var
|
self._local_variables[name] = var
|
||||||
vars.append(var)
|
vars.append(var)
|
||||||
return vars
|
return vars
|
||||||
|
|
||||||
|
def _get_call_result(
|
||||||
|
self,
|
||||||
|
location: Location,
|
||||||
|
callee: Type,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> Optional[Type]:
|
||||||
|
"""Get the result type of a function call
|
||||||
|
|
||||||
|
If the function has overloads, the function will try to resolve the
|
||||||
|
appropriate signature.
|
||||||
|
Argument types are matched to the defined parameters.
|
||||||
|
The function doesn't take the raw expression as a parameter to accommodate
|
||||||
|
for desugared calls such as for operators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location (Location): the call location
|
||||||
|
callee (Type): the called function
|
||||||
|
positional (list[TypedExpr]): the list positional arguments
|
||||||
|
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the return type of the call, or `None` if either
|
||||||
|
the call is invalid or no overload matched the arguments uniquely
|
||||||
|
"""
|
||||||
|
match callee:
|
||||||
|
case Function() as function:
|
||||||
|
valid: bool
|
||||||
|
mapped: list[MappedArgument]
|
||||||
|
valid, mapped = self.map_call_arguments(
|
||||||
|
function, location, positional, keywords
|
||||||
|
)
|
||||||
|
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
||||||
|
if not valid:
|
||||||
|
return None
|
||||||
|
return function.returns
|
||||||
|
|
||||||
|
case OverloadedFunction(overloads=overloads):
|
||||||
|
function = self._match_overload(
|
||||||
|
overloads, location, positional, keywords, report_errors
|
||||||
|
)
|
||||||
|
if function is None:
|
||||||
|
return None
|
||||||
|
return function.returns
|
||||||
|
|
||||||
|
case AppliedType(body=body):
|
||||||
|
return self._get_call_result(
|
||||||
|
location, body, positional, keywords, report_errors
|
||||||
|
)
|
||||||
|
|
||||||
|
case UnknownType():
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
case _:
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(location, f"{callee} is not callable")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _are_arguments_valid(
|
||||||
|
self,
|
||||||
|
arguments: list[MappedArgument],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
"""Check whether the passed argument types correspond to their matched parameter definitions
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
||||||
|
"""
|
||||||
|
valid: bool = True
|
||||||
|
for arg in arguments:
|
||||||
|
if not self.types.is_subtype(arg.type, arg.argument.type):
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
arg.expr.location,
|
||||||
|
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
return valid
|
||||||
|
|
||||||
|
def _match_overload(
|
||||||
|
self,
|
||||||
|
overloads: list[Type],
|
||||||
|
location: Location,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> Optional[Function]:
|
||||||
|
"""Try and resolve the appropriate overload for the given arguments
|
||||||
|
|
||||||
|
Args:
|
||||||
|
overloads (list[Type]): the list of possible overloads
|
||||||
|
location (Location): the call location
|
||||||
|
positional (list[TypedExpr]): the list of positional arguments
|
||||||
|
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Function]: the resolved function signature if it can be
|
||||||
|
determined unambiguously, or `None`.
|
||||||
|
"""
|
||||||
|
candidates: list[OverloadCandidate] = []
|
||||||
|
for overload in overloads:
|
||||||
|
function: Type = unfold_type(overload)
|
||||||
|
if not isinstance(function, Function):
|
||||||
|
if report_errors:
|
||||||
|
self.logger.error(
|
||||||
|
f"Overload is not a function: {overload} is {function}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
valid, mapped = self.map_call_arguments(
|
||||||
|
function=function,
|
||||||
|
location=location,
|
||||||
|
positional=positional,
|
||||||
|
keywords=keywords,
|
||||||
|
report_errors=False,
|
||||||
|
)
|
||||||
|
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||||
|
candidates.append(
|
||||||
|
OverloadCandidate(
|
||||||
|
function=function,
|
||||||
|
mapped=mapped,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||||
|
kw_types: str = ", ".join(
|
||||||
|
f"{name}: {type}" for name, (_, type) in keywords.items()
|
||||||
|
)
|
||||||
|
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||||
|
|
||||||
|
n_candidates: int = len(candidates)
|
||||||
|
|
||||||
|
# Exactly 1 match -> return it
|
||||||
|
if n_candidates == 1:
|
||||||
|
return candidates[0].function
|
||||||
|
|
||||||
|
# No match -> invalid call
|
||||||
|
if n_candidates == 0:
|
||||||
|
overloads_str: str = ", ".join(map(str, overloads))
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"No matching overload in [{overloads_str}] {for_args}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Multiple matches -> see if one <: all others (more specific)
|
||||||
|
for i1, c1 in enumerate(candidates):
|
||||||
|
mapped1: list[MappedArgument] = c1.mapped
|
||||||
|
best_match: bool = True
|
||||||
|
for i2, c2 in enumerate(candidates):
|
||||||
|
if i1 == i2:
|
||||||
|
continue
|
||||||
|
mapped2: list[MappedArgument] = c2.mapped
|
||||||
|
if not self._are_mapped_subtypes(mapped1, mapped2):
|
||||||
|
best_match = False
|
||||||
|
break
|
||||||
|
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
||||||
|
if best_match:
|
||||||
|
return c1.function
|
||||||
|
|
||||||
|
candidates_str: str = ", ".join(
|
||||||
|
str(candidate.function) for candidate in candidates
|
||||||
|
)
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Multiple matching overloads {for_args}: {candidates_str}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def map_call_arguments(
|
||||||
|
self,
|
||||||
|
function: Function,
|
||||||
|
location: Location,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> tuple[bool, list[MappedArgument]]:
|
||||||
|
"""Map call arguments to a function's parameters as defined in its signature
|
||||||
|
|
||||||
|
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||||
|
with the arguments passed at the call site
|
||||||
|
|
||||||
|
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
||||||
|
unless `report_errors` is set to `False`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function (Function): the function definition
|
||||||
|
location (Location): the call location
|
||||||
|
positional (list[TypedExpr]): the list of positional arguments
|
||||||
|
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
||||||
|
the call is valid and the list of mapped arguments
|
||||||
|
"""
|
||||||
|
set_args: set[str] = set()
|
||||||
|
|
||||||
|
required_positional: list[str] = [
|
||||||
|
arg.name for arg in function.pos_args + function.args if arg.required
|
||||||
|
]
|
||||||
|
required_keyword: list[str] = [
|
||||||
|
arg.name for arg in function.kw_args if arg.required
|
||||||
|
]
|
||||||
|
|
||||||
|
mapped: list[MappedArgument] = []
|
||||||
|
|
||||||
|
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||||
|
mixed_params: list[Function.Argument] = list(function.args)
|
||||||
|
kw_params: dict[str, Function.Argument] = {
|
||||||
|
arg.name: arg for arg in function.kw_args
|
||||||
|
}
|
||||||
|
|
||||||
|
valid_call: bool = True
|
||||||
|
|
||||||
|
# TODO: handle *args and **kwargs sinks
|
||||||
|
for arg in positional:
|
||||||
|
param: Function.Argument
|
||||||
|
if len(pos_params) != 0:
|
||||||
|
param = pos_params.pop(0)
|
||||||
|
elif len(mixed_params) != 0:
|
||||||
|
param = mixed_params.pop(0)
|
||||||
|
else:
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, "Too many positional arguments"
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
break
|
||||||
|
name: str = param.name
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||||
|
for name, arg in keywords.items():
|
||||||
|
param: Function.Argument
|
||||||
|
if name not in kw_params:
|
||||||
|
if report_errors:
|
||||||
|
if name in set_args:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, f"Multiple values for argument '{name}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
continue
|
||||||
|
param = kw_params.pop(name)
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def join_args(args: list[str]) -> str:
|
||||||
|
args = list(map(lambda a: f"'{a}'", args))
|
||||||
|
if len(args) == 0:
|
||||||
|
return ""
|
||||||
|
if len(args) == 1:
|
||||||
|
return args[0]
|
||||||
|
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||||
|
|
||||||
|
if len(required_positional) != 0:
|
||||||
|
plural: str = "" if len(required_positional) == 1 else "s"
|
||||||
|
args: str = join_args(required_positional)
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Missing required positional argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
|
||||||
|
if len(required_keyword) != 0:
|
||||||
|
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||||
|
args: str = join_args(required_keyword)
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Missing required keyword argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
|
||||||
|
return valid_call, mapped
|
||||||
|
|
||||||
|
def _are_mapped_subtypes(
|
||||||
|
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
|
||||||
|
) -> bool:
|
||||||
|
"""Check whether the given argument mappings are subtype/supertype of one another
|
||||||
|
|
||||||
|
This function checks whether the argument mappings `mapped1` are subtypes
|
||||||
|
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
||||||
|
of the corresponding parameter in `mapped2`, `False` is returned.
|
||||||
|
|
||||||
|
This is used to check whether a given overload is
|
||||||
|
a more specific function/ a subtype of another.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
||||||
|
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
||||||
|
"""
|
||||||
|
by_expr: dict[m.Expr, Type] = {}
|
||||||
|
for arg in mapped1:
|
||||||
|
by_expr[arg.expr] = arg.argument.type
|
||||||
|
|
||||||
|
for arg in mapped2:
|
||||||
|
type2: Type = arg.argument.type
|
||||||
|
type1: Type = by_expr[arg.expr]
|
||||||
|
if not self.types.is_subtype(type1, type2):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import ast
|
import ast
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
OPERATOR_METHODS: dict[Type[ast.operator], str] = {
|
from midas.lexer.token import TokenType
|
||||||
|
|
||||||
|
PY_OPERATOR_METHODS: dict[Type[ast.operator], str] = {
|
||||||
ast.Add: "__add__",
|
ast.Add: "__add__",
|
||||||
ast.Sub: "__sub__",
|
ast.Sub: "__sub__",
|
||||||
ast.Mult: "__mul__",
|
ast.Mult: "__mul__",
|
||||||
@@ -17,7 +19,7 @@ OPERATOR_METHODS: dict[Type[ast.operator], str] = {
|
|||||||
ast.FloorDiv: "__floordiv__",
|
ast.FloorDiv: "__floordiv__",
|
||||||
}
|
}
|
||||||
|
|
||||||
COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
PY_COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
||||||
ast.Eq: "__eq__",
|
ast.Eq: "__eq__",
|
||||||
# ast.NotEq: "__noteq__",
|
# ast.NotEq: "__noteq__",
|
||||||
ast.Lt: "__lt__",
|
ast.Lt: "__lt__",
|
||||||
@@ -30,9 +32,40 @@ COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
|||||||
# ast.NotIn: "__notin__",
|
# ast.NotIn: "__notin__",
|
||||||
}
|
}
|
||||||
|
|
||||||
UNARY_METHODS: dict[Type[ast.unaryop], str] = {
|
PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = {
|
||||||
ast.Invert: "__invert__",
|
ast.Invert: "__invert__",
|
||||||
# ast.Not: "",
|
# ast.Not: "",
|
||||||
ast.UAdd: "__pos__",
|
ast.UAdd: "__pos__",
|
||||||
ast.USub: "__neg__",
|
ast.USub: "__neg__",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
|
||||||
|
# TokenType.PLUS: "__add__",
|
||||||
|
TokenType.MINUS: "__sub__",
|
||||||
|
TokenType.STAR: "__mul__",
|
||||||
|
TokenType.SLASH: "__truediv__",
|
||||||
|
# TokenType.MODULO: "__mod__",
|
||||||
|
# TokenType.POW: "__pow__",
|
||||||
|
# ast.BitOr: "__or__",
|
||||||
|
# ast.BitXor: "__xor__",
|
||||||
|
# ast.BitAnd: "__and__",
|
||||||
|
# ast.FloorDiv: "__floordiv__",
|
||||||
|
TokenType.EQUAL_EQUAL: "__eq__",
|
||||||
|
# ast.NotEq: "__noteq__",
|
||||||
|
TokenType.LESS: "__lt__",
|
||||||
|
TokenType.LESS_EQUAL: "__le__",
|
||||||
|
TokenType.GREATER: "__gt__",
|
||||||
|
TokenType.GREATER_EQUAL: "__ge__",
|
||||||
|
# ast.Is: "__is__",
|
||||||
|
# ast.IsNot: "__isnot__",
|
||||||
|
# ast.In: "__in__",
|
||||||
|
# ast.NotIn: "__notin__",
|
||||||
|
}
|
||||||
|
|
||||||
|
MIDAS_UNARY_METHODS: dict[TokenType, str] = {
|
||||||
|
# ast.Invert: "__invert__",
|
||||||
|
# ast.Not: "",
|
||||||
|
# TokenType.PLUS: "__pos__",
|
||||||
|
TokenType.MINUS: "__neg__",
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,7 +6,11 @@ from typing import Optional
|
|||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS
|
from midas.checker.operators import (
|
||||||
|
PY_COMPARATOR_METHODS,
|
||||||
|
PY_OPERATOR_METHODS,
|
||||||
|
PY_UNARY_METHODS,
|
||||||
|
)
|
||||||
from midas.checker.preamble import Preamble
|
from midas.checker.preamble import Preamble
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.reporter import FileReporter, Reporter
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
@@ -376,7 +380,7 @@ class PythonTyper(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
method: Optional[str] = PY_OPERATOR_METHODS.get(expr.operator.__class__)
|
||||||
if method is None:
|
if method is None:
|
||||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
self.reporter.warning(
|
self.reporter.warning(
|
||||||
@@ -387,7 +391,7 @@ class PythonTyper(
|
|||||||
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||||
|
|
||||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||||
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
method: Optional[str] = PY_COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||||
if method is None:
|
if method is None:
|
||||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
self.reporter.warning(
|
self.reporter.warning(
|
||||||
@@ -420,7 +424,7 @@ class PythonTyper(
|
|||||||
return result or UnknownType()
|
return result or UnknownType()
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
||||||
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
|
method: Optional[str] = PY_UNARY_METHODS.get(expr.operator.__class__)
|
||||||
if method is None:
|
if method is None:
|
||||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
self.reporter.warning(
|
self.reporter.warning(
|
||||||
|
|||||||
Reference in New Issue
Block a user