Merge pull request 'For loops' (#11) from feat/for-loops into main

Reviewed-on: #11
This commit was merged in pull request #11.
This commit is contained in:
2026-06-15 22:51:12 +00:00
17 changed files with 224 additions and 51 deletions

View File

@@ -86,6 +86,12 @@ class Pass:
pass pass
class ForStmt:
target: Expr
iterator: Expr
body: list[Stmt]
###< ###<

View File

@@ -596,6 +596,23 @@ class PythonAstPrinter(
def visit_pass(self, stmt: p.Pass) -> None: def visit_pass(self, stmt: p.Pass) -> None:
self._write_line("Pass") self._write_line("Pass")
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
self._write_line("ForStmt")
with self._child_level():
self._write_line("target")
with self._child_level(single=True):
stmt.target.accept(self)
self._write_line("iterator")
with self._child_level(single=True):
stmt.iterator.accept(self)
self._write_line("body", last=True)
with self._child_level():
for i, body_stmt in enumerate(stmt.body):
self._idx = i
if i == len(stmt.body) - 1:
self._mark_last()
body_stmt.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self._write_line("BinaryExpr") self._write_line("BinaryExpr")
with self._child_level(): with self._child_level():

View File

@@ -110,6 +110,9 @@ class Stmt(ABC):
@abstractmethod @abstractmethod
def visit_pass(self, stmt: Pass) -> T: ... def visit_pass(self, stmt: Pass) -> T: ...
@abstractmethod
def visit_for_stmt(self, stmt: ForStmt) -> T: ...
@dataclass(frozen=True) @dataclass(frozen=True)
class ExpressionStmt(Stmt): class ExpressionStmt(Stmt):
@@ -189,6 +192,16 @@ class Pass(Stmt):
return visitor.visit_pass(self) return visitor.visit_pass(self)
@dataclass(frozen=True)
class ForStmt(Stmt):
target: Expr
iterator: Expr
body: list[Stmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_for_stmt(self)
############### ###############
# Expressions # # Expressions #
############### ###############

View File

@@ -78,17 +78,37 @@ class PythonTyper(
return TypedAST(stmts=stmts, judgements=self.judgements) return TypedAST(stmts=stmts, judgements=self.judgements)
def type_of(self, expr: p.Expr) -> Type: def judge(self, expr: p.Expr, type: Type):
"""Record a typing judgement
Args:
expr (p.Expr): the judged expression
type (Type): the type of the expression
"""
self.judgements.append((expr, type))
def compute_type(self, expr: p.Expr) -> Type:
"""Evaluate the type of an expression """Evaluate the type of an expression
Args:
expr (p.Expr): the expression to type
Returns:
Type: the type of the given expression
"""
return expr.accept(self)
def type_of(self, expr: p.Expr) -> Type:
"""Evaluate the type of an expression and record the judgement
Args: Args:
expr (p.Expr): the expression to evaluate expr (p.Expr): the expression to evaluate
Returns: Returns:
Type: the type of the given expression Type: the type of the given expression
""" """
type: Type = expr.accept(self) type: Type = self.compute_type(expr)
self.judgements.append((expr, type)) self.judge(expr, type)
return type return type
def resolve_type_expr(self, expr: p.MidasType) -> Type: def resolve_type_expr(self, expr: p.MidasType) -> Type:
@@ -334,6 +354,22 @@ class PythonTyper(
def visit_pass(self, stmt: p.Pass) -> None: def visit_pass(self, stmt: p.Pass) -> None:
pass pass
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
item_type: Optional[Type] = self._get_iterator_type(stmt.iterator)
if item_type is None:
iterator_type: Type = self.compute_type(stmt.iterator)
self.reporter.error(
stmt.iterator.location, f"{iterator_type} is not iterable"
)
item_type = UnknownType()
self._assign(stmt.location, stmt.target, item_type)
self.judge(stmt.target, item_type)
env: Environment = Environment(self.env)
body_returned: bool = self.process_block(stmt.body, env)
if body_returned:
raise ReturnException()
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
if method is None: if method is None:
@@ -370,7 +406,13 @@ class PythonTyper(
) )
return UnknownType() return UnknownType()
return self._get_call_result(location, operation, [(right_expr, right)], {}) result: Optional[Type] = self._get_call_result(
location,
operation,
[(right_expr, right)],
{},
)
return result or UnknownType()
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__) method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
@@ -390,9 +432,13 @@ class PythonTyper(
) )
return UnknownType() return UnknownType()
return self._get_call_result( result: Optional[Type] = self._get_call_result(
expr.location, operation, [(expr.right, operand)], {} expr.location,
operation,
[],
{},
) )
return result or UnknownType()
def visit_call_expr(self, expr: p.CallExpr) -> Type: def visit_call_expr(self, expr: p.CallExpr) -> Type:
callee: Type = self.type_of(expr.callee) callee: Type = self.type_of(expr.callee)
@@ -402,11 +448,14 @@ class PythonTyper(
keywords: dict[str, TypedExpr] = { keywords: dict[str, TypedExpr] = {
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items() name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
} }
return self._get_call_result( return (
location=expr.location, self._get_call_result(
callee=callee, location=expr.location,
positional=positional, callee=callee,
keywords=keywords, positional=positional,
keywords=keywords,
)
or UnknownType()
) )
def visit_get_expr(self, expr: p.GetExpr) -> Type: def visit_get_expr(self, expr: p.GetExpr) -> Type:
@@ -509,8 +558,9 @@ class PythonTyper(
return UnknownType() return UnknownType()
index: Type = self.type_of(expr.index) index: Type = self.type_of(expr.index)
return self._get_call_result( return (
expr.location, operation, [(expr.index, index)], {} self._get_call_result(expr.location, operation, [(expr.index, index)], {})
or UnknownType()
) )
def visit_slice_expr(self, expr: p.SliceExpr) -> Type: def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
@@ -547,7 +597,8 @@ class PythonTyper(
callee: Type, callee: Type,
positional: list[TypedExpr], positional: list[TypedExpr],
keywords: dict[str, TypedExpr], keywords: dict[str, TypedExpr],
) -> Type: report_errors: bool = True,
) -> Optional[Type]:
"""Get the result type of a function call """Get the result type of a function call
If the function has overloads, the function will try to resolve the If the function has overloads, the function will try to resolve the
@@ -561,9 +612,10 @@ class PythonTyper(
callee (Type): the called function callee (Type): the called function
positional (list[TypedExpr]): the list positional arguments positional (list[TypedExpr]): the list positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns: Returns:
Type: the return type of the call, or `UnknownType` if either Type: the return type of the call, or `None` if either
the call is invalid or no overload matched the arguments uniquely the call is invalid or no overload matched the arguments uniquely
""" """
match callee: match callee:
@@ -573,21 +625,22 @@ class PythonTyper(
valid, mapped = self.map_call_arguments( valid, mapped = self.map_call_arguments(
function, location, positional, keywords function, location, positional, keywords
) )
valid = valid and self._are_arguments_valid(mapped) valid = valid and self._are_arguments_valid(mapped, report_errors)
if not valid: if not valid:
return UnknownType() return None
return function.returns return function.returns
case OverloadedFunction(overloads=overloads): case OverloadedFunction(overloads=overloads):
function = self._match_overload( function = self._match_overload(
overloads, location, positional, keywords overloads, location, positional, keywords, report_errors
) )
if function is None: if function is None:
return UnknownType() return None
return function.returns return function.returns
case _: case _:
self.reporter.error(location, f"{callee} is not callable") if report_errors:
return UnknownType() self.reporter.error(location, f"{callee} is not callable")
return None
def _are_arguments_valid( def _are_arguments_valid(
self, self,
@@ -620,6 +673,7 @@ class PythonTyper(
location: Location, location: Location,
positional: list[TypedExpr], positional: list[TypedExpr],
keywords: dict[str, TypedExpr], keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> Optional[Function]: ) -> Optional[Function]:
"""Try and resolve the appropriate overload for the given arguments """Try and resolve the appropriate overload for the given arguments
@@ -628,6 +682,7 @@ class PythonTyper(
location (Location): the call location location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keywords arguments keywords (dict[str, TypedExpr]): the map of keywords arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns: Returns:
Optional[Function]: the resolved function signature if it can be Optional[Function]: the resolved function signature if it can be
@@ -637,9 +692,10 @@ class PythonTyper(
for overload in overloads: for overload in overloads:
function: Type = unfold_type(overload) function: Type = unfold_type(overload)
if not isinstance(function, Function): if not isinstance(function, Function):
self.logger.error( if report_errors:
f"Overload is not a function: {overload} is {function}" self.logger.error(
) f"Overload is not a function: {overload} is {function}"
)
continue continue
valid, mapped = self.map_call_arguments( valid, mapped = self.map_call_arguments(
function=function, function=function,
@@ -671,10 +727,11 @@ class PythonTyper(
# No match -> invalid call # No match -> invalid call
if n_candidates == 0: if n_candidates == 0:
overloads_str: str = ", ".join(map(str, overloads)) overloads_str: str = ", ".join(map(str, overloads))
self.reporter.error( if report_errors:
location, self.reporter.error(
f"No matching overload in [{overloads_str}] {for_args}", location,
) f"No matching overload in [{overloads_str}] {for_args}",
)
return None return None
# Multiple matches -> see if one <: all others (more specific) # Multiple matches -> see if one <: all others (more specific)
@@ -695,10 +752,11 @@ class PythonTyper(
candidates_str: str = ", ".join( candidates_str: str = ", ".join(
str(candidate.function) for candidate in candidates str(candidate.function) for candidate in candidates
) )
self.reporter.error( if report_errors:
location, self.reporter.error(
f"Multiple matching overloads {for_args}: {candidates_str}", location,
) f"Multiple matching overloads {for_args}: {candidates_str}",
)
return None return None
def map_call_arguments( def map_call_arguments(
@@ -863,3 +921,21 @@ class PythonTyper(
if not self.is_subtype(type1, type2): if not self.is_subtype(type1, type2):
return False return False
return True return True
def _get_iterator_type(self, expr: p.Expr) -> Optional[Type]:
# TODO: lookup __iter__
type: Type = self.type_of(expr)
getitem: Optional[Type] = self.types.lookup_member(type, "__getitem__")
if getitem is None:
return None
index: p.Expr = p.LiteralExpr(location=expr.location, value=0)
index_type: Type = self.compute_type(index)
result: Optional[Type] = self._get_call_result(
location=expr.location,
callee=getitem,
positional=[(index, index_type)],
keywords={},
report_errors=False,
)
return result

View File

@@ -116,17 +116,20 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
self.resolve(stmt.value) self.resolve(stmt.value)
for target in stmt.targets: for target in stmt.targets:
match target: self._visit_assign(target)
case p.VariableExpr(name=name):
if not self.is_defined(name):
self.declare(name)
self.define(name)
target.accept(self)
case p.GetExpr(): def _visit_assign(self, target: p.Expr):
target.accept(self) match target:
case _: case p.VariableExpr(name=name):
raise Exception(f"Unsupported assignment to {target}") if not self.is_defined(name):
self.declare(name)
self.define(name)
target.accept(self)
case p.GetExpr():
target.accept(self)
case _:
raise Exception(f"Unsupported assignment to {target}")
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
if stmt.value is not None: if stmt.value is not None:
@@ -153,6 +156,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
def visit_pass(self, stmt: p.Pass) -> None: def visit_pass(self, stmt: p.Pass) -> None:
pass pass
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
self.resolve(stmt.iterator)
self._visit_assign(stmt.target)
self.begin_scope()
self.resolve(*stmt.body)
self.end_scope()
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self.resolve(expr.left) self.resolve(expr.left)
self.resolve(expr.right) self.resolve(expr.right)

View File

@@ -14,7 +14,7 @@ from midas.cli.highlighter import DiagnosticsHighlighter
from midas.cli.utils import DiagnosticPrinter from midas.cli.utils import DiagnosticPrinter
@click.command() @click.command(help="Run type checker and report diagnostics")
@click.argument("file", type=click.File("r")) @click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True) @click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-l", "--highlight", type=click.File("w")) @click.option("-l", "--highlight", type=click.File("w"))

View File

@@ -3,19 +3,20 @@
# midas compile <file.py> [--types <file.midas>] [-o <output>] [--assertions|--strict|--no-checks] # midas compile <file.py> [--types <file.midas>] [-o <output>] [--assertions|--strict|--no-checks]
# ``` # ```
import sys
from pathlib import Path from pathlib import Path
from typing import TextIO from typing import TextIO
import click import click
from midas.checker.checker import TypeChecker from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import Diagnostic from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.cli.utils import DiagnosticPrinter from midas.cli.utils import DiagnosticPrinter
from midas.generator.generator import Generator from midas.generator.generator import Generator
from midas.utils import TypedAST from midas.utils import TypedAST
@click.command() @click.command(help="Compile source")
@click.argument("file", type=click.File("r")) @click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True) @click.option("-t", "--types", type=click.File("r"), multiple=True)
def compile( def compile(
@@ -34,5 +35,8 @@ def compile(
printer = DiagnosticPrinter() printer = DiagnosticPrinter()
printer.print_all(diagnostics) printer.print_all(diagnostics)
if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)):
sys.exit(1)
generator = Generator(workdir=source_path.parent) generator = Generator(workdir=source_path.parent)
generator.generate(typed_ast, source_path) generator.generate(typed_ast, source_path)

View File

@@ -9,7 +9,7 @@ from midas.lexer.token import Token
from midas.parser.midas import MidasParser from midas.parser.midas import MidasParser
@click.command() @click.command(help="Parse and pretty print a Midas file")
@click.argument("file", type=click.File("r")) @click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-") @click.option("-o", "--output", type=click.File("w"), default="-")
def format(file: TextIO, output: TextIO): def format(file: TextIO, output: TextIO):

View File

@@ -46,7 +46,10 @@ def highlight_midas(source: str, path: str) -> Highlighter:
return highlighter return highlighter
@click.command() @click.command(
help="Parse a Python or Midas file and produce a highlighted version showing AST node types inline",
short_help="Parse and highlight a Python or Midas file",
)
@click.argument("file", type=click.File("r")) @click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-") @click.option("-o", "--output", type=click.File("w"), default="-")
def highlight(output: TextIO, file: TextIO): def highlight(output: TextIO, file: TextIO):

View File

@@ -45,7 +45,7 @@ def dump_midas_ast(source: str, filename: str) -> str:
return dump return dump
@click.command() @click.command(help="Parse a Python or Midas file and pretty-print its AST")
@click.argument("file", type=click.File("r")) @click.argument("file", type=click.File("r"))
@click.option("--raw", is_flag=True) @click.option("--raw", is_flag=True)
def parse(file: TextIO, raw: bool): def parse(file: TextIO, raw: bool):

View File

@@ -12,7 +12,7 @@ from midas.checker.checker import TypeChecker
from midas.checker.types import Type from midas.checker.types import Type
@click.command() @click.command(help="Dump types registry")
@click.option("-t", "--types", type=click.File("r"), multiple=True) @click.option("-t", "--types", type=click.File("r"), multiple=True)
def dump_registry( def dump_registry(
types: tuple[TextIO], types: tuple[TextIO],

View File

@@ -14,7 +14,7 @@ from midas.cli.highlighter import DiagnosticsHighlighter
from midas.cli.utils import DiagnosticPrinter from midas.cli.utils import DiagnosticPrinter
@click.command() @click.command(help="Print typing judgements")
@click.argument("file", type=click.File("r")) @click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True) @click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-l", "--highlight", type=click.File("w")) @click.option("-l", "--highlight", type=click.File("w"))

View File

@@ -14,7 +14,7 @@ from midas.cli.highlighter import DiagnosticsHighlighter
from midas.cli.utils import DiagnosticPrinter from midas.cli.utils import DiagnosticPrinter
@click.command() @click.command(help="Validate Midas definitions")
@click.argument("file", type=click.File("r")) @click.argument("file", type=click.File("r"))
@click.option("-l", "--highlight", type=click.File("w")) @click.option("-l", "--highlight", type=click.File("w"))
def validate( def validate(

View File

@@ -191,6 +191,13 @@ class PythonHighlighter(
def visit_pass(self, stmt: p.Pass) -> None: def visit_pass(self, stmt: p.Pass) -> None:
pass pass
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
self.wrap(stmt, "for")
stmt.iterator.accept(self)
stmt.target.accept(self)
for body_stmt in stmt.body:
body_stmt.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ... def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
def visit_compare_expr(self, expr: p.CompareExpr) -> None: ... def visit_compare_expr(self, expr: p.CompareExpr) -> None: ...

View File

@@ -161,5 +161,13 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_pass(self, stmt: p.Pass) -> ast.stmt: def visit_pass(self, stmt: p.Pass) -> ast.stmt:
return ast.Pass() return ast.Pass()
def visit_for_stmt(self, stmt: p.ForStmt) -> ast.stmt:
return ast.For(
target=stmt.target.accept(self),
iter=stmt.iterator.accept(self),
body=self._visit_body(stmt.body),
orelse=[],
)
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]: def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
return [stmt.accept(self) for stmt in stmts] return [stmt.accept(self) for stmt in stmts]

View File

@@ -12,6 +12,7 @@ from midas.ast.python import (
ConstraintType, ConstraintType,
Expr, Expr,
ExpressionStmt, ExpressionStmt,
ForStmt,
FrameColumn, FrameColumn,
FrameType, FrameType,
Function, Function,
@@ -93,6 +94,9 @@ class PythonParser:
case ast.Pass(): case ast.Pass():
return None return None
case ast.For(orelse=[]):
return self.parse_for(node)
case _: case _:
print(f"Unsupported statement: {ast.unparse(node)}") print(f"Unsupported statement: {ast.unparse(node)}")
return None return None
@@ -182,6 +186,22 @@ class PythonParser:
orelse=orelse, orelse=orelse,
) )
def parse_for(self, node: ast.For) -> ForStmt:
body: list[Stmt] = []
for stmt in node.body:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
body.append(stmts)
elif stmts is not None:
body.extend(stmts)
return ForStmt(
location=Location.from_ast(node),
target=self.parse_expr(node.target),
iterator=self.parse_expr(node.iter),
body=body,
)
def parse_function(self, node: ast.FunctionDef) -> Function: def parse_function(self, node: ast.FunctionDef) -> Function:
loc: Location = Location.from_ast(node) loc: Location = Location.from_ast(node)
match node: match node:

View File

@@ -11,6 +11,7 @@ from midas.ast.python import (
ConstraintType, ConstraintType,
Expr, Expr,
ExpressionStmt, ExpressionStmt,
ForStmt,
FrameColumn, FrameColumn,
FrameType, FrameType,
Function, Function,
@@ -182,6 +183,14 @@ class PythonAstJsonSerializer(
"_type": "Pass", "_type": "Pass",
} }
def visit_for_stmt(self, stmt: ForStmt) -> dict:
return {
"_type": "ForStmt",
"target": stmt.target.accept(self),
"iterator": stmt.iterator.accept(self),
"body": self._serialize_list(stmt.body),
}
def visit_binary_expr(self, expr: BinaryExpr) -> dict: def visit_binary_expr(self, expr: BinaryExpr) -> dict:
return { return {
"_type": "BinaryExpr", "_type": "BinaryExpr",