fix(resolver): define variable on assignment

if a variable is not already defined when an assignment is visited, it is then defined in the current scope
This commit is contained in:
2026-06-09 08:06:46 +02:00
parent a4f5db7ece
commit b8bb8190c4

View File

@@ -13,7 +13,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
def __init__(self): def __init__(self):
self.locals: dict[p.Expr, int] = {} self.locals: dict[p.Expr, int] = {}
self.scopes: list[dict[str, bool]] = [] self.scopes: list[dict[str, bool]] = [{}]
def resolve(self, *objects: p.Stmt | p.Expr) -> None: def resolve(self, *objects: p.Stmt | p.Expr) -> None:
"""Resolve the given statements or expressions""" """Resolve the given statements or expressions"""
@@ -77,6 +77,12 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
self.locals[expr] = i self.locals[expr] = i
return return
def is_defined(self, name: str) -> bool:
for scope in self.scopes:
if name in scope:
return True
return False
def resolve_function(self, function: p.Function) -> None: def resolve_function(self, function: p.Function) -> None:
"""Resolve a function definition """Resolve a function definition
@@ -111,7 +117,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
self.resolve(stmt.value) self.resolve(stmt.value)
for target in stmt.targets: for target in stmt.targets:
match target: match target:
case p.VariableExpr() | p.GetExpr(): case p.VariableExpr(name=name):
if not self.is_defined(name):
self.declare(name)
self.define(name)
target.accept(self)
case p.GetExpr():
target.accept(self) target.accept(self)
case _: case _:
raise Exception(f"Unsupported assignment to {target}") raise Exception(f"Unsupported assignment to {target}")