diff --git a/midas/checker/resolver.py b/midas/checker/resolver.py index 0b7d990..02fcbbc 100644 --- a/midas/checker/resolver.py +++ b/midas/checker/resolver.py @@ -13,7 +13,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): def __init__(self): self.locals: dict[p.Expr, int] = {} - self.scopes: list[dict[str, bool]] = [] + self.scopes: list[dict[str, bool]] = [{}] def resolve(self, *objects: p.Stmt | p.Expr) -> None: """Resolve the given statements or expressions""" @@ -77,6 +77,12 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): self.locals[expr] = i return + def is_defined(self, name: str) -> bool: + for scope in self.scopes: + if name in scope: + return True + return False + def resolve_function(self, function: p.Function) -> None: """Resolve a function definition @@ -111,7 +117,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): self.resolve(stmt.value) for target in stmt.targets: match target: - case p.VariableExpr() | p.GetExpr(): + case p.VariableExpr(name=name): + if not self.is_defined(name): + self.declare(name) + self.define(name) + target.accept(self) + + case p.GetExpr(): target.accept(self) case _: raise Exception(f"Unsupported assignment to {target}")