fix(checker): early define fully-typed function
to handle simple recursion cases where the function has an explicit return type hint, the function must be defined before evaluating its body
This commit is contained in:
@@ -6,4 +6,9 @@ def minimum(x: int, y: int):
|
||||
|
||||
a = 15
|
||||
b = 72
|
||||
c = minimum(a, b)
|
||||
c = minimum(a, b)
|
||||
|
||||
def factorial(n: int) -> int:
|
||||
if n <= 1:
|
||||
return 1
|
||||
return n * factorial(n - 1)
|
||||
@@ -223,6 +223,19 @@ class Checker(
|
||||
|
||||
for arg in pos_args + args + kw_args:
|
||||
env.define(arg.name, arg.type)
|
||||
|
||||
returns_hint: Optional[Type] = None
|
||||
if stmt.returns is not None:
|
||||
returns_hint = stmt.returns.accept(self)
|
||||
# Early define to handle simple fully-typed recursion
|
||||
inside_function: Function = Function(
|
||||
name=stmt.name,
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns_hint,
|
||||
)
|
||||
self.env.define(stmt.name, inside_function)
|
||||
|
||||
returned: bool = self.evaluate_block(stmt.body, env)
|
||||
inferred_return: Type = UnknownType()
|
||||
@@ -236,9 +249,11 @@ class Checker(
|
||||
stmt.location,
|
||||
f"Mixed return types: {env.return_types}",
|
||||
)
|
||||
|
||||
returns: Type = UnknownType()
|
||||
if stmt.returns is not None:
|
||||
returns = stmt.returns.accept(self)
|
||||
if returns_hint is not None:
|
||||
assert stmt.returns is not None
|
||||
returns = returns_hint
|
||||
if returns != inferred_return:
|
||||
self.error(
|
||||
stmt.returns.location,
|
||||
|
||||
Reference in New Issue
Block a user