diff --git a/midas/checker/checker.py b/midas/checker/checker.py index 13b7855..9ed50f0 100644 --- a/midas/checker/checker.py +++ b/midas/checker/checker.py @@ -15,6 +15,7 @@ from midas.checker.types import ( BaseType, ComplexType, Function, + Operation, Type, UnitType, UnknownType, @@ -490,14 +491,48 @@ class Checker( left: Type = self.type_of(expr.left) right: Type = self.type_of(expr.right) - result: Optional[Type] = self.ctx.get_operation_result(left, method, right) - if result is None: + operations: list[Operation] = self.ctx.get_operations_by_name(method) + valid_operations: list[Operation] = [] + for op in operations: + sig: Operation.CallSignature = op.signature + if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right): + valid_operations.append(op) + + if len(valid_operations) == 0: self.error( expr.location, f"Undefined operation {method} between {left} and {right}", ) return UnknownType() - return result + elif len(valid_operations) == 1: + self.logger.debug(f"Unique operation {method} between {left} and {right}") + return valid_operations[0].result + + for i, op1 in enumerate(valid_operations): + sig1: Operation.CallSignature = op1.signature + best_match: bool = True + for j, op2 in enumerate(valid_operations): + if i == j: + continue + sig2: Operation.CallSignature = op2.signature + if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype( + sig1.right, sig2.right + ): + best_match = False + break + self.logger.debug(f"{op1} is a full overload of {op2}") + if best_match: + return op1.result + + overloads: list[str] = [ + f"({op.signature.left} {op.signature.method} {op.signature.right}) -> {op.result}" + for op in valid_operations + ] + self.error( + expr.location, + f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(overloads)}", + ) + return UnknownType() def visit_compare_expr(self, expr: p.CompareExpr) -> Type: method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__) diff --git a/midas/checker/types.py b/midas/checker/types.py index c1d4449..83707b6 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -45,4 +45,16 @@ class ComplexType: properties: dict[str, Type] +@dataclass(frozen=True, kw_only=True) +class Operation: + signature: CallSignature + result: Type + + @dataclass(frozen=True, kw_only=True) + class CallSignature: + left: Type + method: str + right: Type + + Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType diff --git a/midas/resolver/midas.py b/midas/resolver/midas.py index acbbe96..2962b09 100644 --- a/midas/resolver/midas.py +++ b/midas/resolver/midas.py @@ -3,6 +3,7 @@ from typing import Optional import midas.ast.midas as m from midas.checker.types import ( AliasType, + Operation, Type, UnknownType, ) @@ -14,7 +15,7 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T def __init__(self) -> None: self._types: dict[str, Type] = {} - self._operations: dict[tuple[Type, str, Type], Type] = {} + self._operations: dict[Operation.CallSignature, Type] = {} define_builtins(self) @@ -48,10 +49,26 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T Returns: Optional[Type]: the result type, or None if no matching operation was found """ - operation: tuple[Type, str, Type] = (left, operator, right) - result: Optional[Type] = self._operations.get(operation) + signature: Operation.CallSignature = Operation.CallSignature( + left=left, + method=operator, + right=right, + ) + result: Optional[Type] = self._operations.get(signature) return result + def get_operations_by_name(self, name: str) -> list[Operation]: + operations: list[Operation] = [] + for signature, result in self._operations.items(): + if signature.method == name: + operations.append( + Operation( + signature=signature, + result=result, + ) + ) + return operations + def define_type(self, name: str, type: Type) -> Type: """Define a type in the registry @@ -82,12 +99,16 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T Raises: ValueError: if an operation is already defined with these operands and name """ - operation: tuple[Type, str, Type] = (left, operator, right) - if operation in self._operations: + signature: Operation.CallSignature = Operation.CallSignature( + left=left, + method=operator, + right=right, + ) + if signature in self._operations: raise ValueError( f"Operation {operator} already defined between {left} and {right}" ) - self._operations[operation] = result + self._operations[signature] = result def resolve(self, stmts: list[m.Stmt]): """Process a sequence of statements