fix(checker): store member kind in registry

This commit is contained in:
2026-06-17 12:11:16 +02:00
parent 12782dda1e
commit ff79f25628
4 changed files with 42 additions and 20 deletions

View File

@@ -173,7 +173,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
base_name, base_name,
member.name.lexeme, member.name.lexeme,
member_type, member_type,
member.kind == m.MemberKind.METHOD, member.kind,
) )
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:

View File

@@ -1,6 +1,8 @@
import logging import logging
from dataclasses import dataclass
from typing import Optional from typing import Optional
from midas.ast.midas import MemberKind
from midas.checker.builtins import BUILTIN_SUBTYPES from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.types import ( from midas.checker.types import (
AliasType, AliasType,
@@ -22,11 +24,17 @@ from midas.checker.types import (
) )
@dataclass
class Member:
kind: MemberKind
type: Type
class TypesRegistry: class TypesRegistry:
def __init__(self) -> None: def __init__(self) -> None:
self.logger: logging.Logger = logging.getLogger("TypesRegistry") self.logger: logging.Logger = logging.getLogger("TypesRegistry")
self._types: dict[str, Type] = {} self._types: dict[str, Type] = {}
self._members: dict[str, dict[str, Type]] = {} self._members: dict[str, dict[str, Member]] = {}
self._predicates: dict[str, Predicate] = {} self._predicates: dict[str, Predicate] = {}
def get_type(self, name: str) -> Type: def get_type(self, name: str) -> Type:
@@ -64,26 +72,38 @@ class TypesRegistry:
return type return type
def define_member( def define_member(
self, type_name: str, member_name: str, member_type: Type, is_method: bool self,
type_name: str,
member_name: str,
member_type: Type,
kind: MemberKind,
): ):
members: dict[str, Type] = self._members.setdefault(type_name, {}) members: dict[str, Member] = self._members.setdefault(type_name, {})
if member_name in members: if member_name in members:
if not is_method: current: Member = members[member_name]
if current.kind != kind:
self.logger.error( self.logger.error(
f"Member '{member_name}' already defined for type {type_name}" f"Member '{member_name}' is already defined as a {current.kind},"
+ f" cannot define a {kind} with the same name"
) )
return return
current: Type = members[member_name] if kind != MemberKind.METHOD:
self.logger.error(
f"Member '{member_name}' already defined for type {type_name},"
+ " only methods can be overloaded"
)
return
combined: Type combined: Type
match current: match current.type:
case OverloadedFunction(overloads=overloads): case OverloadedFunction(overloads=overloads):
combined = OverloadedFunction(overloads=overloads + [member_type]) combined = OverloadedFunction(overloads=overloads + [member_type])
case _: case _:
combined = OverloadedFunction(overloads=[current, member_type]) combined = OverloadedFunction(overloads=[current.type, member_type])
members[member_name] = combined members[member_name] = Member(kind=current.kind, type=combined)
else: else:
members[member_name] = member_type members[member_name] = Member(kind=kind, type=member_type)
def define_predicate(self, name: str, predicate: Predicate): def define_predicate(self, name: str, predicate: Predicate):
if name in self._predicates: if name in self._predicates:
@@ -327,13 +347,13 @@ class TypesRegistry:
case BaseType(name=name): case BaseType(name=name):
if name in self._members: if name in self._members:
if member_name in self._members[name]: if member_name in self._members[name]:
return self._members[name][member_name] return self._members[name][member_name].type
return None return None
case AliasType(name=name, type=base): case AliasType(name=name, type=base):
if name in self._members: if name in self._members:
if member_name in self._members[name]: if member_name in self._members[name]:
return self._members[name][member_name] return self._members[name][member_name].type
return self.lookup_member(base, member_name) return self.lookup_member(base, member_name)
case AppliedType(name=name, body=body, args=args): case AppliedType(name=name, body=body, args=args):
@@ -347,7 +367,7 @@ class TypesRegistry:
} }
if name in self._members: if name in self._members:
if member_name in self._members[name]: if member_name in self._members[name]:
member_type: Type = self._members[name][member_name] member_type: Type = self._members[name][member_name].type
return substitute_typevars(member_type, substitutions) return substitute_typevars(member_type, substitutions)
member_type2: Optional[Type] = self.lookup_member(body, member_name) member_type2: Optional[Type] = self.lookup_member(body, member_name)

View File

@@ -1,6 +1,6 @@
from typing import Literal, Optional, cast from typing import Literal, Optional, cast
from midas.checker.registry import TypesRegistry from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import ( from midas.checker.types import (
AppliedType, AppliedType,
ConstraintType, ConstraintType,
@@ -54,9 +54,9 @@ class VarianceInferrer:
self.tracker = Tracker(type.params) self.tracker = Tracker(type.params)
self.walk(type.body, 1, type.name) self.walk(type.body, 1, type.name)
members: dict[str, Type] = self.types._members.get(type.name, {}) members: dict[str, Member] = self.types._members.get(type.name, {})
for name, member in members.items(): for name, member in members.items():
self.walk(member, 1, type.name, [f"member:'{name}'"]) self.walk(member.type, 1, type.name, [f"member:'{name}'"])
return GenericType( return GenericType(
name=type.name, name=type.name,

View File

@@ -10,6 +10,7 @@ import click
from midas.ast.printer import MidasPrinter from midas.ast.printer import MidasPrinter
from midas.checker.checker import TypeChecker from midas.checker.checker import TypeChecker
from midas.checker.registry import Member
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
@@ -38,7 +39,7 @@ def dump_registry(
print("##### Types #####") print("##### Types #####")
for name, type in checker.types._types.items(): for name, type in checker.types._types.items():
members: dict[str, Type] = checker.types._members.get(name, {}) members: dict[str, Member] = checker.types._members.get(name, {})
params: str = "" params: str = ""
if isinstance(type, GenericType): if isinstance(type, GenericType):
params = ", ".join(map(str, type.params)) params = ", ".join(map(str, type.params))
@@ -46,8 +47,9 @@ def dump_registry(
print(f"{name}{params} = {base_type(type)}") print(f"{name}{params} = {base_type(type)}")
if len(members) != 0: if len(members) != 0:
print(" " * 4 + "Members:") print(" " * 4 + "Members:")
for member_name, member_type in members.items(): for member_name, member in members.items():
print(" " * 8 + f"{member_name}: {member_type}") kind: str = member.kind.name
print(" " * 8 + f"({kind:8}) {member_name}: {member.type}")
print("##### Predicates #####") print("##### Predicates #####")
printer = MidasPrinter() printer = MidasPrinter()