diff --git a/midas/checker/midas.py b/midas/checker/midas.py index 8c0fede..cc2cec2 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -85,15 +85,19 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: self._resolve_type_params(stmt.params) - base: Type = stmt.type.accept(self) - for op in stmt.operations: - right: Type = op.operand.accept(self) - result: Type = op.result.accept(self) - self.types.define_operation( - left=base, - operator=op.name.lexeme, - right=right, - result=result, + base_name: str = stmt.name.lexeme + try: + _ = self.get_type(base_name) + except NameError: + self.reporter.error(stmt.name.get_location(), f"Unknown type '{base_name}'") + + for member in stmt.members: + member_type: Type = member.type.accept(self) + self.types.define_member( + base_name, + member.name.lexeme, + member_type, + member.kind == m.MemberKind.METHOD, ) def visit_op_stmt(self, stmt: m.OpStmt) -> None: ... diff --git a/midas/checker/registry.py b/midas/checker/registry.py index d5c432a..16c36a5 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -1,3 +1,4 @@ +import logging from typing import Optional from midas.checker.builtins import BUILTIN_SUBTYPES @@ -9,6 +10,7 @@ from midas.checker.types import ( Function, GenericType, Operation, + OverloadedFunction, Type, substitute_typevars, ) @@ -16,7 +18,9 @@ from midas.checker.types import ( class TypesRegistry: def __init__(self) -> None: + self.logger: logging.Logger = logging.getLogger("TypesRegistry") self._types: dict[str, Type] = {} + self._members: dict[str, dict[str, Type]] = {} self._operations: dict[Operation.CallSignature, Type] = {} def get_type(self, name: str) -> Type: @@ -86,6 +90,28 @@ class TypesRegistry: self._types[name] = type return type + def define_member( + self, type_name: str, member_name: str, member_type: Type, is_method: bool + ): + members: dict[str, Type] = self._members.setdefault(type_name, {}) + if member_name in members: + if not is_method: + self.logger.error( + f"Member '{member_name}' already defined for type {type_name}" + ) + return + current: Type = members[member_name] + combined: Type + match current: + case OverloadedFunction(overloads=overloads): + combined = OverloadedFunction(overloads=overloads + [member_type]) + case _: + combined = OverloadedFunction(overloads=[current, member_type]) + members[member_name] = combined + + else: + members[member_name] = member_type + def define_operation(self, left: Type, operator: str, right: Type, result: Type): """Define an operation in the registry diff --git a/midas/checker/types.py b/midas/checker/types.py index 9057e4c..dd8c173 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -69,6 +69,14 @@ class Function: return f"{self.name}: {self.type}{opt}" +@dataclass(frozen=True, kw_only=True) +class OverloadedFunction: + overloads: list[Type] + + def __str__(self) -> str: + return "" + + @dataclass(frozen=True, kw_only=True) class ComplexType: members: dict[str, Type] @@ -209,6 +217,7 @@ Type = ( | UnknownType | UnitType | Function + | OverloadedFunction | ComplexType | ExtensionType | TypeVar