diff --git a/midas/resolver/midas.py b/midas/resolver/midas.py index 468f59a..063a752 100644 --- a/midas/resolver/midas.py +++ b/midas/resolver/midas.py @@ -4,8 +4,10 @@ import midas.ast.midas as m from midas.checker.types import ( AliasType, ComplexType, + GenericType, Operation, Type, + TypeVar, UnknownType, ) from midas.resolver.builtin import define_builtins @@ -18,6 +20,8 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T self._types: dict[str, Type] = {} self._operations: dict[Operation.CallSignature, Type] = {} + self._local_variables: dict[str, TypeVar] = {} + define_builtins(self) def get_type(self, name: str) -> Type: @@ -32,10 +36,11 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T Returns: Type: the type """ - type: Optional[Type] = self._types.get(name) - if type is None: - raise NameError(f"Undefined type {name}") - return type + if name in self._local_variables: + return self._local_variables[name] + if name in self._types: + return self._types[name] + raise NameError(f"Undefined type {name}") def get_operation_result( self, left: Type, operator: str, right: Type @@ -121,12 +126,21 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T stmt.accept(self) def visit_type_stmt(self, stmt: m.TypeStmt) -> None: - type: Type = stmt.type.accept(self) + params: list[TypeVar] = [] for param in stmt.params: + name: str = param.name.lexeme + bound: Optional[Type] = None if param.bound is not None: - param.bound.accept(self) + bound = param.bound.accept(self) + var = TypeVar(name=name, bound=bound) + self._local_variables[name] = var + params.append(var) + type: Type = stmt.type.accept(self) + if len(params) != 0: + type = GenericType(params=params, body=type) name: str = stmt.name.lexeme self.define_type(name, AliasType(name=name, type=type)) + self._local_variables.clear() def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...