feat(checker): resolve operation overloads with subtypes
This commit is contained in:
@@ -15,6 +15,7 @@ from midas.checker.types import (
|
||||
BaseType,
|
||||
ComplexType,
|
||||
Function,
|
||||
Operation,
|
||||
Type,
|
||||
UnitType,
|
||||
UnknownType,
|
||||
@@ -490,14 +491,48 @@ class Checker(
|
||||
left: Type = self.type_of(expr.left)
|
||||
right: Type = self.type_of(expr.right)
|
||||
|
||||
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
||||
if result is None:
|
||||
operations: list[Operation] = self.ctx.get_operations_by_name(method)
|
||||
valid_operations: list[Operation] = []
|
||||
for op in operations:
|
||||
sig: Operation.CallSignature = op.signature
|
||||
if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right):
|
||||
valid_operations.append(op)
|
||||
|
||||
if len(valid_operations) == 0:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Undefined operation {method} between {left} and {right}",
|
||||
)
|
||||
return UnknownType()
|
||||
return result
|
||||
elif len(valid_operations) == 1:
|
||||
self.logger.debug(f"Unique operation {method} between {left} and {right}")
|
||||
return valid_operations[0].result
|
||||
|
||||
for i, op1 in enumerate(valid_operations):
|
||||
sig1: Operation.CallSignature = op1.signature
|
||||
best_match: bool = True
|
||||
for j, op2 in enumerate(valid_operations):
|
||||
if i == j:
|
||||
continue
|
||||
sig2: Operation.CallSignature = op2.signature
|
||||
if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype(
|
||||
sig1.right, sig2.right
|
||||
):
|
||||
best_match = False
|
||||
break
|
||||
self.logger.debug(f"{op1} is a full overload of {op2}")
|
||||
if best_match:
|
||||
return op1.result
|
||||
|
||||
overloads: list[str] = [
|
||||
f"({op.signature.left} {op.signature.method} {op.signature.right}) -> {op.result}"
|
||||
for op in valid_operations
|
||||
]
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(overloads)}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||
|
||||
@@ -45,4 +45,16 @@ class ComplexType:
|
||||
properties: dict[str, Type]
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Operation:
|
||||
signature: CallSignature
|
||||
result: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class CallSignature:
|
||||
left: Type
|
||||
method: str
|
||||
right: Type
|
||||
|
||||
|
||||
Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
Operation,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
@@ -14,7 +15,7 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._types: dict[str, Type] = {}
|
||||
self._operations: dict[tuple[Type, str, Type], Type] = {}
|
||||
self._operations: dict[Operation.CallSignature, Type] = {}
|
||||
|
||||
define_builtins(self)
|
||||
|
||||
@@ -48,10 +49,26 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
|
||||
Returns:
|
||||
Optional[Type]: the result type, or None if no matching operation was found
|
||||
"""
|
||||
operation: tuple[Type, str, Type] = (left, operator, right)
|
||||
result: Optional[Type] = self._operations.get(operation)
|
||||
signature: Operation.CallSignature = Operation.CallSignature(
|
||||
left=left,
|
||||
method=operator,
|
||||
right=right,
|
||||
)
|
||||
result: Optional[Type] = self._operations.get(signature)
|
||||
return result
|
||||
|
||||
def get_operations_by_name(self, name: str) -> list[Operation]:
|
||||
operations: list[Operation] = []
|
||||
for signature, result in self._operations.items():
|
||||
if signature.method == name:
|
||||
operations.append(
|
||||
Operation(
|
||||
signature=signature,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
return operations
|
||||
|
||||
def define_type(self, name: str, type: Type) -> Type:
|
||||
"""Define a type in the registry
|
||||
|
||||
@@ -82,12 +99,16 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
|
||||
Raises:
|
||||
ValueError: if an operation is already defined with these operands and name
|
||||
"""
|
||||
operation: tuple[Type, str, Type] = (left, operator, right)
|
||||
if operation in self._operations:
|
||||
signature: Operation.CallSignature = Operation.CallSignature(
|
||||
left=left,
|
||||
method=operator,
|
||||
right=right,
|
||||
)
|
||||
if signature in self._operations:
|
||||
raise ValueError(
|
||||
f"Operation {operator} already defined between {left} and {right}"
|
||||
)
|
||||
self._operations[operation] = result
|
||||
self._operations[signature] = result
|
||||
|
||||
def resolve(self, stmts: list[m.Stmt]):
|
||||
"""Process a sequence of statements
|
||||
|
||||
Reference in New Issue
Block a user