diff --git a/midas/checker/midas.py b/midas/checker/midas.py index 3764c03..42d3062 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -102,7 +102,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type base_name, member.name.lexeme, member_type, - member.kind == m.MemberKind.METHOD, + member.kind, ) def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: diff --git a/midas/checker/registry.py b/midas/checker/registry.py index 6591548..8b7c2fc 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -1,6 +1,8 @@ import logging +from dataclasses import dataclass from typing import Optional +from midas.ast.midas import MemberKind from midas.checker.builtins import BUILTIN_SUBTYPES from midas.checker.types import ( AliasType, @@ -19,11 +21,17 @@ from midas.checker.types import ( ) +@dataclass +class Member: + kind: MemberKind + type: Type + + class TypesRegistry: def __init__(self) -> None: self.logger: logging.Logger = logging.getLogger("TypesRegistry") self._types: dict[str, Type] = {} - self._members: dict[str, dict[str, Type]] = {} + self._members: dict[str, dict[str, Member]] = {} def get_type(self, name: str) -> Type: """Get a type from its name @@ -60,26 +68,38 @@ class TypesRegistry: return type def define_member( - self, type_name: str, member_name: str, member_type: Type, is_method: bool + self, + type_name: str, + member_name: str, + member_type: Type, + kind: MemberKind, ): - members: dict[str, Type] = self._members.setdefault(type_name, {}) + members: dict[str, Member] = self._members.setdefault(type_name, {}) if member_name in members: - if not is_method: + current: Member = members[member_name] + if current.kind != kind: self.logger.error( - f"Member '{member_name}' already defined for type {type_name}" + f"Member '{member_name}' is already defined as a {current.kind}," + + f" cannot define a {kind} with the same name" ) return - current: Type = members[member_name] + if kind != MemberKind.METHOD: + self.logger.error( + f"Member '{member_name}' already defined for type {type_name}," + + " only methods can be overloaded" + ) + return + combined: Type - match current: + match current.type: case OverloadedFunction(overloads=overloads): combined = OverloadedFunction(overloads=overloads + [member_type]) case _: - combined = OverloadedFunction(overloads=[current, member_type]) - members[member_name] = combined + combined = OverloadedFunction(overloads=[current.type, member_type]) + members[member_name] = Member(kind=current.kind, type=combined) else: - members[member_name] = member_type + members[member_name] = Member(kind=kind, type=member_type) def is_subtype(self, type1: Type, type2: Type) -> bool: """Check whether `type1` is a subtype of `type2` @@ -297,13 +317,13 @@ class TypesRegistry: case BaseType(name=name): if name in self._members: if member_name in self._members[name]: - return self._members[name][member_name] + return self._members[name][member_name].type return None case AliasType(name=name, type=base): if name in self._members: if member_name in self._members[name]: - return self._members[name][member_name] + return self._members[name][member_name].type return self.lookup_member(base, member_name) case AppliedType(name=name, body=body, args=args): @@ -317,7 +337,7 @@ class TypesRegistry: } if name in self._members: if member_name in self._members[name]: - member_type: Type = self._members[name][member_name] + member_type: Type = self._members[name][member_name].type return substitute_typevars(member_type, substitutions) member_type2: Optional[Type] = self.lookup_member(body, member_name)