diff --git a/examples/01_simple_type_checking/03_control_flow.py b/examples/01_simple_type_checking/03_control_flow.py index 2857f20..07e11a8 100644 --- a/examples/01_simple_type_checking/03_control_flow.py +++ b/examples/01_simple_type_checking/03_control_flow.py @@ -6,4 +6,9 @@ def minimum(x: int, y: int): a = 15 b = 72 -c = minimum(a, b) \ No newline at end of file +c = minimum(a, b) + +def factorial(n: int) -> int: + if n <= 1: + return 1 + return n * factorial(n - 1) \ No newline at end of file diff --git a/midas/checker/checker.py b/midas/checker/checker.py index 1b71f3f..603f0be 100644 --- a/midas/checker/checker.py +++ b/midas/checker/checker.py @@ -223,6 +223,19 @@ class Checker( for arg in pos_args + args + kw_args: env.define(arg.name, arg.type) + + returns_hint: Optional[Type] = None + if stmt.returns is not None: + returns_hint = stmt.returns.accept(self) + # Early define to handle simple fully-typed recursion + inside_function: Function = Function( + name=stmt.name, + pos_args=pos_args, + args=args, + kw_args=kw_args, + returns=returns_hint, + ) + self.env.define(stmt.name, inside_function) returned: bool = self.evaluate_block(stmt.body, env) inferred_return: Type = UnknownType() @@ -236,9 +249,11 @@ class Checker( stmt.location, f"Mixed return types: {env.return_types}", ) + returns: Type = UnknownType() - if stmt.returns is not None: - returns = stmt.returns.accept(self) + if returns_hint is not None: + assert stmt.returns is not None + returns = returns_hint if returns != inferred_return: self.error( stmt.returns.location,