feat(checker): resolve operation overloads with subtypes

This commit is contained in:
2026-06-07 13:43:43 +02:00
parent 25bd895dde
commit c24eb5125e
3 changed files with 77 additions and 9 deletions

View File

@@ -15,6 +15,7 @@ from midas.checker.types import (
BaseType,
ComplexType,
Function,
Operation,
Type,
UnitType,
UnknownType,
@@ -490,14 +491,48 @@ class Checker(
left: Type = self.type_of(expr.left)
right: Type = self.type_of(expr.right)
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
if result is None:
operations: list[Operation] = self.ctx.get_operations_by_name(method)
valid_operations: list[Operation] = []
for op in operations:
sig: Operation.CallSignature = op.signature
if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right):
valid_operations.append(op)
if len(valid_operations) == 0:
self.error(
expr.location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
return result
elif len(valid_operations) == 1:
self.logger.debug(f"Unique operation {method} between {left} and {right}")
return valid_operations[0].result
for i, op1 in enumerate(valid_operations):
sig1: Operation.CallSignature = op1.signature
best_match: bool = True
for j, op2 in enumerate(valid_operations):
if i == j:
continue
sig2: Operation.CallSignature = op2.signature
if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype(
sig1.right, sig2.right
):
best_match = False
break
self.logger.debug(f"{op1} is a full overload of {op2}")
if best_match:
return op1.result
overloads: list[str] = [
f"({op.signature.left} {op.signature.method} {op.signature.right}) -> {op.result}"
for op in valid_operations
]
self.error(
expr.location,
f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(overloads)}",
)
return UnknownType()
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)

View File

@@ -45,4 +45,16 @@ class ComplexType:
properties: dict[str, Type]
@dataclass(frozen=True, kw_only=True)
class Operation:
signature: CallSignature
result: Type
@dataclass(frozen=True, kw_only=True)
class CallSignature:
left: Type
method: str
right: Type
Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType

View File

@@ -3,6 +3,7 @@ from typing import Optional
import midas.ast.midas as m
from midas.checker.types import (
AliasType,
Operation,
Type,
UnknownType,
)
@@ -14,7 +15,7 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
def __init__(self) -> None:
self._types: dict[str, Type] = {}
self._operations: dict[tuple[Type, str, Type], Type] = {}
self._operations: dict[Operation.CallSignature, Type] = {}
define_builtins(self)
@@ -48,10 +49,26 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
Returns:
Optional[Type]: the result type, or None if no matching operation was found
"""
operation: tuple[Type, str, Type] = (left, operator, right)
result: Optional[Type] = self._operations.get(operation)
signature: Operation.CallSignature = Operation.CallSignature(
left=left,
method=operator,
right=right,
)
result: Optional[Type] = self._operations.get(signature)
return result
def get_operations_by_name(self, name: str) -> list[Operation]:
operations: list[Operation] = []
for signature, result in self._operations.items():
if signature.method == name:
operations.append(
Operation(
signature=signature,
result=result,
)
)
return operations
def define_type(self, name: str, type: Type) -> Type:
"""Define a type in the registry
@@ -82,12 +99,16 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
Raises:
ValueError: if an operation is already defined with these operands and name
"""
operation: tuple[Type, str, Type] = (left, operator, right)
if operation in self._operations:
signature: Operation.CallSignature = Operation.CallSignature(
left=left,
method=operator,
right=right,
)
if signature in self._operations:
raise ValueError(
f"Operation {operator} already defined between {left} and {right}"
)
self._operations[operation] = result
self._operations[signature] = result
def resolve(self, stmts: list[m.Stmt]):
"""Process a sequence of statements