fix(checker): store member kind in registry

This commit is contained in:
2026-06-17 12:11:16 +02:00
parent 12782dda1e
commit ff79f25628
4 changed files with 42 additions and 20 deletions

View File

@@ -173,7 +173,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
base_name,
member.name.lexeme,
member_type,
member.kind == m.MemberKind.METHOD,
member.kind,
)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:

View File

@@ -1,6 +1,8 @@
import logging
from dataclasses import dataclass
from typing import Optional
from midas.ast.midas import MemberKind
from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.types import (
AliasType,
@@ -22,11 +24,17 @@ from midas.checker.types import (
)
@dataclass
class Member:
kind: MemberKind
type: Type
class TypesRegistry:
def __init__(self) -> None:
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
self._types: dict[str, Type] = {}
self._members: dict[str, dict[str, Type]] = {}
self._members: dict[str, dict[str, Member]] = {}
self._predicates: dict[str, Predicate] = {}
def get_type(self, name: str) -> Type:
@@ -64,26 +72,38 @@ class TypesRegistry:
return type
def define_member(
self, type_name: str, member_name: str, member_type: Type, is_method: bool
self,
type_name: str,
member_name: str,
member_type: Type,
kind: MemberKind,
):
members: dict[str, Type] = self._members.setdefault(type_name, {})
members: dict[str, Member] = self._members.setdefault(type_name, {})
if member_name in members:
if not is_method:
current: Member = members[member_name]
if current.kind != kind:
self.logger.error(
f"Member '{member_name}' already defined for type {type_name}"
f"Member '{member_name}' is already defined as a {current.kind},"
+ f" cannot define a {kind} with the same name"
)
return
current: Type = members[member_name]
if kind != MemberKind.METHOD:
self.logger.error(
f"Member '{member_name}' already defined for type {type_name},"
+ " only methods can be overloaded"
)
return
combined: Type
match current:
match current.type:
case OverloadedFunction(overloads=overloads):
combined = OverloadedFunction(overloads=overloads + [member_type])
case _:
combined = OverloadedFunction(overloads=[current, member_type])
members[member_name] = combined
combined = OverloadedFunction(overloads=[current.type, member_type])
members[member_name] = Member(kind=current.kind, type=combined)
else:
members[member_name] = member_type
members[member_name] = Member(kind=kind, type=member_type)
def define_predicate(self, name: str, predicate: Predicate):
if name in self._predicates:
@@ -327,13 +347,13 @@ class TypesRegistry:
case BaseType(name=name):
if name in self._members:
if member_name in self._members[name]:
return self._members[name][member_name]
return self._members[name][member_name].type
return None
case AliasType(name=name, type=base):
if name in self._members:
if member_name in self._members[name]:
return self._members[name][member_name]
return self._members[name][member_name].type
return self.lookup_member(base, member_name)
case AppliedType(name=name, body=body, args=args):
@@ -347,7 +367,7 @@ class TypesRegistry:
}
if name in self._members:
if member_name in self._members[name]:
member_type: Type = self._members[name][member_name]
member_type: Type = self._members[name][member_name].type
return substitute_typevars(member_type, substitutions)
member_type2: Optional[Type] = self.lookup_member(body, member_name)

View File

@@ -1,6 +1,6 @@
from typing import Literal, Optional, cast
from midas.checker.registry import TypesRegistry
from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import (
AppliedType,
ConstraintType,
@@ -54,9 +54,9 @@ class VarianceInferrer:
self.tracker = Tracker(type.params)
self.walk(type.body, 1, type.name)
members: dict[str, Type] = self.types._members.get(type.name, {})
members: dict[str, Member] = self.types._members.get(type.name, {})
for name, member in members.items():
self.walk(member, 1, type.name, [f"member:'{name}'"])
self.walk(member.type, 1, type.name, [f"member:'{name}'"])
return GenericType(
name=type.name,

View File

@@ -10,6 +10,7 @@ import click
from midas.ast.printer import MidasPrinter
from midas.checker.checker import TypeChecker
from midas.checker.registry import Member
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
@@ -38,7 +39,7 @@ def dump_registry(
print("##### Types #####")
for name, type in checker.types._types.items():
members: dict[str, Type] = checker.types._members.get(name, {})
members: dict[str, Member] = checker.types._members.get(name, {})
params: str = ""
if isinstance(type, GenericType):
params = ", ".join(map(str, type.params))
@@ -46,8 +47,9 @@ def dump_registry(
print(f"{name}{params} = {base_type(type)}")
if len(members) != 0:
print(" " * 4 + "Members:")
for member_name, member_type in members.items():
print(" " * 8 + f"{member_name}: {member_type}")
for member_name, member in members.items():
kind: str = member.kind.name
print(" " * 8 + f"({kind:8}) {member_name}: {member.type}")
print("##### Predicates #####")
printer = MidasPrinter()