Merge pull request 'Cast assertions and generator tests' (#12) from feat/cast-assertions into main

Reviewed-on: #12
This commit was merged in pull request #12.
This commit is contained in:
2026-06-16 12:57:49 +00:00
19 changed files with 572 additions and 28 deletions

View File

@@ -92,6 +92,10 @@ class ForStmt:
body: list[Stmt] body: list[Stmt]
class RawStmt:
stmt: ast.stmt
###< ###<
@@ -164,4 +168,8 @@ class SliceExpr:
step: Optional[Expr] step: Optional[Expr]
class RawExpr:
expr: ast.expr
###< ###<

View File

@@ -613,6 +613,11 @@ class PythonAstPrinter(
self._mark_last() self._mark_last()
body_stmt.accept(self) body_stmt.accept(self)
def visit_raw_stmt(self, stmt: p.RawStmt) -> None:
self._write_line("RawStmt")
with self._child_level(single=True):
self._write_line(f"stmt: {ast.unparse(stmt.stmt)}")
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self._write_line("BinaryExpr") self._write_line("BinaryExpr")
with self._child_level(): with self._child_level():
@@ -756,3 +761,8 @@ class PythonAstPrinter(
self._write_optional_child("lower", expr.lower) self._write_optional_child("lower", expr.lower)
self._write_optional_child("upper", expr.upper) self._write_optional_child("upper", expr.upper)
self._write_optional_child("step", expr.step, last=True) self._write_optional_child("step", expr.step, last=True)
def visit_raw_expr(self, expr: p.RawExpr) -> None:
self._write_line("RawExpr")
with self._child_level(single=True):
self._write_line(f"expr: {ast.unparse(expr.expr)}")

View File

@@ -113,6 +113,9 @@ class Stmt(ABC):
@abstractmethod @abstractmethod
def visit_for_stmt(self, stmt: ForStmt) -> T: ... def visit_for_stmt(self, stmt: ForStmt) -> T: ...
@abstractmethod
def visit_raw_stmt(self, stmt: RawStmt) -> T: ...
@dataclass(frozen=True) @dataclass(frozen=True)
class ExpressionStmt(Stmt): class ExpressionStmt(Stmt):
@@ -202,6 +205,14 @@ class ForStmt(Stmt):
return visitor.visit_for_stmt(self) return visitor.visit_for_stmt(self)
@dataclass(frozen=True)
class RawStmt(Stmt):
stmt: ast.stmt
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_raw_stmt(self)
############### ###############
# Expressions # # Expressions #
############### ###############
@@ -254,6 +265,9 @@ class Expr(ABC):
@abstractmethod @abstractmethod
def visit_slice_expr(self, expr: SliceExpr) -> T: ... def visit_slice_expr(self, expr: SliceExpr) -> T: ...
@abstractmethod
def visit_raw_expr(self, expr: RawExpr) -> T: ...
@dataclass(frozen=True) @dataclass(frozen=True)
class BinaryExpr(Expr): class BinaryExpr(Expr):
@@ -373,3 +387,11 @@ class SliceExpr(Expr):
def accept(self, visitor: Expr.Visitor[T]) -> T: def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_slice_expr(self) return visitor.visit_slice_expr(self)
@dataclass(frozen=True)
class RawExpr(Expr):
expr: ast.expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_raw_expr(self)

121
midas/checker/preamble.py Normal file
View File

@@ -0,0 +1,121 @@
from dataclasses import dataclass
from midas.checker.environment import Environment
from midas.checker.registry import TypesRegistry
from midas.checker.types import Function, GenericType, TopType, Type, TypeVar, UnitType
@dataclass(frozen=True)
class Param:
name: str
type: Type
required: bool = True
class Preamble(Environment):
def __init__(self, types: TypesRegistry) -> None:
super().__init__()
self._types: TypesRegistry = types
self._def_type_constructor("object")
self._def_type_constructor("float")
self._def_type_constructor("int")
self._def_type_constructor("bool")
self._def_type_constructor("str")
self._def_function(
name="list",
pos=[Param("object", TopType())],
returns=self._list_of(TopType()),
)
# TODO: use sink
self._def_function(
name="print",
pos=[Param("object", TopType())],
returns=UnitType(),
)
map_in = TypeVar(name="T", bound=None)
map_out = TypeVar(name="U", bound=None)
mapper = self._make_function(
name="MapTransform",
pos=[Param("v", map_in)],
returns=map_out,
)
self._def_function(
name="map",
pos=[
Param("transform", mapper),
Param(
"iterable",
self._list_of(map_in), # TODO: replace with Iterable[T]
),
],
returns=self._list_of(map_out), # TODO: replace with Iterable[U]
)
def _list_of(self, item_type: Type) -> Type:
return self._types.apply_generic(self._types.get_type("list"), [item_type])
def _def_type_constructor(self, name: str):
# TODO: more specific arg types
self._def_function(
name=name,
pos=[Param("object", TopType())],
returns=self._types.get_type(name),
)
def _make_function(
self,
*,
name: str,
pos: list[Param] = [],
mixed: list[Param] = [],
kw: list[Param] = [],
returns: Type = UnitType(),
type_vars: list[TypeVar] = [],
) -> Type:
def map_args(params: list[Param], offset: int) -> list[Function.Argument]:
return [
Function.Argument(
pos=i + offset,
name=param.name,
type=param.type,
required=param.required,
)
for i, param in enumerate(params)
]
function = Function(
pos_args=map_args(pos, 0),
args=map_args(mixed, len(pos)),
kw_args=map_args(kw, len(pos) + len(mixed)),
returns=returns,
)
if len(type_vars) != 0:
function = GenericType(
name=name,
params=type_vars,
body=function,
)
return function
def _def_function(
self,
*,
name: str,
pos: list[Param] = [],
mixed: list[Param] = [],
kw: list[Param] = [],
returns: Type = UnitType(),
type_vars: list[TypeVar] = [],
):
function: Type = self._make_function(
name=name,
pos=pos,
mixed=mixed,
kw=kw,
returns=returns,
type_vars=type_vars,
)
self.define(name, function)

View File

@@ -7,10 +7,12 @@ import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver from midas.checker.resolver import Resolver
from midas.checker.types import ( from midas.checker.types import (
AppliedType,
Function, Function,
OverloadedFunction, OverloadedFunction,
Type, Type,
@@ -56,7 +58,7 @@ class PythonTyper(
self.logger: logging.Logger = logging.getLogger("PythonTyper") self.logger: logging.Logger = logging.getLogger("PythonTyper")
self.reporter: FileReporter = reporter.for_file(None) self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types self.types: TypesRegistry = types
self.global_env: Environment = Environment() self.global_env: Environment = Preamble(self.types)
self.env: Environment = self.global_env self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {} self.locals: dict[p.Expr, int] = {}
self.judgements: list[tuple[p.Expr, Type]] = [] self.judgements: list[tuple[p.Expr, Type]] = []
@@ -252,7 +254,7 @@ class PythonTyper(
if returns_hint is not None: if returns_hint is not None:
assert stmt.returns is not None assert stmt.returns is not None
returns = returns_hint returns = returns_hint
if returns != inferred_return: if not self.is_subtype(inferred_return, returns):
self.reporter.error( self.reporter.error(
stmt.returns.location, stmt.returns.location,
f"Return type mismatch, annotated {returns} but returns {inferred_return}", f"Return type mismatch, annotated {returns} but returns {inferred_return}",
@@ -370,6 +372,9 @@ class PythonTyper(
if body_returned: if body_returned:
raise ReturnException() raise ReturnException()
def visit_raw_stmt(self, stmt: p.RawStmt) -> None:
pass
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
if method is None: if method is None:
@@ -566,6 +571,9 @@ class PythonTyper(
def visit_slice_expr(self, expr: p.SliceExpr) -> Type: def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
return self.types.get_type("slice") return self.types.get_type("slice")
def visit_raw_expr(self, expr: p.RawExpr) -> Type:
return UnknownType()
def visit_base_type(self, node: p.BaseType) -> Type: def visit_base_type(self, node: p.BaseType) -> Type:
base: Type base: Type
try: try:
@@ -637,6 +645,15 @@ class PythonTyper(
if function is None: if function is None:
return None return None
return function.returns return function.returns
case AppliedType(body=body):
return self._get_call_result(
location, body, positional, keywords, report_errors
)
case UnknownType():
return UnknownType()
case _: case _:
if report_errors: if report_errors:
self.reporter.error(location, f"{callee} is not callable") self.reporter.error(location, f"{callee} is not callable")

View File

@@ -163,6 +163,9 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
self.resolve(*stmt.body) self.resolve(*stmt.body)
self.end_scope() self.end_scope()
def visit_raw_stmt(self, stmt: p.RawStmt) -> None:
pass
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self.resolve(expr.left) self.resolve(expr.left)
self.resolve(expr.right) self.resolve(expr.right)
@@ -221,3 +224,6 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
self.resolve(expr.upper) self.resolve(expr.upper)
if expr.step is not None: if expr.step is not None:
self.resolve(expr.step) self.resolve(expr.step)
def visit_raw_expr(self, expr: p.RawExpr) -> None:
pass

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Optional from typing import Optional
@@ -41,23 +41,21 @@ class UnitType:
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Function: class Function:
pos_args: list[Argument] pos_args: list[Argument] = field(default_factory=list)
args: list[Argument] args: list[Argument] = field(default_factory=list)
kw_args: list[Argument] kw_args: list[Argument] = field(default_factory=list)
returns: Type returns: Type
def __str__(self) -> str: def __str__(self) -> str:
args: list[str] = [] args: list[str] = []
if len(self.pos_args) != 0: if len(self.pos_args) != 0:
args += list(map(str, self.pos_args)) args += list(map(str, self.pos_args))
if len(self.args) + len(self.kw_args) != 0:
args.append("/") args.append("/")
if len(self.args) != 0: if len(self.args) != 0:
args += list(map(str, self.args)) args += list(map(str, self.args))
if len(self.kw_args) != 0: if len(self.kw_args) != 0:
if len(args) != 0:
args.append("*") args.append("*")
args += list(map(str, self.kw_args)) args += list(map(str, self.kw_args))

View File

@@ -9,7 +9,21 @@ from typing import TextIO
import click import click
from midas.checker.checker import TypeChecker from midas.checker.checker import TypeChecker
from midas.checker.types import Type from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
def base_type(type: Type) -> Type:
match type:
case BaseType():
return type
case AliasType(type=base):
return base
case AppliedType(body=body):
return body
case GenericType(body=body):
return body
case _:
return type
@click.command(help="Dump types registry") @click.command(help="Dump types registry")
@@ -23,7 +37,7 @@ def dump_registry(
for name, type in checker.types._types.items(): for name, type in checker.types._types.items():
members: dict[str, Type] = checker.types._members.get(name, {}) members: dict[str, Type] = checker.types._members.get(name, {})
print(f"{name} = {type}") print(f"{name} = {base_type(type)}")
if len(members) != 0: if len(members) != 0:
print(" " * 4 + "Members:") print(" " * 4 + "Members:")
for member_name, member_type in members.items(): for member_name, member_type in members.items():

View File

@@ -1,11 +1,34 @@
import ast import ast
import shutil import shutil
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Optional
from midas.ast.location import Location
import midas.ast.python as p import midas.ast.python as p
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
TopType,
Type,
TypeVar,
UnitType,
)
from midas.utils import TypedAST from midas.utils import TypedAST
@dataclass
class Scope:
pre_assertions: list[ast.stmt] = field(default_factory=list)
aliases: list[str] = field(default_factory=list)
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def __init__(self, workdir: Path) -> None: def __init__(self, workdir: Path) -> None:
self.workdir: Path = workdir.resolve() self.workdir: Path = workdir.resolve()
@@ -13,19 +36,35 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
if self.build_dir.exists(): if self.build_dir.exists():
shutil.rmtree(self.build_dir) shutil.rmtree(self.build_dir)
self.build_dir.mkdir(parents=True, exist_ok=True) self.build_dir.mkdir(parents=True, exist_ok=True)
self.rel_src_path: Path = Path()
def generate(self, typed_ast: TypedAST, src_path: Path) -> Path: self._typed_ast: TypedAST = TypedAST(
stmts=[],
judgements=[],
)
self._alias_count: int = 0
self._scopes: list[Scope] = []
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
self.rel_src_path = src_path.relative_to(self.workdir)
self._typed_ast = typed_ast
body: list[ast.stmt] = self._visit_body(typed_ast.stmts) body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
module = ast.Module(body=body, type_ignores=[]) module = ast.Module(body=body, type_ignores=[])
module = ast.fix_missing_locations(module) module = ast.fix_missing_locations(module)
return module
def generate(
self, typed_ast: TypedAST, src_path: Path, out_path: Optional[Path] = None
) -> Path:
module: ast.AST = self.generate_ast(typed_ast, src_path)
compiled: str = ast.unparse(module) compiled: str = ast.unparse(module)
rel_src_path: Path = src_path.relative_to(self.workdir) if out_path is None:
out_path: Path = (self.build_dir / rel_src_path).resolve() out_path = (self.build_dir / self.rel_src_path).resolve()
try: try:
_ = out_path.relative_to(self.build_dir) _ = out_path.relative_to(self.build_dir)
except ValueError: except ValueError:
raise ValueError( raise ValueError(
f"Directory traversal, {rel_src_path} points outside of parent directory" f"Directory traversal, {self.rel_src_path} points outside of parent directory"
) )
out_path.parent.mkdir(parents=True, exist_ok=True) out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(compiled) out_path.write_text(compiled)
@@ -80,8 +119,13 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
) )
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr: def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
# TODO: insert assertion expr2: ast.expr = expr.expr.accept(self)
return expr.expr.accept(self) alias: ast.expr = self._make_alias(expr2)
type: Type = self._get_expr_type(expr)
self._make_cast_asserts(expr.location, alias, type)
return alias
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr: def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
return ast.IfExp( return ast.IfExp(
@@ -108,6 +152,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
step=expr.step.accept(self) if expr.step is not None else None, step=expr.step.accept(self) if expr.step is not None else None,
) )
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
return expr.expr
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt: def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt:
return ast.Expr( return ast.Expr(
value=stmt.expr.accept(self), value=stmt.expr.accept(self),
@@ -169,5 +216,120 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
orelse=[], orelse=[],
) )
def visit_raw_stmt(self, stmt: p.RawStmt) -> ast.stmt:
return stmt.stmt
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]: def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
return [stmt.accept(self) for stmt in stmts] generated: list[ast.stmt] = []
for stmt in stmts:
scope = Scope()
self._scopes.append(scope)
stmt2 = stmt.accept(self)
generated.extend(scope.pre_assertions)
generated.append(stmt2)
if len(scope.aliases) != 0:
generated.append(
ast.Delete(targets=[ast.Name(id=alias) for alias in scope.aliases])
)
self._scopes.pop()
# Remove redundant pass statements
if len(generated) > 1:
generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)]
return generated
def _make_alias(self, expr: ast.expr) -> ast.expr:
name: str = f"__midas_alias_{self._alias_count}__"
alias = ast.Name(id=name)
self._alias_count += 1
self._scopes[-1].aliases.append(name)
self._scopes[-1].pre_assertions.append(
ast.Assign(
targets=[alias],
value=expr,
)
)
return alias
def _add_assert(self, expr: ast.expr, message: str | ast.expr):
if isinstance(message, str):
message = ast.Constant(value=message)
self._scopes[-1].pre_assertions.append(
ast.Assert(
test=expr,
msg=message,
)
)
def _get_expr_type(self, query: p.Expr) -> Type:
for expr, type in self._typed_ast.judgements:
if expr == query:
return type
raise RuntimeError(f"Cannot get type judgement for {query}")
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
match type:
case BaseType(name=name):
self._add_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id=name)],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
)
case AliasType(type=base):
self._make_cast_asserts(src_location, expr, base)
case UnitType():
self._add_assert(
ast.Compare(
left=expr,
ops=[ast.Is()],
comparators=[
ast.Constant(value=None),
],
),
self._make_cast_assert_message(src_location, expr, type),
)
case AppliedType():
self._make_cast_asserts(src_location, expr, type.body)
case (
TopType()
| Function()
| OverloadedFunction()
| ComplexType()
| ExtensionType()
| GenericType()
):
raise NotImplementedError(f"Can't make assertion for type {type}")
case TypeVar():
raise RuntimeError("Unexpected TypeVar")
def _make_cast_assert_message(
self, location: Location, expr: ast.expr, type: Type
) -> ast.expr:
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
return ast.JoinedStr(
values=[
ast.Constant(f"{loc_str}: CastError: Cannot cast "),
ast.FormattedValue(
value=ast.Attribute(
value=ast.Call(
func=ast.Name(id="type"),
args=[expr],
keywords=[],
),
attr="__name__",
),
conversion=-1,
),
ast.Constant(f" to {type}"),
]
)

View File

@@ -22,6 +22,8 @@ from midas.ast.python import (
LiteralExpr, LiteralExpr,
LogicalExpr, LogicalExpr,
MidasType, MidasType,
RawExpr,
RawStmt,
ReturnStmt, ReturnStmt,
SliceExpr, SliceExpr,
Stmt, Stmt,
@@ -99,7 +101,7 @@ class PythonParser:
case _: case _:
print(f"Unsupported statement: {ast.unparse(node)}") print(f"Unsupported statement: {ast.unparse(node)}")
return None return RawStmt(location=location, stmt=node)
def parse_annotation_assign(self, node: ast.AnnAssign) -> list[Stmt]: def parse_annotation_assign(self, node: ast.AnnAssign) -> list[Stmt]:
statements: list[Stmt] = [] statements: list[Stmt] = []
@@ -461,7 +463,8 @@ class PythonParser:
) )
case _: case _:
raise UnsupportedSyntaxError(node) print(f"Unsupported expression: {ast.unparse(node)}")
return RawExpr(location=location, expr=node)
def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr: def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr:
op: ast.boolop = node.op op: ast.boolop = node.op

View File

@@ -21,6 +21,10 @@ class Tester(ABC):
@abstractmethod @abstractmethod
def namespace(self) -> str: ... def namespace(self) -> str: ...
@property
def extension(self) -> str:
return "json"
@property @property
def base_dir(self) -> Path: def base_dir(self) -> Path:
return self.CASES_DIR / self.namespace return self.CASES_DIR / self.namespace
@@ -99,7 +103,7 @@ class Tester(ABC):
return True return True
def _result_path(self, test_path: Path) -> Path: def _result_path(self, test_path: Path) -> Path:
return test_path.parent / (test_path.name + ".ref.json") return test_path.parent / (test_path.name + f".ref.{self.extension}")
def _print_diff(self, diff: Iterator[str]): def _print_diff(self, diff: Iterator[str]):
for line in diff: for line in diff:

View File

@@ -0,0 +1,14 @@
type Meter = float
type Second = float
type MeterPerSecond = float
extend Meter {
def __add__: fn(Meter, /) -> Meter
def __sub__: fn(Meter, /) -> Meter
def __truediv__: fn(Second, /) -> MeterPerSecond
}
extend Second {
def __add__: fn(Second, /) -> Second
def __sub__: fn(Second, /) -> Second
}

View File

@@ -0,0 +1,5 @@
from midas import cast, Meter, Second
distance: Meter = cast(Meter, 123.45)
time: Second = cast(Second, 6.7)
speed = distance / time

View File

@@ -0,0 +1,79 @@
Module(
body=[
ImportFrom(
module='midas',
names=[
alias(name='cast'),
alias(name='Meter'),
alias(name='Second')],
level=0),
Assign(
targets=[
Name(id='__midas_alias_0__')],
value=Constant(value=123.45)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_alias_0__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='01_simple_types.py:L3:19: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_alias_0__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assign(
targets=[
Name(id='distance')],
value=Name(id='__midas_alias_0__')),
Delete(
targets=[
Name(id='__midas_alias_0__')]),
Assign(
targets=[
Name(id='__midas_alias_1__')],
value=Constant(value=6.7)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_alias_1__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='01_simple_types.py:L4:16: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_alias_1__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assign(
targets=[
Name(id='time')],
value=Name(id='__midas_alias_1__')),
Delete(
targets=[
Name(id='__midas_alias_1__')]),
Assign(
targets=[
Name(id='speed')],
value=BinOp(
left=Name(id='distance'),
op=Div(),
right=Name(id='time')))],
type_ignores=[])

View File

@@ -1,5 +1,9 @@
{ {
"stmts": [ "stmts": [
{
"_type": "RawStmt",
"stmt": "from __future__ import annotations"
},
{ {
"_type": "TypeAssign", "_type": "TypeAssign",
"name": "df", "name": "df",

View File

@@ -1,5 +1,9 @@
{ {
"stmts": [ "stmts": [
{
"_type": "RawStmt",
"stmt": "from __future__ import annotations"
},
{ {
"_type": "TypeAssign", "_type": "TypeAssign",
"name": "df", "name": "df",

View File

@@ -1,5 +1,9 @@
{ {
"stmts": [ "stmts": [
{
"_type": "RawStmt",
"stmt": "from __future__ import annotations"
},
{ {
"_type": "Function", "_type": "Function",
"name": "func", "name": "func",

55
tests/generator.py Normal file
View File

@@ -0,0 +1,55 @@
import ast
from dataclasses import dataclass
from pathlib import Path
from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import DiagnosticType
from midas.generator.generator import Generator
from midas.utils import TypedAST
from tests.base import Tester
@dataclass
class CaseResult:
compiled_ast: ast.AST = ast.Module([], [])
def dumps(self) -> str:
return ast.dump(self.compiled_ast, indent=2)
class GeneratorTester(Tester):
@property
def namespace(self) -> str:
return "generator"
@property
def extension(self) -> str:
return "txt"
def _list_tests(self) -> list[Path]:
return list(self.base_dir.rglob("*.py"))
def _exec_case(self, path: Path) -> CaseResult:
if not path.exists():
raise FileNotFoundError(f"Could not find test '{path}'")
if not path.is_file():
raise TypeError(f"Test '{path}' is not a file")
result: CaseResult = CaseResult()
checker = TypeChecker()
types_path: Path = path.with_suffix(".midas")
if types_path.exists():
checker.import_midas(types_path)
typed_ast: TypedAST = checker.type_check(path)
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
generator = Generator(workdir=path.parent)
result.compiled_ast = generator.generate_ast(typed_ast, path)
return result
if __name__ == "__main__":
GeneratorTester.main()

View File

@@ -22,6 +22,8 @@ from midas.ast.python import (
LogicalExpr, LogicalExpr,
MidasType, MidasType,
Pass, Pass,
RawExpr,
RawStmt,
ReturnStmt, ReturnStmt,
SliceExpr, SliceExpr,
Stmt, Stmt,
@@ -191,6 +193,12 @@ class PythonAstJsonSerializer(
"body": self._serialize_list(stmt.body), "body": self._serialize_list(stmt.body),
} }
def visit_raw_stmt(self, stmt: RawStmt) -> dict:
return {
"_type": "RawStmt",
"stmt": ast.unparse(stmt.stmt),
}
def visit_binary_expr(self, expr: BinaryExpr) -> dict: def visit_binary_expr(self, expr: BinaryExpr) -> dict:
return { return {
"_type": "BinaryExpr", "_type": "BinaryExpr",
@@ -284,3 +292,9 @@ class PythonAstJsonSerializer(
"upper": self._serialize_optional(expr.upper), "upper": self._serialize_optional(expr.upper),
"step": self._serialize_optional(expr.step), "step": self._serialize_optional(expr.step),
} }
def visit_raw_expr(self, expr: RawExpr) -> dict:
return {
"_type": "RawExpr",
"expr": ast.unparse(expr.expr),
}