fix(checker): store member kind in registry
This commit is contained in:
@@ -173,7 +173,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
|||||||
base_name,
|
base_name,
|
||||||
member.name.lexeme,
|
member.name.lexeme,
|
||||||
member_type,
|
member_type,
|
||||||
member.kind == m.MemberKind.METHOD,
|
member.kind,
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from midas.ast.midas import MemberKind
|
||||||
from midas.checker.builtins import BUILTIN_SUBTYPES
|
from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
AliasType,
|
||||||
@@ -22,11 +24,17 @@ from midas.checker.types import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Member:
|
||||||
|
kind: MemberKind
|
||||||
|
type: Type
|
||||||
|
|
||||||
|
|
||||||
class TypesRegistry:
|
class TypesRegistry:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
|
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
|
||||||
self._types: dict[str, Type] = {}
|
self._types: dict[str, Type] = {}
|
||||||
self._members: dict[str, dict[str, Type]] = {}
|
self._members: dict[str, dict[str, Member]] = {}
|
||||||
self._predicates: dict[str, Predicate] = {}
|
self._predicates: dict[str, Predicate] = {}
|
||||||
|
|
||||||
def get_type(self, name: str) -> Type:
|
def get_type(self, name: str) -> Type:
|
||||||
@@ -64,26 +72,38 @@ class TypesRegistry:
|
|||||||
return type
|
return type
|
||||||
|
|
||||||
def define_member(
|
def define_member(
|
||||||
self, type_name: str, member_name: str, member_type: Type, is_method: bool
|
self,
|
||||||
|
type_name: str,
|
||||||
|
member_name: str,
|
||||||
|
member_type: Type,
|
||||||
|
kind: MemberKind,
|
||||||
):
|
):
|
||||||
members: dict[str, Type] = self._members.setdefault(type_name, {})
|
members: dict[str, Member] = self._members.setdefault(type_name, {})
|
||||||
if member_name in members:
|
if member_name in members:
|
||||||
if not is_method:
|
current: Member = members[member_name]
|
||||||
|
if current.kind != kind:
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
f"Member '{member_name}' already defined for type {type_name}"
|
f"Member '{member_name}' is already defined as a {current.kind},"
|
||||||
|
+ f" cannot define a {kind} with the same name"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
current: Type = members[member_name]
|
if kind != MemberKind.METHOD:
|
||||||
|
self.logger.error(
|
||||||
|
f"Member '{member_name}' already defined for type {type_name},"
|
||||||
|
+ " only methods can be overloaded"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
combined: Type
|
combined: Type
|
||||||
match current:
|
match current.type:
|
||||||
case OverloadedFunction(overloads=overloads):
|
case OverloadedFunction(overloads=overloads):
|
||||||
combined = OverloadedFunction(overloads=overloads + [member_type])
|
combined = OverloadedFunction(overloads=overloads + [member_type])
|
||||||
case _:
|
case _:
|
||||||
combined = OverloadedFunction(overloads=[current, member_type])
|
combined = OverloadedFunction(overloads=[current.type, member_type])
|
||||||
members[member_name] = combined
|
members[member_name] = Member(kind=current.kind, type=combined)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
members[member_name] = member_type
|
members[member_name] = Member(kind=kind, type=member_type)
|
||||||
|
|
||||||
def define_predicate(self, name: str, predicate: Predicate):
|
def define_predicate(self, name: str, predicate: Predicate):
|
||||||
if name in self._predicates:
|
if name in self._predicates:
|
||||||
@@ -327,13 +347,13 @@ class TypesRegistry:
|
|||||||
case BaseType(name=name):
|
case BaseType(name=name):
|
||||||
if name in self._members:
|
if name in self._members:
|
||||||
if member_name in self._members[name]:
|
if member_name in self._members[name]:
|
||||||
return self._members[name][member_name]
|
return self._members[name][member_name].type
|
||||||
return None
|
return None
|
||||||
|
|
||||||
case AliasType(name=name, type=base):
|
case AliasType(name=name, type=base):
|
||||||
if name in self._members:
|
if name in self._members:
|
||||||
if member_name in self._members[name]:
|
if member_name in self._members[name]:
|
||||||
return self._members[name][member_name]
|
return self._members[name][member_name].type
|
||||||
return self.lookup_member(base, member_name)
|
return self.lookup_member(base, member_name)
|
||||||
|
|
||||||
case AppliedType(name=name, body=body, args=args):
|
case AppliedType(name=name, body=body, args=args):
|
||||||
@@ -347,7 +367,7 @@ class TypesRegistry:
|
|||||||
}
|
}
|
||||||
if name in self._members:
|
if name in self._members:
|
||||||
if member_name in self._members[name]:
|
if member_name in self._members[name]:
|
||||||
member_type: Type = self._members[name][member_name]
|
member_type: Type = self._members[name][member_name].type
|
||||||
return substitute_typevars(member_type, substitutions)
|
return substitute_typevars(member_type, substitutions)
|
||||||
|
|
||||||
member_type2: Optional[Type] = self.lookup_member(body, member_name)
|
member_type2: Optional[Type] = self.lookup_member(body, member_name)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Literal, Optional, cast
|
from typing import Literal, Optional, cast
|
||||||
|
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import Member, TypesRegistry
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AppliedType,
|
AppliedType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
@@ -54,9 +54,9 @@ class VarianceInferrer:
|
|||||||
self.tracker = Tracker(type.params)
|
self.tracker = Tracker(type.params)
|
||||||
|
|
||||||
self.walk(type.body, 1, type.name)
|
self.walk(type.body, 1, type.name)
|
||||||
members: dict[str, Type] = self.types._members.get(type.name, {})
|
members: dict[str, Member] = self.types._members.get(type.name, {})
|
||||||
for name, member in members.items():
|
for name, member in members.items():
|
||||||
self.walk(member, 1, type.name, [f"member:'{name}'"])
|
self.walk(member.type, 1, type.name, [f"member:'{name}'"])
|
||||||
|
|
||||||
return GenericType(
|
return GenericType(
|
||||||
name=type.name,
|
name=type.name,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import click
|
|||||||
|
|
||||||
from midas.ast.printer import MidasPrinter
|
from midas.ast.printer import MidasPrinter
|
||||||
from midas.checker.checker import TypeChecker
|
from midas.checker.checker import TypeChecker
|
||||||
|
from midas.checker.registry import Member
|
||||||
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
|
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
|
||||||
|
|
||||||
|
|
||||||
@@ -38,7 +39,7 @@ def dump_registry(
|
|||||||
|
|
||||||
print("##### Types #####")
|
print("##### Types #####")
|
||||||
for name, type in checker.types._types.items():
|
for name, type in checker.types._types.items():
|
||||||
members: dict[str, Type] = checker.types._members.get(name, {})
|
members: dict[str, Member] = checker.types._members.get(name, {})
|
||||||
params: str = ""
|
params: str = ""
|
||||||
if isinstance(type, GenericType):
|
if isinstance(type, GenericType):
|
||||||
params = ", ".join(map(str, type.params))
|
params = ", ".join(map(str, type.params))
|
||||||
@@ -46,8 +47,9 @@ def dump_registry(
|
|||||||
print(f"{name}{params} = {base_type(type)}")
|
print(f"{name}{params} = {base_type(type)}")
|
||||||
if len(members) != 0:
|
if len(members) != 0:
|
||||||
print(" " * 4 + "Members:")
|
print(" " * 4 + "Members:")
|
||||||
for member_name, member_type in members.items():
|
for member_name, member in members.items():
|
||||||
print(" " * 8 + f"{member_name}: {member_type}")
|
kind: str = member.kind.name
|
||||||
|
print(" " * 8 + f"({kind:8}) {member_name}: {member.type}")
|
||||||
|
|
||||||
print("##### Predicates #####")
|
print("##### Predicates #####")
|
||||||
printer = MidasPrinter()
|
printer = MidasPrinter()
|
||||||
|
|||||||
Reference in New Issue
Block a user