Merge pull request 'For loops' (#11) from feat/for-loops into main
Reviewed-on: #11
This commit was merged in pull request #11.
This commit is contained in:
@@ -86,6 +86,12 @@ class Pass:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ForStmt:
|
||||||
|
target: Expr
|
||||||
|
iterator: Expr
|
||||||
|
body: list[Stmt]
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -596,6 +596,23 @@ class PythonAstPrinter(
|
|||||||
def visit_pass(self, stmt: p.Pass) -> None:
|
def visit_pass(self, stmt: p.Pass) -> None:
|
||||||
self._write_line("Pass")
|
self._write_line("Pass")
|
||||||
|
|
||||||
|
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
|
||||||
|
self._write_line("ForStmt")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("target")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
stmt.target.accept(self)
|
||||||
|
self._write_line("iterator")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
stmt.iterator.accept(self)
|
||||||
|
self._write_line("body", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, body_stmt in enumerate(stmt.body):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.body) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
body_stmt.accept(self)
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||||
self._write_line("BinaryExpr")
|
self._write_line("BinaryExpr")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
|
|||||||
@@ -110,6 +110,9 @@ class Stmt(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_pass(self, stmt: Pass) -> T: ...
|
def visit_pass(self, stmt: Pass) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_for_stmt(self, stmt: ForStmt) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ExpressionStmt(Stmt):
|
class ExpressionStmt(Stmt):
|
||||||
@@ -189,6 +192,16 @@ class Pass(Stmt):
|
|||||||
return visitor.visit_pass(self)
|
return visitor.visit_pass(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ForStmt(Stmt):
|
||||||
|
target: Expr
|
||||||
|
iterator: Expr
|
||||||
|
body: list[Stmt]
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_for_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
###############
|
###############
|
||||||
# Expressions #
|
# Expressions #
|
||||||
###############
|
###############
|
||||||
|
|||||||
@@ -78,17 +78,37 @@ class PythonTyper(
|
|||||||
|
|
||||||
return TypedAST(stmts=stmts, judgements=self.judgements)
|
return TypedAST(stmts=stmts, judgements=self.judgements)
|
||||||
|
|
||||||
def type_of(self, expr: p.Expr) -> Type:
|
def judge(self, expr: p.Expr, type: Type):
|
||||||
|
"""Record a typing judgement
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expr (p.Expr): the judged expression
|
||||||
|
type (Type): the type of the expression
|
||||||
|
"""
|
||||||
|
self.judgements.append((expr, type))
|
||||||
|
|
||||||
|
def compute_type(self, expr: p.Expr) -> Type:
|
||||||
"""Evaluate the type of an expression
|
"""Evaluate the type of an expression
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expr (p.Expr): the expression to type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the type of the given expression
|
||||||
|
"""
|
||||||
|
return expr.accept(self)
|
||||||
|
|
||||||
|
def type_of(self, expr: p.Expr) -> Type:
|
||||||
|
"""Evaluate the type of an expression and record the judgement
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
expr (p.Expr): the expression to evaluate
|
expr (p.Expr): the expression to evaluate
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Type: the type of the given expression
|
Type: the type of the given expression
|
||||||
"""
|
"""
|
||||||
type: Type = expr.accept(self)
|
type: Type = self.compute_type(expr)
|
||||||
self.judgements.append((expr, type))
|
self.judge(expr, type)
|
||||||
return type
|
return type
|
||||||
|
|
||||||
def resolve_type_expr(self, expr: p.MidasType) -> Type:
|
def resolve_type_expr(self, expr: p.MidasType) -> Type:
|
||||||
@@ -334,6 +354,22 @@ class PythonTyper(
|
|||||||
def visit_pass(self, stmt: p.Pass) -> None:
|
def visit_pass(self, stmt: p.Pass) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
|
||||||
|
item_type: Optional[Type] = self._get_iterator_type(stmt.iterator)
|
||||||
|
if item_type is None:
|
||||||
|
iterator_type: Type = self.compute_type(stmt.iterator)
|
||||||
|
self.reporter.error(
|
||||||
|
stmt.iterator.location, f"{iterator_type} is not iterable"
|
||||||
|
)
|
||||||
|
item_type = UnknownType()
|
||||||
|
|
||||||
|
self._assign(stmt.location, stmt.target, item_type)
|
||||||
|
self.judge(stmt.target, item_type)
|
||||||
|
env: Environment = Environment(self.env)
|
||||||
|
body_returned: bool = self.process_block(stmt.body, env)
|
||||||
|
if body_returned:
|
||||||
|
raise ReturnException()
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
||||||
if method is None:
|
if method is None:
|
||||||
@@ -370,7 +406,13 @@ class PythonTyper(
|
|||||||
)
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
return self._get_call_result(location, operation, [(right_expr, right)], {})
|
result: Optional[Type] = self._get_call_result(
|
||||||
|
location,
|
||||||
|
operation,
|
||||||
|
[(right_expr, right)],
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
return result or UnknownType()
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
||||||
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
|
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
|
||||||
@@ -390,9 +432,13 @@ class PythonTyper(
|
|||||||
)
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
return self._get_call_result(
|
result: Optional[Type] = self._get_call_result(
|
||||||
expr.location, operation, [(expr.right, operand)], {}
|
expr.location,
|
||||||
|
operation,
|
||||||
|
[],
|
||||||
|
{},
|
||||||
)
|
)
|
||||||
|
return result or UnknownType()
|
||||||
|
|
||||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||||
callee: Type = self.type_of(expr.callee)
|
callee: Type = self.type_of(expr.callee)
|
||||||
@@ -402,12 +448,15 @@ class PythonTyper(
|
|||||||
keywords: dict[str, TypedExpr] = {
|
keywords: dict[str, TypedExpr] = {
|
||||||
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
||||||
}
|
}
|
||||||
return self._get_call_result(
|
return (
|
||||||
|
self._get_call_result(
|
||||||
location=expr.location,
|
location=expr.location,
|
||||||
callee=callee,
|
callee=callee,
|
||||||
positional=positional,
|
positional=positional,
|
||||||
keywords=keywords,
|
keywords=keywords,
|
||||||
)
|
)
|
||||||
|
or UnknownType()
|
||||||
|
)
|
||||||
|
|
||||||
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
||||||
object: Type = self.type_of(expr.object)
|
object: Type = self.type_of(expr.object)
|
||||||
@@ -509,8 +558,9 @@ class PythonTyper(
|
|||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
index: Type = self.type_of(expr.index)
|
index: Type = self.type_of(expr.index)
|
||||||
return self._get_call_result(
|
return (
|
||||||
expr.location, operation, [(expr.index, index)], {}
|
self._get_call_result(expr.location, operation, [(expr.index, index)], {})
|
||||||
|
or UnknownType()
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
|
def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
|
||||||
@@ -547,7 +597,8 @@ class PythonTyper(
|
|||||||
callee: Type,
|
callee: Type,
|
||||||
positional: list[TypedExpr],
|
positional: list[TypedExpr],
|
||||||
keywords: dict[str, TypedExpr],
|
keywords: dict[str, TypedExpr],
|
||||||
) -> Type:
|
report_errors: bool = True,
|
||||||
|
) -> Optional[Type]:
|
||||||
"""Get the result type of a function call
|
"""Get the result type of a function call
|
||||||
|
|
||||||
If the function has overloads, the function will try to resolve the
|
If the function has overloads, the function will try to resolve the
|
||||||
@@ -561,9 +612,10 @@ class PythonTyper(
|
|||||||
callee (Type): the called function
|
callee (Type): the called function
|
||||||
positional (list[TypedExpr]): the list positional arguments
|
positional (list[TypedExpr]): the list positional arguments
|
||||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Type: the return type of the call, or `UnknownType` if either
|
Type: the return type of the call, or `None` if either
|
||||||
the call is invalid or no overload matched the arguments uniquely
|
the call is invalid or no overload matched the arguments uniquely
|
||||||
"""
|
"""
|
||||||
match callee:
|
match callee:
|
||||||
@@ -573,21 +625,22 @@ class PythonTyper(
|
|||||||
valid, mapped = self.map_call_arguments(
|
valid, mapped = self.map_call_arguments(
|
||||||
function, location, positional, keywords
|
function, location, positional, keywords
|
||||||
)
|
)
|
||||||
valid = valid and self._are_arguments_valid(mapped)
|
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
||||||
if not valid:
|
if not valid:
|
||||||
return UnknownType()
|
return None
|
||||||
return function.returns
|
return function.returns
|
||||||
|
|
||||||
case OverloadedFunction(overloads=overloads):
|
case OverloadedFunction(overloads=overloads):
|
||||||
function = self._match_overload(
|
function = self._match_overload(
|
||||||
overloads, location, positional, keywords
|
overloads, location, positional, keywords, report_errors
|
||||||
)
|
)
|
||||||
if function is None:
|
if function is None:
|
||||||
return UnknownType()
|
return None
|
||||||
return function.returns
|
return function.returns
|
||||||
case _:
|
case _:
|
||||||
|
if report_errors:
|
||||||
self.reporter.error(location, f"{callee} is not callable")
|
self.reporter.error(location, f"{callee} is not callable")
|
||||||
return UnknownType()
|
return None
|
||||||
|
|
||||||
def _are_arguments_valid(
|
def _are_arguments_valid(
|
||||||
self,
|
self,
|
||||||
@@ -620,6 +673,7 @@ class PythonTyper(
|
|||||||
location: Location,
|
location: Location,
|
||||||
positional: list[TypedExpr],
|
positional: list[TypedExpr],
|
||||||
keywords: dict[str, TypedExpr],
|
keywords: dict[str, TypedExpr],
|
||||||
|
report_errors: bool = True,
|
||||||
) -> Optional[Function]:
|
) -> Optional[Function]:
|
||||||
"""Try and resolve the appropriate overload for the given arguments
|
"""Try and resolve the appropriate overload for the given arguments
|
||||||
|
|
||||||
@@ -628,6 +682,7 @@ class PythonTyper(
|
|||||||
location (Location): the call location
|
location (Location): the call location
|
||||||
positional (list[TypedExpr]): the list of positional arguments
|
positional (list[TypedExpr]): the list of positional arguments
|
||||||
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[Function]: the resolved function signature if it can be
|
Optional[Function]: the resolved function signature if it can be
|
||||||
@@ -637,6 +692,7 @@ class PythonTyper(
|
|||||||
for overload in overloads:
|
for overload in overloads:
|
||||||
function: Type = unfold_type(overload)
|
function: Type = unfold_type(overload)
|
||||||
if not isinstance(function, Function):
|
if not isinstance(function, Function):
|
||||||
|
if report_errors:
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
f"Overload is not a function: {overload} is {function}"
|
f"Overload is not a function: {overload} is {function}"
|
||||||
)
|
)
|
||||||
@@ -671,6 +727,7 @@ class PythonTyper(
|
|||||||
# No match -> invalid call
|
# No match -> invalid call
|
||||||
if n_candidates == 0:
|
if n_candidates == 0:
|
||||||
overloads_str: str = ", ".join(map(str, overloads))
|
overloads_str: str = ", ".join(map(str, overloads))
|
||||||
|
if report_errors:
|
||||||
self.reporter.error(
|
self.reporter.error(
|
||||||
location,
|
location,
|
||||||
f"No matching overload in [{overloads_str}] {for_args}",
|
f"No matching overload in [{overloads_str}] {for_args}",
|
||||||
@@ -695,6 +752,7 @@ class PythonTyper(
|
|||||||
candidates_str: str = ", ".join(
|
candidates_str: str = ", ".join(
|
||||||
str(candidate.function) for candidate in candidates
|
str(candidate.function) for candidate in candidates
|
||||||
)
|
)
|
||||||
|
if report_errors:
|
||||||
self.reporter.error(
|
self.reporter.error(
|
||||||
location,
|
location,
|
||||||
f"Multiple matching overloads {for_args}: {candidates_str}",
|
f"Multiple matching overloads {for_args}: {candidates_str}",
|
||||||
@@ -863,3 +921,21 @@ class PythonTyper(
|
|||||||
if not self.is_subtype(type1, type2):
|
if not self.is_subtype(type1, type2):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _get_iterator_type(self, expr: p.Expr) -> Optional[Type]:
|
||||||
|
# TODO: lookup __iter__
|
||||||
|
type: Type = self.type_of(expr)
|
||||||
|
getitem: Optional[Type] = self.types.lookup_member(type, "__getitem__")
|
||||||
|
if getitem is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
index: p.Expr = p.LiteralExpr(location=expr.location, value=0)
|
||||||
|
index_type: Type = self.compute_type(index)
|
||||||
|
result: Optional[Type] = self._get_call_result(
|
||||||
|
location=expr.location,
|
||||||
|
callee=getitem,
|
||||||
|
positional=[(index, index_type)],
|
||||||
|
keywords={},
|
||||||
|
report_errors=False,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|||||||
@@ -116,6 +116,9 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||||
self.resolve(stmt.value)
|
self.resolve(stmt.value)
|
||||||
for target in stmt.targets:
|
for target in stmt.targets:
|
||||||
|
self._visit_assign(target)
|
||||||
|
|
||||||
|
def _visit_assign(self, target: p.Expr):
|
||||||
match target:
|
match target:
|
||||||
case p.VariableExpr(name=name):
|
case p.VariableExpr(name=name):
|
||||||
if not self.is_defined(name):
|
if not self.is_defined(name):
|
||||||
@@ -153,6 +156,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
def visit_pass(self, stmt: p.Pass) -> None:
|
def visit_pass(self, stmt: p.Pass) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
|
||||||
|
self.resolve(stmt.iterator)
|
||||||
|
self._visit_assign(stmt.target)
|
||||||
|
self.begin_scope()
|
||||||
|
self.resolve(*stmt.body)
|
||||||
|
self.end_scope()
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||||
self.resolve(expr.left)
|
self.resolve(expr.left)
|
||||||
self.resolve(expr.right)
|
self.resolve(expr.right)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from midas.cli.highlighter import DiagnosticsHighlighter
|
|||||||
from midas.cli.utils import DiagnosticPrinter
|
from midas.cli.utils import DiagnosticPrinter
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command(help="Run type checker and report diagnostics")
|
||||||
@click.argument("file", type=click.File("r"))
|
@click.argument("file", type=click.File("r"))
|
||||||
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||||
@click.option("-l", "--highlight", type=click.File("w"))
|
@click.option("-l", "--highlight", type=click.File("w"))
|
||||||
|
|||||||
@@ -3,19 +3,20 @@
|
|||||||
# midas compile <file.py> [--types <file.midas>] [-o <output>] [--assertions|--strict|--no-checks]
|
# midas compile <file.py> [--types <file.midas>] [-o <output>] [--assertions|--strict|--no-checks]
|
||||||
# ```
|
# ```
|
||||||
|
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TextIO
|
from typing import TextIO
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from midas.checker.checker import TypeChecker
|
from midas.checker.checker import TypeChecker
|
||||||
from midas.checker.diagnostic import Diagnostic
|
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||||
from midas.cli.utils import DiagnosticPrinter
|
from midas.cli.utils import DiagnosticPrinter
|
||||||
from midas.generator.generator import Generator
|
from midas.generator.generator import Generator
|
||||||
from midas.utils import TypedAST
|
from midas.utils import TypedAST
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command(help="Compile source")
|
||||||
@click.argument("file", type=click.File("r"))
|
@click.argument("file", type=click.File("r"))
|
||||||
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||||
def compile(
|
def compile(
|
||||||
@@ -34,5 +35,8 @@ def compile(
|
|||||||
printer = DiagnosticPrinter()
|
printer = DiagnosticPrinter()
|
||||||
printer.print_all(diagnostics)
|
printer.print_all(diagnostics)
|
||||||
|
|
||||||
|
if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)):
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
generator = Generator(workdir=source_path.parent)
|
generator = Generator(workdir=source_path.parent)
|
||||||
generator.generate(typed_ast, source_path)
|
generator.generate(typed_ast, source_path)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from midas.lexer.token import Token
|
|||||||
from midas.parser.midas import MidasParser
|
from midas.parser.midas import MidasParser
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command(help="Parse and pretty print a Midas file")
|
||||||
@click.argument("file", type=click.File("r"))
|
@click.argument("file", type=click.File("r"))
|
||||||
@click.option("-o", "--output", type=click.File("w"), default="-")
|
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||||
def format(file: TextIO, output: TextIO):
|
def format(file: TextIO, output: TextIO):
|
||||||
|
|||||||
@@ -46,7 +46,10 @@ def highlight_midas(source: str, path: str) -> Highlighter:
|
|||||||
return highlighter
|
return highlighter
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command(
|
||||||
|
help="Parse a Python or Midas file and produce a highlighted version showing AST node types inline",
|
||||||
|
short_help="Parse and highlight a Python or Midas file",
|
||||||
|
)
|
||||||
@click.argument("file", type=click.File("r"))
|
@click.argument("file", type=click.File("r"))
|
||||||
@click.option("-o", "--output", type=click.File("w"), default="-")
|
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||||
def highlight(output: TextIO, file: TextIO):
|
def highlight(output: TextIO, file: TextIO):
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ def dump_midas_ast(source: str, filename: str) -> str:
|
|||||||
return dump
|
return dump
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command(help="Parse a Python or Midas file and pretty-print its AST")
|
||||||
@click.argument("file", type=click.File("r"))
|
@click.argument("file", type=click.File("r"))
|
||||||
@click.option("--raw", is_flag=True)
|
@click.option("--raw", is_flag=True)
|
||||||
def parse(file: TextIO, raw: bool):
|
def parse(file: TextIO, raw: bool):
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from midas.checker.checker import TypeChecker
|
|||||||
from midas.checker.types import Type
|
from midas.checker.types import Type
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command(help="Dump types registry")
|
||||||
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||||
def dump_registry(
|
def dump_registry(
|
||||||
types: tuple[TextIO],
|
types: tuple[TextIO],
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from midas.cli.highlighter import DiagnosticsHighlighter
|
|||||||
from midas.cli.utils import DiagnosticPrinter
|
from midas.cli.utils import DiagnosticPrinter
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command(help="Print typing judgements")
|
||||||
@click.argument("file", type=click.File("r"))
|
@click.argument("file", type=click.File("r"))
|
||||||
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||||
@click.option("-l", "--highlight", type=click.File("w"))
|
@click.option("-l", "--highlight", type=click.File("w"))
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from midas.cli.highlighter import DiagnosticsHighlighter
|
|||||||
from midas.cli.utils import DiagnosticPrinter
|
from midas.cli.utils import DiagnosticPrinter
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command(help="Validate Midas definitions")
|
||||||
@click.argument("file", type=click.File("r"))
|
@click.argument("file", type=click.File("r"))
|
||||||
@click.option("-l", "--highlight", type=click.File("w"))
|
@click.option("-l", "--highlight", type=click.File("w"))
|
||||||
def validate(
|
def validate(
|
||||||
|
|||||||
@@ -191,6 +191,13 @@ class PythonHighlighter(
|
|||||||
def visit_pass(self, stmt: p.Pass) -> None:
|
def visit_pass(self, stmt: p.Pass) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
|
||||||
|
self.wrap(stmt, "for")
|
||||||
|
stmt.iterator.accept(self)
|
||||||
|
stmt.target.accept(self)
|
||||||
|
for body_stmt in stmt.body:
|
||||||
|
body_stmt.accept(self)
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
|
||||||
|
|
||||||
def visit_compare_expr(self, expr: p.CompareExpr) -> None: ...
|
def visit_compare_expr(self, expr: p.CompareExpr) -> None: ...
|
||||||
|
|||||||
@@ -161,5 +161,13 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
def visit_pass(self, stmt: p.Pass) -> ast.stmt:
|
def visit_pass(self, stmt: p.Pass) -> ast.stmt:
|
||||||
return ast.Pass()
|
return ast.Pass()
|
||||||
|
|
||||||
|
def visit_for_stmt(self, stmt: p.ForStmt) -> ast.stmt:
|
||||||
|
return ast.For(
|
||||||
|
target=stmt.target.accept(self),
|
||||||
|
iter=stmt.iterator.accept(self),
|
||||||
|
body=self._visit_body(stmt.body),
|
||||||
|
orelse=[],
|
||||||
|
)
|
||||||
|
|
||||||
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
|
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
|
||||||
return [stmt.accept(self) for stmt in stmts]
|
return [stmt.accept(self) for stmt in stmts]
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from midas.ast.python import (
|
|||||||
ConstraintType,
|
ConstraintType,
|
||||||
Expr,
|
Expr,
|
||||||
ExpressionStmt,
|
ExpressionStmt,
|
||||||
|
ForStmt,
|
||||||
FrameColumn,
|
FrameColumn,
|
||||||
FrameType,
|
FrameType,
|
||||||
Function,
|
Function,
|
||||||
@@ -93,6 +94,9 @@ class PythonParser:
|
|||||||
case ast.Pass():
|
case ast.Pass():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
case ast.For(orelse=[]):
|
||||||
|
return self.parse_for(node)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
print(f"Unsupported statement: {ast.unparse(node)}")
|
print(f"Unsupported statement: {ast.unparse(node)}")
|
||||||
return None
|
return None
|
||||||
@@ -182,6 +186,22 @@ class PythonParser:
|
|||||||
orelse=orelse,
|
orelse=orelse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def parse_for(self, node: ast.For) -> ForStmt:
|
||||||
|
body: list[Stmt] = []
|
||||||
|
for stmt in node.body:
|
||||||
|
stmts = self.parse_stmt(stmt)
|
||||||
|
if isinstance(stmts, Stmt):
|
||||||
|
body.append(stmts)
|
||||||
|
elif stmts is not None:
|
||||||
|
body.extend(stmts)
|
||||||
|
|
||||||
|
return ForStmt(
|
||||||
|
location=Location.from_ast(node),
|
||||||
|
target=self.parse_expr(node.target),
|
||||||
|
iterator=self.parse_expr(node.iter),
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
|
||||||
def parse_function(self, node: ast.FunctionDef) -> Function:
|
def parse_function(self, node: ast.FunctionDef) -> Function:
|
||||||
loc: Location = Location.from_ast(node)
|
loc: Location = Location.from_ast(node)
|
||||||
match node:
|
match node:
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from midas.ast.python import (
|
|||||||
ConstraintType,
|
ConstraintType,
|
||||||
Expr,
|
Expr,
|
||||||
ExpressionStmt,
|
ExpressionStmt,
|
||||||
|
ForStmt,
|
||||||
FrameColumn,
|
FrameColumn,
|
||||||
FrameType,
|
FrameType,
|
||||||
Function,
|
Function,
|
||||||
@@ -182,6 +183,14 @@ class PythonAstJsonSerializer(
|
|||||||
"_type": "Pass",
|
"_type": "Pass",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def visit_for_stmt(self, stmt: ForStmt) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "ForStmt",
|
||||||
|
"target": stmt.target.accept(self),
|
||||||
|
"iterator": stmt.iterator.accept(self),
|
||||||
|
"body": self._serialize_list(stmt.body),
|
||||||
|
}
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
|
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "BinaryExpr",
|
"_type": "BinaryExpr",
|
||||||
|
|||||||
Reference in New Issue
Block a user