feat(checker): resolve operation overloads with subtypes
This commit is contained in:
@@ -15,6 +15,7 @@ from midas.checker.types import (
|
|||||||
BaseType,
|
BaseType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
Function,
|
Function,
|
||||||
|
Operation,
|
||||||
Type,
|
Type,
|
||||||
UnitType,
|
UnitType,
|
||||||
UnknownType,
|
UnknownType,
|
||||||
@@ -490,14 +491,48 @@ class Checker(
|
|||||||
left: Type = self.type_of(expr.left)
|
left: Type = self.type_of(expr.left)
|
||||||
right: Type = self.type_of(expr.right)
|
right: Type = self.type_of(expr.right)
|
||||||
|
|
||||||
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
operations: list[Operation] = self.ctx.get_operations_by_name(method)
|
||||||
if result is None:
|
valid_operations: list[Operation] = []
|
||||||
|
for op in operations:
|
||||||
|
sig: Operation.CallSignature = op.signature
|
||||||
|
if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right):
|
||||||
|
valid_operations.append(op)
|
||||||
|
|
||||||
|
if len(valid_operations) == 0:
|
||||||
self.error(
|
self.error(
|
||||||
expr.location,
|
expr.location,
|
||||||
f"Undefined operation {method} between {left} and {right}",
|
f"Undefined operation {method} between {left} and {right}",
|
||||||
)
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
return result
|
elif len(valid_operations) == 1:
|
||||||
|
self.logger.debug(f"Unique operation {method} between {left} and {right}")
|
||||||
|
return valid_operations[0].result
|
||||||
|
|
||||||
|
for i, op1 in enumerate(valid_operations):
|
||||||
|
sig1: Operation.CallSignature = op1.signature
|
||||||
|
best_match: bool = True
|
||||||
|
for j, op2 in enumerate(valid_operations):
|
||||||
|
if i == j:
|
||||||
|
continue
|
||||||
|
sig2: Operation.CallSignature = op2.signature
|
||||||
|
if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype(
|
||||||
|
sig1.right, sig2.right
|
||||||
|
):
|
||||||
|
best_match = False
|
||||||
|
break
|
||||||
|
self.logger.debug(f"{op1} is a full overload of {op2}")
|
||||||
|
if best_match:
|
||||||
|
return op1.result
|
||||||
|
|
||||||
|
overloads: list[str] = [
|
||||||
|
f"({op.signature.left} {op.signature.method} {op.signature.right}) -> {op.result}"
|
||||||
|
for op in valid_operations
|
||||||
|
]
|
||||||
|
self.error(
|
||||||
|
expr.location,
|
||||||
|
f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(overloads)}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||||
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||||
|
|||||||
@@ -45,4 +45,16 @@ class ComplexType:
|
|||||||
properties: dict[str, Type]
|
properties: dict[str, Type]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Operation:
|
||||||
|
signature: CallSignature
|
||||||
|
result: Type
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class CallSignature:
|
||||||
|
left: Type
|
||||||
|
method: str
|
||||||
|
right: Type
|
||||||
|
|
||||||
|
|
||||||
Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType
|
Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from typing import Optional
|
|||||||
import midas.ast.midas as m
|
import midas.ast.midas as m
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
AliasType,
|
||||||
|
Operation,
|
||||||
Type,
|
Type,
|
||||||
UnknownType,
|
UnknownType,
|
||||||
)
|
)
|
||||||
@@ -14,7 +15,7 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._types: dict[str, Type] = {}
|
self._types: dict[str, Type] = {}
|
||||||
self._operations: dict[tuple[Type, str, Type], Type] = {}
|
self._operations: dict[Operation.CallSignature, Type] = {}
|
||||||
|
|
||||||
define_builtins(self)
|
define_builtins(self)
|
||||||
|
|
||||||
@@ -48,10 +49,26 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[Type]: the result type, or None if no matching operation was found
|
Optional[Type]: the result type, or None if no matching operation was found
|
||||||
"""
|
"""
|
||||||
operation: tuple[Type, str, Type] = (left, operator, right)
|
signature: Operation.CallSignature = Operation.CallSignature(
|
||||||
result: Optional[Type] = self._operations.get(operation)
|
left=left,
|
||||||
|
method=operator,
|
||||||
|
right=right,
|
||||||
|
)
|
||||||
|
result: Optional[Type] = self._operations.get(signature)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def get_operations_by_name(self, name: str) -> list[Operation]:
|
||||||
|
operations: list[Operation] = []
|
||||||
|
for signature, result in self._operations.items():
|
||||||
|
if signature.method == name:
|
||||||
|
operations.append(
|
||||||
|
Operation(
|
||||||
|
signature=signature,
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return operations
|
||||||
|
|
||||||
def define_type(self, name: str, type: Type) -> Type:
|
def define_type(self, name: str, type: Type) -> Type:
|
||||||
"""Define a type in the registry
|
"""Define a type in the registry
|
||||||
|
|
||||||
@@ -82,12 +99,16 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: if an operation is already defined with these operands and name
|
ValueError: if an operation is already defined with these operands and name
|
||||||
"""
|
"""
|
||||||
operation: tuple[Type, str, Type] = (left, operator, right)
|
signature: Operation.CallSignature = Operation.CallSignature(
|
||||||
if operation in self._operations:
|
left=left,
|
||||||
|
method=operator,
|
||||||
|
right=right,
|
||||||
|
)
|
||||||
|
if signature in self._operations:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Operation {operator} already defined between {left} and {right}"
|
f"Operation {operator} already defined between {left} and {right}"
|
||||||
)
|
)
|
||||||
self._operations[operation] = result
|
self._operations[signature] = result
|
||||||
|
|
||||||
def resolve(self, stmts: list[m.Stmt]):
|
def resolve(self, stmts: list[m.Stmt]):
|
||||||
"""Process a sequence of statements
|
"""Process a sequence of statements
|
||||||
|
|||||||
Reference in New Issue
Block a user