diff --git a/midas/checker/midas.py b/midas/checker/midas.py index e79719c..13e83b6 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -173,7 +173,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type base_name, member.name.lexeme, member_type, - member.kind == m.MemberKind.METHOD, + member.kind, ) def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: diff --git a/midas/checker/registry.py b/midas/checker/registry.py index b8f7dfa..b787f20 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -1,6 +1,8 @@ import logging +from dataclasses import dataclass from typing import Optional +from midas.ast.midas import MemberKind from midas.checker.builtins import BUILTIN_SUBTYPES from midas.checker.types import ( AliasType, @@ -22,11 +24,17 @@ from midas.checker.types import ( ) +@dataclass +class Member: + kind: MemberKind + type: Type + + class TypesRegistry: def __init__(self) -> None: self.logger: logging.Logger = logging.getLogger("TypesRegistry") self._types: dict[str, Type] = {} - self._members: dict[str, dict[str, Type]] = {} + self._members: dict[str, dict[str, Member]] = {} self._predicates: dict[str, Predicate] = {} def get_type(self, name: str) -> Type: @@ -64,26 +72,38 @@ class TypesRegistry: return type def define_member( - self, type_name: str, member_name: str, member_type: Type, is_method: bool + self, + type_name: str, + member_name: str, + member_type: Type, + kind: MemberKind, ): - members: dict[str, Type] = self._members.setdefault(type_name, {}) + members: dict[str, Member] = self._members.setdefault(type_name, {}) if member_name in members: - if not is_method: + current: Member = members[member_name] + if current.kind != kind: self.logger.error( - f"Member '{member_name}' already defined for type {type_name}" + f"Member '{member_name}' is already defined as a {current.kind}," + + f" cannot define a {kind} with the same name" ) return - current: Type = members[member_name] + if kind != MemberKind.METHOD: + self.logger.error( + f"Member '{member_name}' already defined for type {type_name}," + + " only methods can be overloaded" + ) + return + combined: Type - match current: + match current.type: case OverloadedFunction(overloads=overloads): combined = OverloadedFunction(overloads=overloads + [member_type]) case _: - combined = OverloadedFunction(overloads=[current, member_type]) - members[member_name] = combined + combined = OverloadedFunction(overloads=[current.type, member_type]) + members[member_name] = Member(kind=current.kind, type=combined) else: - members[member_name] = member_type + members[member_name] = Member(kind=kind, type=member_type) def define_predicate(self, name: str, predicate: Predicate): if name in self._predicates: @@ -327,13 +347,13 @@ class TypesRegistry: case BaseType(name=name): if name in self._members: if member_name in self._members[name]: - return self._members[name][member_name] + return self._members[name][member_name].type return None case AliasType(name=name, type=base): if name in self._members: if member_name in self._members[name]: - return self._members[name][member_name] + return self._members[name][member_name].type return self.lookup_member(base, member_name) case AppliedType(name=name, body=body, args=args): @@ -347,7 +367,7 @@ class TypesRegistry: } if name in self._members: if member_name in self._members[name]: - member_type: Type = self._members[name][member_name] + member_type: Type = self._members[name][member_name].type return substitute_typevars(member_type, substitutions) member_type2: Optional[Type] = self.lookup_member(body, member_name) diff --git a/midas/checker/variance.py b/midas/checker/variance.py index b620958..e7f5ac0 100644 --- a/midas/checker/variance.py +++ b/midas/checker/variance.py @@ -1,6 +1,6 @@ from typing import Literal, Optional, cast -from midas.checker.registry import TypesRegistry +from midas.checker.registry import Member, TypesRegistry from midas.checker.types import ( AppliedType, ConstraintType, @@ -54,9 +54,9 @@ class VarianceInferrer: self.tracker = Tracker(type.params) self.walk(type.body, 1, type.name) - members: dict[str, Type] = self.types._members.get(type.name, {}) + members: dict[str, Member] = self.types._members.get(type.name, {}) for name, member in members.items(): - self.walk(member, 1, type.name, [f"member:'{name}'"]) + self.walk(member.type, 1, type.name, [f"member:'{name}'"]) return GenericType( name=type.name, diff --git a/midas/cli/commands/registry.py b/midas/cli/commands/registry.py index 7c0c521..12c222b 100644 --- a/midas/cli/commands/registry.py +++ b/midas/cli/commands/registry.py @@ -10,6 +10,7 @@ import click from midas.ast.printer import MidasPrinter from midas.checker.checker import TypeChecker +from midas.checker.registry import Member from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type @@ -38,7 +39,7 @@ def dump_registry( print("##### Types #####") for name, type in checker.types._types.items(): - members: dict[str, Type] = checker.types._members.get(name, {}) + members: dict[str, Member] = checker.types._members.get(name, {}) params: str = "" if isinstance(type, GenericType): params = ", ".join(map(str, type.params)) @@ -46,8 +47,9 @@ def dump_registry( print(f"{name}{params} = {base_type(type)}") if len(members) != 0: print(" " * 4 + "Members:") - for member_name, member_type in members.items(): - print(" " * 8 + f"{member_name}: {member_type}") + for member_name, member in members.items(): + kind: str = member.kind.name + print(" " * 8 + f"({kind:8}) {member_name}: {member.type}") print("##### Predicates #####") printer = MidasPrinter()