Merge pull request 'For loops' (#11) from feat/for-loops into main

Reviewed-on: #11
This commit was merged in pull request #11.
This commit is contained in:
2026-06-15 22:51:12 +00:00
17 changed files with 224 additions and 51 deletions

View File

@@ -86,6 +86,12 @@ class Pass:
pass
class ForStmt:
target: Expr
iterator: Expr
body: list[Stmt]
###<

View File

@@ -596,6 +596,23 @@ class PythonAstPrinter(
def visit_pass(self, stmt: p.Pass) -> None:
self._write_line("Pass")
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
self._write_line("ForStmt")
with self._child_level():
self._write_line("target")
with self._child_level(single=True):
stmt.target.accept(self)
self._write_line("iterator")
with self._child_level(single=True):
stmt.iterator.accept(self)
self._write_line("body", last=True)
with self._child_level():
for i, body_stmt in enumerate(stmt.body):
self._idx = i
if i == len(stmt.body) - 1:
self._mark_last()
body_stmt.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self._write_line("BinaryExpr")
with self._child_level():

View File

@@ -110,6 +110,9 @@ class Stmt(ABC):
@abstractmethod
def visit_pass(self, stmt: Pass) -> T: ...
@abstractmethod
def visit_for_stmt(self, stmt: ForStmt) -> T: ...
@dataclass(frozen=True)
class ExpressionStmt(Stmt):
@@ -189,6 +192,16 @@ class Pass(Stmt):
return visitor.visit_pass(self)
@dataclass(frozen=True)
class ForStmt(Stmt):
target: Expr
iterator: Expr
body: list[Stmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_for_stmt(self)
###############
# Expressions #
###############

View File

@@ -78,17 +78,37 @@ class PythonTyper(
return TypedAST(stmts=stmts, judgements=self.judgements)
def type_of(self, expr: p.Expr) -> Type:
def judge(self, expr: p.Expr, type: Type):
"""Record a typing judgement
Args:
expr (p.Expr): the judged expression
type (Type): the type of the expression
"""
self.judgements.append((expr, type))
def compute_type(self, expr: p.Expr) -> Type:
"""Evaluate the type of an expression
Args:
expr (p.Expr): the expression to type
Returns:
Type: the type of the given expression
"""
return expr.accept(self)
def type_of(self, expr: p.Expr) -> Type:
"""Evaluate the type of an expression and record the judgement
Args:
expr (p.Expr): the expression to evaluate
Returns:
Type: the type of the given expression
"""
type: Type = expr.accept(self)
self.judgements.append((expr, type))
type: Type = self.compute_type(expr)
self.judge(expr, type)
return type
def resolve_type_expr(self, expr: p.MidasType) -> Type:
@@ -334,6 +354,22 @@ class PythonTyper(
def visit_pass(self, stmt: p.Pass) -> None:
pass
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
item_type: Optional[Type] = self._get_iterator_type(stmt.iterator)
if item_type is None:
iterator_type: Type = self.compute_type(stmt.iterator)
self.reporter.error(
stmt.iterator.location, f"{iterator_type} is not iterable"
)
item_type = UnknownType()
self._assign(stmt.location, stmt.target, item_type)
self.judge(stmt.target, item_type)
env: Environment = Environment(self.env)
body_returned: bool = self.process_block(stmt.body, env)
if body_returned:
raise ReturnException()
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
if method is None:
@@ -370,7 +406,13 @@ class PythonTyper(
)
return UnknownType()
return self._get_call_result(location, operation, [(right_expr, right)], {})
result: Optional[Type] = self._get_call_result(
location,
operation,
[(right_expr, right)],
{},
)
return result or UnknownType()
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
@@ -390,9 +432,13 @@ class PythonTyper(
)
return UnknownType()
return self._get_call_result(
expr.location, operation, [(expr.right, operand)], {}
result: Optional[Type] = self._get_call_result(
expr.location,
operation,
[],
{},
)
return result or UnknownType()
def visit_call_expr(self, expr: p.CallExpr) -> Type:
callee: Type = self.type_of(expr.callee)
@@ -402,11 +448,14 @@ class PythonTyper(
keywords: dict[str, TypedExpr] = {
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
}
return self._get_call_result(
location=expr.location,
callee=callee,
positional=positional,
keywords=keywords,
return (
self._get_call_result(
location=expr.location,
callee=callee,
positional=positional,
keywords=keywords,
)
or UnknownType()
)
def visit_get_expr(self, expr: p.GetExpr) -> Type:
@@ -509,8 +558,9 @@ class PythonTyper(
return UnknownType()
index: Type = self.type_of(expr.index)
return self._get_call_result(
expr.location, operation, [(expr.index, index)], {}
return (
self._get_call_result(expr.location, operation, [(expr.index, index)], {})
or UnknownType()
)
def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
@@ -547,7 +597,8 @@ class PythonTyper(
callee: Type,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
report_errors: bool = True,
) -> Optional[Type]:
"""Get the result type of a function call
If the function has overloads, the function will try to resolve the
@@ -561,9 +612,10 @@ class PythonTyper(
callee (Type): the called function
positional (list[TypedExpr]): the list positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Type: the return type of the call, or `UnknownType` if either
Type: the return type of the call, or `None` if either
the call is invalid or no overload matched the arguments uniquely
"""
match callee:
@@ -573,21 +625,22 @@ class PythonTyper(
valid, mapped = self.map_call_arguments(
function, location, positional, keywords
)
valid = valid and self._are_arguments_valid(mapped)
valid = valid and self._are_arguments_valid(mapped, report_errors)
if not valid:
return UnknownType()
return None
return function.returns
case OverloadedFunction(overloads=overloads):
function = self._match_overload(
overloads, location, positional, keywords
overloads, location, positional, keywords, report_errors
)
if function is None:
return UnknownType()
return None
return function.returns
case _:
self.reporter.error(location, f"{callee} is not callable")
return UnknownType()
if report_errors:
self.reporter.error(location, f"{callee} is not callable")
return None
def _are_arguments_valid(
self,
@@ -620,6 +673,7 @@ class PythonTyper(
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> Optional[Function]:
"""Try and resolve the appropriate overload for the given arguments
@@ -628,6 +682,7 @@ class PythonTyper(
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keywords arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Optional[Function]: the resolved function signature if it can be
@@ -637,9 +692,10 @@ class PythonTyper(
for overload in overloads:
function: Type = unfold_type(overload)
if not isinstance(function, Function):
self.logger.error(
f"Overload is not a function: {overload} is {function}"
)
if report_errors:
self.logger.error(
f"Overload is not a function: {overload} is {function}"
)
continue
valid, mapped = self.map_call_arguments(
function=function,
@@ -671,10 +727,11 @@ class PythonTyper(
# No match -> invalid call
if n_candidates == 0:
overloads_str: str = ", ".join(map(str, overloads))
self.reporter.error(
location,
f"No matching overload in [{overloads_str}] {for_args}",
)
if report_errors:
self.reporter.error(
location,
f"No matching overload in [{overloads_str}] {for_args}",
)
return None
# Multiple matches -> see if one <: all others (more specific)
@@ -695,10 +752,11 @@ class PythonTyper(
candidates_str: str = ", ".join(
str(candidate.function) for candidate in candidates
)
self.reporter.error(
location,
f"Multiple matching overloads {for_args}: {candidates_str}",
)
if report_errors:
self.reporter.error(
location,
f"Multiple matching overloads {for_args}: {candidates_str}",
)
return None
def map_call_arguments(
@@ -863,3 +921,21 @@ class PythonTyper(
if not self.is_subtype(type1, type2):
return False
return True
def _get_iterator_type(self, expr: p.Expr) -> Optional[Type]:
# TODO: lookup __iter__
type: Type = self.type_of(expr)
getitem: Optional[Type] = self.types.lookup_member(type, "__getitem__")
if getitem is None:
return None
index: p.Expr = p.LiteralExpr(location=expr.location, value=0)
index_type: Type = self.compute_type(index)
result: Optional[Type] = self._get_call_result(
location=expr.location,
callee=getitem,
positional=[(index, index_type)],
keywords={},
report_errors=False,
)
return result

View File

@@ -116,17 +116,20 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
self.resolve(stmt.value)
for target in stmt.targets:
match target:
case p.VariableExpr(name=name):
if not self.is_defined(name):
self.declare(name)
self.define(name)
target.accept(self)
self._visit_assign(target)
case p.GetExpr():
target.accept(self)
case _:
raise Exception(f"Unsupported assignment to {target}")
def _visit_assign(self, target: p.Expr):
match target:
case p.VariableExpr(name=name):
if not self.is_defined(name):
self.declare(name)
self.define(name)
target.accept(self)
case p.GetExpr():
target.accept(self)
case _:
raise Exception(f"Unsupported assignment to {target}")
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
if stmt.value is not None:
@@ -153,6 +156,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
def visit_pass(self, stmt: p.Pass) -> None:
pass
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
self.resolve(stmt.iterator)
self._visit_assign(stmt.target)
self.begin_scope()
self.resolve(*stmt.body)
self.end_scope()
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self.resolve(expr.left)
self.resolve(expr.right)

View File

@@ -14,7 +14,7 @@ from midas.cli.highlighter import DiagnosticsHighlighter
from midas.cli.utils import DiagnosticPrinter
@click.command()
@click.command(help="Run type checker and report diagnostics")
@click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-l", "--highlight", type=click.File("w"))

View File

@@ -3,19 +3,20 @@
# midas compile <file.py> [--types <file.midas>] [-o <output>] [--assertions|--strict|--no-checks]
# ```
import sys
from pathlib import Path
from typing import TextIO
import click
from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import Diagnostic
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.cli.utils import DiagnosticPrinter
from midas.generator.generator import Generator
from midas.utils import TypedAST
@click.command()
@click.command(help="Compile source")
@click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
def compile(
@@ -34,5 +35,8 @@ def compile(
printer = DiagnosticPrinter()
printer.print_all(diagnostics)
if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)):
sys.exit(1)
generator = Generator(workdir=source_path.parent)
generator.generate(typed_ast, source_path)

View File

@@ -9,7 +9,7 @@ from midas.lexer.token import Token
from midas.parser.midas import MidasParser
@click.command()
@click.command(help="Parse and pretty print a Midas file")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
def format(file: TextIO, output: TextIO):

View File

@@ -46,7 +46,10 @@ def highlight_midas(source: str, path: str) -> Highlighter:
return highlighter
@click.command()
@click.command(
help="Parse a Python or Midas file and produce a highlighted version showing AST node types inline",
short_help="Parse and highlight a Python or Midas file",
)
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
def highlight(output: TextIO, file: TextIO):

View File

@@ -45,7 +45,7 @@ def dump_midas_ast(source: str, filename: str) -> str:
return dump
@click.command()
@click.command(help="Parse a Python or Midas file and pretty-print its AST")
@click.argument("file", type=click.File("r"))
@click.option("--raw", is_flag=True)
def parse(file: TextIO, raw: bool):

View File

@@ -12,7 +12,7 @@ from midas.checker.checker import TypeChecker
from midas.checker.types import Type
@click.command()
@click.command(help="Dump types registry")
@click.option("-t", "--types", type=click.File("r"), multiple=True)
def dump_registry(
types: tuple[TextIO],

View File

@@ -14,7 +14,7 @@ from midas.cli.highlighter import DiagnosticsHighlighter
from midas.cli.utils import DiagnosticPrinter
@click.command()
@click.command(help="Print typing judgements")
@click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-l", "--highlight", type=click.File("w"))

View File

@@ -14,7 +14,7 @@ from midas.cli.highlighter import DiagnosticsHighlighter
from midas.cli.utils import DiagnosticPrinter
@click.command()
@click.command(help="Validate Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-l", "--highlight", type=click.File("w"))
def validate(

View File

@@ -191,6 +191,13 @@ class PythonHighlighter(
def visit_pass(self, stmt: p.Pass) -> None:
pass
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
self.wrap(stmt, "for")
stmt.iterator.accept(self)
stmt.target.accept(self)
for body_stmt in stmt.body:
body_stmt.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
def visit_compare_expr(self, expr: p.CompareExpr) -> None: ...

View File

@@ -161,5 +161,13 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_pass(self, stmt: p.Pass) -> ast.stmt:
return ast.Pass()
def visit_for_stmt(self, stmt: p.ForStmt) -> ast.stmt:
return ast.For(
target=stmt.target.accept(self),
iter=stmt.iterator.accept(self),
body=self._visit_body(stmt.body),
orelse=[],
)
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
return [stmt.accept(self) for stmt in stmts]

View File

@@ -12,6 +12,7 @@ from midas.ast.python import (
ConstraintType,
Expr,
ExpressionStmt,
ForStmt,
FrameColumn,
FrameType,
Function,
@@ -93,6 +94,9 @@ class PythonParser:
case ast.Pass():
return None
case ast.For(orelse=[]):
return self.parse_for(node)
case _:
print(f"Unsupported statement: {ast.unparse(node)}")
return None
@@ -182,6 +186,22 @@ class PythonParser:
orelse=orelse,
)
def parse_for(self, node: ast.For) -> ForStmt:
body: list[Stmt] = []
for stmt in node.body:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
body.append(stmts)
elif stmts is not None:
body.extend(stmts)
return ForStmt(
location=Location.from_ast(node),
target=self.parse_expr(node.target),
iterator=self.parse_expr(node.iter),
body=body,
)
def parse_function(self, node: ast.FunctionDef) -> Function:
loc: Location = Location.from_ast(node)
match node:

View File

@@ -11,6 +11,7 @@ from midas.ast.python import (
ConstraintType,
Expr,
ExpressionStmt,
ForStmt,
FrameColumn,
FrameType,
Function,
@@ -182,6 +183,14 @@ class PythonAstJsonSerializer(
"_type": "Pass",
}
def visit_for_stmt(self, stmt: ForStmt) -> dict:
return {
"_type": "ForStmt",
"target": stmt.target.accept(self),
"iterator": stmt.iterator.accept(self),
"body": self._serialize_list(stmt.body),
}
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
return {
"_type": "BinaryExpr",