diff --git a/midas/checker/python.py b/midas/checker/python.py index 4481ba7..316836e 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -78,17 +78,37 @@ class PythonTyper( return TypedAST(stmts=stmts, judgements=self.judgements) - def type_of(self, expr: p.Expr) -> Type: + def judge(self, expr: p.Expr, type: Type): + """Record a typing judgement + + Args: + expr (p.Expr): the judged expression + type (Type): the type of the expression + """ + self.judgements.append((expr, type)) + + def compute_type(self, expr: p.Expr) -> Type: """Evaluate the type of an expression + Args: + expr (p.Expr): the expression to type + + Returns: + Type: the type of the given expression + """ + return expr.accept(self) + + def type_of(self, expr: p.Expr) -> Type: + """Evaluate the type of an expression and record the judgement + Args: expr (p.Expr): the expression to evaluate Returns: Type: the type of the given expression """ - type: Type = expr.accept(self) - self.judgements.append((expr, type)) + type: Type = self.compute_type(expr) + self.judge(expr, type) return type def resolve_type_expr(self, expr: p.MidasType) -> Type: @@ -337,10 +357,14 @@ class PythonTyper( def visit_for_stmt(self, stmt: p.ForStmt) -> None: item_type: Optional[Type] = self._get_iterator_type(stmt.iterator) if item_type is None: - self.reporter.error(stmt.iterator.location, "Iterator is not an iterator") + iterator_type: Type = self.compute_type(stmt.iterator) + self.reporter.error( + stmt.iterator.location, f"{iterator_type} is not iterable" + ) item_type = UnknownType() self._assign(stmt.location, stmt.target, item_type) + self.judge(stmt.target, item_type) env: Environment = Environment(self.env) body_returned: bool = self.process_block(stmt.body, env) if body_returned: @@ -906,9 +930,7 @@ class PythonTyper( return None index: p.Expr = p.LiteralExpr(location=expr.location, value=0) - index_type: Type = index.accept( - self - ) # skip type_of() to avoid recording judgement + index_type: Type = self.compute_type(index) result: Optional[Type] = self._get_call_result( location=expr.location, callee=getitem,