Merge pull request 'Cast assertions and generator tests' (#12) from feat/cast-assertions into main
Reviewed-on: #12
This commit was merged in pull request #12.
This commit is contained in:
@@ -92,6 +92,10 @@ class ForStmt:
|
|||||||
body: list[Stmt]
|
body: list[Stmt]
|
||||||
|
|
||||||
|
|
||||||
|
class RawStmt:
|
||||||
|
stmt: ast.stmt
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
|
|
||||||
|
|
||||||
@@ -164,4 +168,8 @@ class SliceExpr:
|
|||||||
step: Optional[Expr]
|
step: Optional[Expr]
|
||||||
|
|
||||||
|
|
||||||
|
class RawExpr:
|
||||||
|
expr: ast.expr
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
|
|||||||
@@ -613,6 +613,11 @@ class PythonAstPrinter(
|
|||||||
self._mark_last()
|
self._mark_last()
|
||||||
body_stmt.accept(self)
|
body_stmt.accept(self)
|
||||||
|
|
||||||
|
def visit_raw_stmt(self, stmt: p.RawStmt) -> None:
|
||||||
|
self._write_line("RawStmt")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
self._write_line(f"stmt: {ast.unparse(stmt.stmt)}")
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||||
self._write_line("BinaryExpr")
|
self._write_line("BinaryExpr")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
@@ -756,3 +761,8 @@ class PythonAstPrinter(
|
|||||||
self._write_optional_child("lower", expr.lower)
|
self._write_optional_child("lower", expr.lower)
|
||||||
self._write_optional_child("upper", expr.upper)
|
self._write_optional_child("upper", expr.upper)
|
||||||
self._write_optional_child("step", expr.step, last=True)
|
self._write_optional_child("step", expr.step, last=True)
|
||||||
|
|
||||||
|
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
||||||
|
self._write_line("RawExpr")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
self._write_line(f"expr: {ast.unparse(expr.expr)}")
|
||||||
|
|||||||
@@ -113,6 +113,9 @@ class Stmt(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_for_stmt(self, stmt: ForStmt) -> T: ...
|
def visit_for_stmt(self, stmt: ForStmt) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_raw_stmt(self, stmt: RawStmt) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ExpressionStmt(Stmt):
|
class ExpressionStmt(Stmt):
|
||||||
@@ -202,6 +205,14 @@ class ForStmt(Stmt):
|
|||||||
return visitor.visit_for_stmt(self)
|
return visitor.visit_for_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RawStmt(Stmt):
|
||||||
|
stmt: ast.stmt
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_raw_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
###############
|
###############
|
||||||
# Expressions #
|
# Expressions #
|
||||||
###############
|
###############
|
||||||
@@ -254,6 +265,9 @@ class Expr(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_slice_expr(self, expr: SliceExpr) -> T: ...
|
def visit_slice_expr(self, expr: SliceExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_raw_expr(self, expr: RawExpr) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class BinaryExpr(Expr):
|
class BinaryExpr(Expr):
|
||||||
@@ -373,3 +387,11 @@ class SliceExpr(Expr):
|
|||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
return visitor.visit_slice_expr(self)
|
return visitor.visit_slice_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RawExpr(Expr):
|
||||||
|
expr: ast.expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_raw_expr(self)
|
||||||
|
|||||||
121
midas/checker/preamble.py
Normal file
121
midas/checker/preamble.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from midas.checker.environment import Environment
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.types import Function, GenericType, TopType, Type, TypeVar, UnitType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Param:
|
||||||
|
name: str
|
||||||
|
type: Type
|
||||||
|
required: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class Preamble(Environment):
|
||||||
|
def __init__(self, types: TypesRegistry) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._types: TypesRegistry = types
|
||||||
|
|
||||||
|
self._def_type_constructor("object")
|
||||||
|
self._def_type_constructor("float")
|
||||||
|
self._def_type_constructor("int")
|
||||||
|
self._def_type_constructor("bool")
|
||||||
|
self._def_type_constructor("str")
|
||||||
|
self._def_function(
|
||||||
|
name="list",
|
||||||
|
pos=[Param("object", TopType())],
|
||||||
|
returns=self._list_of(TopType()),
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: use sink
|
||||||
|
self._def_function(
|
||||||
|
name="print",
|
||||||
|
pos=[Param("object", TopType())],
|
||||||
|
returns=UnitType(),
|
||||||
|
)
|
||||||
|
|
||||||
|
map_in = TypeVar(name="T", bound=None)
|
||||||
|
map_out = TypeVar(name="U", bound=None)
|
||||||
|
mapper = self._make_function(
|
||||||
|
name="MapTransform",
|
||||||
|
pos=[Param("v", map_in)],
|
||||||
|
returns=map_out,
|
||||||
|
)
|
||||||
|
self._def_function(
|
||||||
|
name="map",
|
||||||
|
pos=[
|
||||||
|
Param("transform", mapper),
|
||||||
|
Param(
|
||||||
|
"iterable",
|
||||||
|
self._list_of(map_in), # TODO: replace with Iterable[T]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
returns=self._list_of(map_out), # TODO: replace with Iterable[U]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _list_of(self, item_type: Type) -> Type:
|
||||||
|
return self._types.apply_generic(self._types.get_type("list"), [item_type])
|
||||||
|
|
||||||
|
def _def_type_constructor(self, name: str):
|
||||||
|
# TODO: more specific arg types
|
||||||
|
self._def_function(
|
||||||
|
name=name,
|
||||||
|
pos=[Param("object", TopType())],
|
||||||
|
returns=self._types.get_type(name),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_function(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
pos: list[Param] = [],
|
||||||
|
mixed: list[Param] = [],
|
||||||
|
kw: list[Param] = [],
|
||||||
|
returns: Type = UnitType(),
|
||||||
|
type_vars: list[TypeVar] = [],
|
||||||
|
) -> Type:
|
||||||
|
def map_args(params: list[Param], offset: int) -> list[Function.Argument]:
|
||||||
|
return [
|
||||||
|
Function.Argument(
|
||||||
|
pos=i + offset,
|
||||||
|
name=param.name,
|
||||||
|
type=param.type,
|
||||||
|
required=param.required,
|
||||||
|
)
|
||||||
|
for i, param in enumerate(params)
|
||||||
|
]
|
||||||
|
|
||||||
|
function = Function(
|
||||||
|
pos_args=map_args(pos, 0),
|
||||||
|
args=map_args(mixed, len(pos)),
|
||||||
|
kw_args=map_args(kw, len(pos) + len(mixed)),
|
||||||
|
returns=returns,
|
||||||
|
)
|
||||||
|
if len(type_vars) != 0:
|
||||||
|
function = GenericType(
|
||||||
|
name=name,
|
||||||
|
params=type_vars,
|
||||||
|
body=function,
|
||||||
|
)
|
||||||
|
return function
|
||||||
|
|
||||||
|
def _def_function(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
pos: list[Param] = [],
|
||||||
|
mixed: list[Param] = [],
|
||||||
|
kw: list[Param] = [],
|
||||||
|
returns: Type = UnitType(),
|
||||||
|
type_vars: list[TypeVar] = [],
|
||||||
|
):
|
||||||
|
function: Type = self._make_function(
|
||||||
|
name=name,
|
||||||
|
pos=pos,
|
||||||
|
mixed=mixed,
|
||||||
|
kw=kw,
|
||||||
|
returns=returns,
|
||||||
|
type_vars=type_vars,
|
||||||
|
)
|
||||||
|
self.define(name, function)
|
||||||
@@ -7,10 +7,12 @@ import midas.ast.python as p
|
|||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS
|
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS
|
||||||
|
from midas.checker.preamble import Preamble
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.reporter import FileReporter, Reporter
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
from midas.checker.resolver import Resolver
|
from midas.checker.resolver import Resolver
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
|
AppliedType,
|
||||||
Function,
|
Function,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
Type,
|
Type,
|
||||||
@@ -56,7 +58,7 @@ class PythonTyper(
|
|||||||
self.logger: logging.Logger = logging.getLogger("PythonTyper")
|
self.logger: logging.Logger = logging.getLogger("PythonTyper")
|
||||||
self.reporter: FileReporter = reporter.for_file(None)
|
self.reporter: FileReporter = reporter.for_file(None)
|
||||||
self.types: TypesRegistry = types
|
self.types: TypesRegistry = types
|
||||||
self.global_env: Environment = Environment()
|
self.global_env: Environment = Preamble(self.types)
|
||||||
self.env: Environment = self.global_env
|
self.env: Environment = self.global_env
|
||||||
self.locals: dict[p.Expr, int] = {}
|
self.locals: dict[p.Expr, int] = {}
|
||||||
self.judgements: list[tuple[p.Expr, Type]] = []
|
self.judgements: list[tuple[p.Expr, Type]] = []
|
||||||
@@ -252,7 +254,7 @@ class PythonTyper(
|
|||||||
if returns_hint is not None:
|
if returns_hint is not None:
|
||||||
assert stmt.returns is not None
|
assert stmt.returns is not None
|
||||||
returns = returns_hint
|
returns = returns_hint
|
||||||
if returns != inferred_return:
|
if not self.is_subtype(inferred_return, returns):
|
||||||
self.reporter.error(
|
self.reporter.error(
|
||||||
stmt.returns.location,
|
stmt.returns.location,
|
||||||
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
||||||
@@ -370,6 +372,9 @@ class PythonTyper(
|
|||||||
if body_returned:
|
if body_returned:
|
||||||
raise ReturnException()
|
raise ReturnException()
|
||||||
|
|
||||||
|
def visit_raw_stmt(self, stmt: p.RawStmt) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
||||||
if method is None:
|
if method is None:
|
||||||
@@ -566,6 +571,9 @@ class PythonTyper(
|
|||||||
def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
|
def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
|
||||||
return self.types.get_type("slice")
|
return self.types.get_type("slice")
|
||||||
|
|
||||||
|
def visit_raw_expr(self, expr: p.RawExpr) -> Type:
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
def visit_base_type(self, node: p.BaseType) -> Type:
|
def visit_base_type(self, node: p.BaseType) -> Type:
|
||||||
base: Type
|
base: Type
|
||||||
try:
|
try:
|
||||||
@@ -637,6 +645,15 @@ class PythonTyper(
|
|||||||
if function is None:
|
if function is None:
|
||||||
return None
|
return None
|
||||||
return function.returns
|
return function.returns
|
||||||
|
|
||||||
|
case AppliedType(body=body):
|
||||||
|
return self._get_call_result(
|
||||||
|
location, body, positional, keywords, report_errors
|
||||||
|
)
|
||||||
|
|
||||||
|
case UnknownType():
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
if report_errors:
|
if report_errors:
|
||||||
self.reporter.error(location, f"{callee} is not callable")
|
self.reporter.error(location, f"{callee} is not callable")
|
||||||
|
|||||||
@@ -163,6 +163,9 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
self.resolve(*stmt.body)
|
self.resolve(*stmt.body)
|
||||||
self.end_scope()
|
self.end_scope()
|
||||||
|
|
||||||
|
def visit_raw_stmt(self, stmt: p.RawStmt) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||||
self.resolve(expr.left)
|
self.resolve(expr.left)
|
||||||
self.resolve(expr.right)
|
self.resolve(expr.right)
|
||||||
@@ -221,3 +224,6 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
self.resolve(expr.upper)
|
self.resolve(expr.upper)
|
||||||
if expr.step is not None:
|
if expr.step is not None:
|
||||||
self.resolve(expr.step)
|
self.resolve(expr.step)
|
||||||
|
|
||||||
|
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
||||||
|
pass
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
@@ -41,23 +41,21 @@ class UnitType:
|
|||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class Function:
|
class Function:
|
||||||
pos_args: list[Argument]
|
pos_args: list[Argument] = field(default_factory=list)
|
||||||
args: list[Argument]
|
args: list[Argument] = field(default_factory=list)
|
||||||
kw_args: list[Argument]
|
kw_args: list[Argument] = field(default_factory=list)
|
||||||
returns: Type
|
returns: Type
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
args: list[str] = []
|
args: list[str] = []
|
||||||
if len(self.pos_args) != 0:
|
if len(self.pos_args) != 0:
|
||||||
args += list(map(str, self.pos_args))
|
args += list(map(str, self.pos_args))
|
||||||
if len(self.args) + len(self.kw_args) != 0:
|
|
||||||
args.append("/")
|
args.append("/")
|
||||||
|
|
||||||
if len(self.args) != 0:
|
if len(self.args) != 0:
|
||||||
args += list(map(str, self.args))
|
args += list(map(str, self.args))
|
||||||
|
|
||||||
if len(self.kw_args) != 0:
|
if len(self.kw_args) != 0:
|
||||||
if len(args) != 0:
|
|
||||||
args.append("*")
|
args.append("*")
|
||||||
args += list(map(str, self.kw_args))
|
args += list(map(str, self.kw_args))
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,21 @@ from typing import TextIO
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from midas.checker.checker import TypeChecker
|
from midas.checker.checker import TypeChecker
|
||||||
from midas.checker.types import Type
|
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
|
||||||
|
|
||||||
|
|
||||||
|
def base_type(type: Type) -> Type:
|
||||||
|
match type:
|
||||||
|
case BaseType():
|
||||||
|
return type
|
||||||
|
case AliasType(type=base):
|
||||||
|
return base
|
||||||
|
case AppliedType(body=body):
|
||||||
|
return body
|
||||||
|
case GenericType(body=body):
|
||||||
|
return body
|
||||||
|
case _:
|
||||||
|
return type
|
||||||
|
|
||||||
|
|
||||||
@click.command(help="Dump types registry")
|
@click.command(help="Dump types registry")
|
||||||
@@ -23,7 +37,7 @@ def dump_registry(
|
|||||||
|
|
||||||
for name, type in checker.types._types.items():
|
for name, type in checker.types._types.items():
|
||||||
members: dict[str, Type] = checker.types._members.get(name, {})
|
members: dict[str, Type] = checker.types._members.get(name, {})
|
||||||
print(f"{name} = {type}")
|
print(f"{name} = {base_type(type)}")
|
||||||
if len(members) != 0:
|
if len(members) != 0:
|
||||||
print(" " * 4 + "Members:")
|
print(" " * 4 + "Members:")
|
||||||
for member_name, member_type in members.items():
|
for member_name, member_type in members.items():
|
||||||
|
|||||||
@@ -1,11 +1,34 @@
|
|||||||
import ast
|
import ast
|
||||||
import shutil
|
import shutil
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from midas.ast.location import Location
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
|
from midas.checker.types import (
|
||||||
|
AliasType,
|
||||||
|
AppliedType,
|
||||||
|
BaseType,
|
||||||
|
ComplexType,
|
||||||
|
ExtensionType,
|
||||||
|
Function,
|
||||||
|
GenericType,
|
||||||
|
OverloadedFunction,
|
||||||
|
TopType,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
UnitType,
|
||||||
|
)
|
||||||
from midas.utils import TypedAST
|
from midas.utils import TypedAST
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Scope:
|
||||||
|
pre_assertions: list[ast.stmt] = field(default_factory=list)
|
||||||
|
aliases: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||||
def __init__(self, workdir: Path) -> None:
|
def __init__(self, workdir: Path) -> None:
|
||||||
self.workdir: Path = workdir.resolve()
|
self.workdir: Path = workdir.resolve()
|
||||||
@@ -13,19 +36,35 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
if self.build_dir.exists():
|
if self.build_dir.exists():
|
||||||
shutil.rmtree(self.build_dir)
|
shutil.rmtree(self.build_dir)
|
||||||
self.build_dir.mkdir(parents=True, exist_ok=True)
|
self.build_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.rel_src_path: Path = Path()
|
||||||
|
|
||||||
def generate(self, typed_ast: TypedAST, src_path: Path) -> Path:
|
self._typed_ast: TypedAST = TypedAST(
|
||||||
|
stmts=[],
|
||||||
|
judgements=[],
|
||||||
|
)
|
||||||
|
self._alias_count: int = 0
|
||||||
|
self._scopes: list[Scope] = []
|
||||||
|
|
||||||
|
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
|
||||||
|
self.rel_src_path = src_path.relative_to(self.workdir)
|
||||||
|
self._typed_ast = typed_ast
|
||||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
||||||
module = ast.Module(body=body, type_ignores=[])
|
module = ast.Module(body=body, type_ignores=[])
|
||||||
module = ast.fix_missing_locations(module)
|
module = ast.fix_missing_locations(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self, typed_ast: TypedAST, src_path: Path, out_path: Optional[Path] = None
|
||||||
|
) -> Path:
|
||||||
|
module: ast.AST = self.generate_ast(typed_ast, src_path)
|
||||||
compiled: str = ast.unparse(module)
|
compiled: str = ast.unparse(module)
|
||||||
rel_src_path: Path = src_path.relative_to(self.workdir)
|
if out_path is None:
|
||||||
out_path: Path = (self.build_dir / rel_src_path).resolve()
|
out_path = (self.build_dir / self.rel_src_path).resolve()
|
||||||
try:
|
try:
|
||||||
_ = out_path.relative_to(self.build_dir)
|
_ = out_path.relative_to(self.build_dir)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Directory traversal, {rel_src_path} points outside of parent directory"
|
f"Directory traversal, {self.rel_src_path} points outside of parent directory"
|
||||||
)
|
)
|
||||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
out_path.write_text(compiled)
|
out_path.write_text(compiled)
|
||||||
@@ -80,8 +119,13 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
||||||
# TODO: insert assertion
|
expr2: ast.expr = expr.expr.accept(self)
|
||||||
return expr.expr.accept(self)
|
alias: ast.expr = self._make_alias(expr2)
|
||||||
|
|
||||||
|
type: Type = self._get_expr_type(expr)
|
||||||
|
self._make_cast_asserts(expr.location, alias, type)
|
||||||
|
|
||||||
|
return alias
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
|
||||||
return ast.IfExp(
|
return ast.IfExp(
|
||||||
@@ -108,6 +152,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
step=expr.step.accept(self) if expr.step is not None else None,
|
step=expr.step.accept(self) if expr.step is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
|
||||||
|
return expr.expr
|
||||||
|
|
||||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt:
|
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt:
|
||||||
return ast.Expr(
|
return ast.Expr(
|
||||||
value=stmt.expr.accept(self),
|
value=stmt.expr.accept(self),
|
||||||
@@ -169,5 +216,120 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
orelse=[],
|
orelse=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def visit_raw_stmt(self, stmt: p.RawStmt) -> ast.stmt:
|
||||||
|
return stmt.stmt
|
||||||
|
|
||||||
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
|
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
|
||||||
return [stmt.accept(self) for stmt in stmts]
|
generated: list[ast.stmt] = []
|
||||||
|
for stmt in stmts:
|
||||||
|
scope = Scope()
|
||||||
|
self._scopes.append(scope)
|
||||||
|
|
||||||
|
stmt2 = stmt.accept(self)
|
||||||
|
generated.extend(scope.pre_assertions)
|
||||||
|
generated.append(stmt2)
|
||||||
|
if len(scope.aliases) != 0:
|
||||||
|
generated.append(
|
||||||
|
ast.Delete(targets=[ast.Name(id=alias) for alias in scope.aliases])
|
||||||
|
)
|
||||||
|
self._scopes.pop()
|
||||||
|
|
||||||
|
# Remove redundant pass statements
|
||||||
|
if len(generated) > 1:
|
||||||
|
generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)]
|
||||||
|
return generated
|
||||||
|
|
||||||
|
def _make_alias(self, expr: ast.expr) -> ast.expr:
|
||||||
|
name: str = f"__midas_alias_{self._alias_count}__"
|
||||||
|
alias = ast.Name(id=name)
|
||||||
|
self._alias_count += 1
|
||||||
|
self._scopes[-1].aliases.append(name)
|
||||||
|
self._scopes[-1].pre_assertions.append(
|
||||||
|
ast.Assign(
|
||||||
|
targets=[alias],
|
||||||
|
value=expr,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return alias
|
||||||
|
|
||||||
|
def _add_assert(self, expr: ast.expr, message: str | ast.expr):
|
||||||
|
if isinstance(message, str):
|
||||||
|
message = ast.Constant(value=message)
|
||||||
|
self._scopes[-1].pre_assertions.append(
|
||||||
|
ast.Assert(
|
||||||
|
test=expr,
|
||||||
|
msg=message,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_expr_type(self, query: p.Expr) -> Type:
|
||||||
|
for expr, type in self._typed_ast.judgements:
|
||||||
|
if expr == query:
|
||||||
|
return type
|
||||||
|
raise RuntimeError(f"Cannot get type judgement for {query}")
|
||||||
|
|
||||||
|
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
|
||||||
|
match type:
|
||||||
|
case BaseType(name=name):
|
||||||
|
self._add_assert(
|
||||||
|
ast.Call(
|
||||||
|
func=ast.Name(id="isinstance"),
|
||||||
|
args=[expr, ast.Name(id=name)],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(src_location, expr, type),
|
||||||
|
)
|
||||||
|
|
||||||
|
case AliasType(type=base):
|
||||||
|
self._make_cast_asserts(src_location, expr, base)
|
||||||
|
|
||||||
|
case UnitType():
|
||||||
|
self._add_assert(
|
||||||
|
ast.Compare(
|
||||||
|
left=expr,
|
||||||
|
ops=[ast.Is()],
|
||||||
|
comparators=[
|
||||||
|
ast.Constant(value=None),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(src_location, expr, type),
|
||||||
|
)
|
||||||
|
|
||||||
|
case AppliedType():
|
||||||
|
self._make_cast_asserts(src_location, expr, type.body)
|
||||||
|
|
||||||
|
case (
|
||||||
|
TopType()
|
||||||
|
| Function()
|
||||||
|
| OverloadedFunction()
|
||||||
|
| ComplexType()
|
||||||
|
| ExtensionType()
|
||||||
|
| GenericType()
|
||||||
|
):
|
||||||
|
raise NotImplementedError(f"Can't make assertion for type {type}")
|
||||||
|
|
||||||
|
case TypeVar():
|
||||||
|
raise RuntimeError("Unexpected TypeVar")
|
||||||
|
|
||||||
|
def _make_cast_assert_message(
|
||||||
|
self, location: Location, expr: ast.expr, type: Type
|
||||||
|
) -> ast.expr:
|
||||||
|
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||||
|
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
|
||||||
|
return ast.JoinedStr(
|
||||||
|
values=[
|
||||||
|
ast.Constant(f"{loc_str}: CastError: Cannot cast "),
|
||||||
|
ast.FormattedValue(
|
||||||
|
value=ast.Attribute(
|
||||||
|
value=ast.Call(
|
||||||
|
func=ast.Name(id="type"),
|
||||||
|
args=[expr],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
attr="__name__",
|
||||||
|
),
|
||||||
|
conversion=-1,
|
||||||
|
),
|
||||||
|
ast.Constant(f" to {type}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ from midas.ast.python import (
|
|||||||
LiteralExpr,
|
LiteralExpr,
|
||||||
LogicalExpr,
|
LogicalExpr,
|
||||||
MidasType,
|
MidasType,
|
||||||
|
RawExpr,
|
||||||
|
RawStmt,
|
||||||
ReturnStmt,
|
ReturnStmt,
|
||||||
SliceExpr,
|
SliceExpr,
|
||||||
Stmt,
|
Stmt,
|
||||||
@@ -99,7 +101,7 @@ class PythonParser:
|
|||||||
|
|
||||||
case _:
|
case _:
|
||||||
print(f"Unsupported statement: {ast.unparse(node)}")
|
print(f"Unsupported statement: {ast.unparse(node)}")
|
||||||
return None
|
return RawStmt(location=location, stmt=node)
|
||||||
|
|
||||||
def parse_annotation_assign(self, node: ast.AnnAssign) -> list[Stmt]:
|
def parse_annotation_assign(self, node: ast.AnnAssign) -> list[Stmt]:
|
||||||
statements: list[Stmt] = []
|
statements: list[Stmt] = []
|
||||||
@@ -461,7 +463,8 @@ class PythonParser:
|
|||||||
)
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise UnsupportedSyntaxError(node)
|
print(f"Unsupported expression: {ast.unparse(node)}")
|
||||||
|
return RawExpr(location=location, expr=node)
|
||||||
|
|
||||||
def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr:
|
def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr:
|
||||||
op: ast.boolop = node.op
|
op: ast.boolop = node.op
|
||||||
|
|||||||
@@ -21,6 +21,10 @@ class Tester(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def namespace(self) -> str: ...
|
def namespace(self) -> str: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extension(self) -> str:
|
||||||
|
return "json"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def base_dir(self) -> Path:
|
def base_dir(self) -> Path:
|
||||||
return self.CASES_DIR / self.namespace
|
return self.CASES_DIR / self.namespace
|
||||||
@@ -99,7 +103,7 @@ class Tester(ABC):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def _result_path(self, test_path: Path) -> Path:
|
def _result_path(self, test_path: Path) -> Path:
|
||||||
return test_path.parent / (test_path.name + ".ref.json")
|
return test_path.parent / (test_path.name + f".ref.{self.extension}")
|
||||||
|
|
||||||
def _print_diff(self, diff: Iterator[str]):
|
def _print_diff(self, diff: Iterator[str]):
|
||||||
for line in diff:
|
for line in diff:
|
||||||
|
|||||||
14
tests/cases/generator/01_simple_types.midas
Normal file
14
tests/cases/generator/01_simple_types.midas
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
type Meter = float
|
||||||
|
type Second = float
|
||||||
|
type MeterPerSecond = float
|
||||||
|
|
||||||
|
extend Meter {
|
||||||
|
def __add__: fn(Meter, /) -> Meter
|
||||||
|
def __sub__: fn(Meter, /) -> Meter
|
||||||
|
def __truediv__: fn(Second, /) -> MeterPerSecond
|
||||||
|
}
|
||||||
|
|
||||||
|
extend Second {
|
||||||
|
def __add__: fn(Second, /) -> Second
|
||||||
|
def __sub__: fn(Second, /) -> Second
|
||||||
|
}
|
||||||
5
tests/cases/generator/01_simple_types.py
Normal file
5
tests/cases/generator/01_simple_types.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from midas import cast, Meter, Second
|
||||||
|
|
||||||
|
distance: Meter = cast(Meter, 123.45)
|
||||||
|
time: Second = cast(Second, 6.7)
|
||||||
|
speed = distance / time
|
||||||
79
tests/cases/generator/01_simple_types.py.ref.txt
Normal file
79
tests/cases/generator/01_simple_types.py.ref.txt
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
Module(
|
||||||
|
body=[
|
||||||
|
ImportFrom(
|
||||||
|
module='midas',
|
||||||
|
names=[
|
||||||
|
alias(name='cast'),
|
||||||
|
alias(name='Meter'),
|
||||||
|
alias(name='Second')],
|
||||||
|
level=0),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_alias_0__')],
|
||||||
|
value=Constant(value=123.45)),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='isinstance'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_alias_0__'),
|
||||||
|
Name(id='float')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=JoinedStr(
|
||||||
|
values=[
|
||||||
|
Constant(value='01_simple_types.py:L3:19: CastError: Cannot cast '),
|
||||||
|
FormattedValue(
|
||||||
|
value=Attribute(
|
||||||
|
value=Call(
|
||||||
|
func=Name(id='type'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_alias_0__')],
|
||||||
|
keywords=[]),
|
||||||
|
attr='__name__'),
|
||||||
|
conversion=-1),
|
||||||
|
Constant(value=' to float')])),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='distance')],
|
||||||
|
value=Name(id='__midas_alias_0__')),
|
||||||
|
Delete(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_alias_0__')]),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_alias_1__')],
|
||||||
|
value=Constant(value=6.7)),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='isinstance'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_alias_1__'),
|
||||||
|
Name(id='float')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=JoinedStr(
|
||||||
|
values=[
|
||||||
|
Constant(value='01_simple_types.py:L4:16: CastError: Cannot cast '),
|
||||||
|
FormattedValue(
|
||||||
|
value=Attribute(
|
||||||
|
value=Call(
|
||||||
|
func=Name(id='type'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_alias_1__')],
|
||||||
|
keywords=[]),
|
||||||
|
attr='__name__'),
|
||||||
|
conversion=-1),
|
||||||
|
Constant(value=' to float')])),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='time')],
|
||||||
|
value=Name(id='__midas_alias_1__')),
|
||||||
|
Delete(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_alias_1__')]),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='speed')],
|
||||||
|
value=BinOp(
|
||||||
|
left=Name(id='distance'),
|
||||||
|
op=Div(),
|
||||||
|
right=Name(id='time')))],
|
||||||
|
type_ignores=[])
|
||||||
@@ -1,5 +1,9 @@
|
|||||||
{
|
{
|
||||||
"stmts": [
|
"stmts": [
|
||||||
|
{
|
||||||
|
"_type": "RawStmt",
|
||||||
|
"stmt": "from __future__ import annotations"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"_type": "TypeAssign",
|
"_type": "TypeAssign",
|
||||||
"name": "df",
|
"name": "df",
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
{
|
{
|
||||||
"stmts": [
|
"stmts": [
|
||||||
|
{
|
||||||
|
"_type": "RawStmt",
|
||||||
|
"stmt": "from __future__ import annotations"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"_type": "TypeAssign",
|
"_type": "TypeAssign",
|
||||||
"name": "df",
|
"name": "df",
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
{
|
{
|
||||||
"stmts": [
|
"stmts": [
|
||||||
|
{
|
||||||
|
"_type": "RawStmt",
|
||||||
|
"stmt": "from __future__ import annotations"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"_type": "Function",
|
"_type": "Function",
|
||||||
"name": "func",
|
"name": "func",
|
||||||
|
|||||||
55
tests/generator.py
Normal file
55
tests/generator.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import ast
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from midas.checker.checker import TypeChecker
|
||||||
|
from midas.checker.diagnostic import DiagnosticType
|
||||||
|
from midas.generator.generator import Generator
|
||||||
|
from midas.utils import TypedAST
|
||||||
|
from tests.base import Tester
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CaseResult:
|
||||||
|
compiled_ast: ast.AST = ast.Module([], [])
|
||||||
|
|
||||||
|
def dumps(self) -> str:
|
||||||
|
return ast.dump(self.compiled_ast, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorTester(Tester):
|
||||||
|
@property
|
||||||
|
def namespace(self) -> str:
|
||||||
|
return "generator"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extension(self) -> str:
|
||||||
|
return "txt"
|
||||||
|
|
||||||
|
def _list_tests(self) -> list[Path]:
|
||||||
|
return list(self.base_dir.rglob("*.py"))
|
||||||
|
|
||||||
|
def _exec_case(self, path: Path) -> CaseResult:
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"Could not find test '{path}'")
|
||||||
|
if not path.is_file():
|
||||||
|
raise TypeError(f"Test '{path}' is not a file")
|
||||||
|
|
||||||
|
result: CaseResult = CaseResult()
|
||||||
|
|
||||||
|
checker = TypeChecker()
|
||||||
|
types_path: Path = path.with_suffix(".midas")
|
||||||
|
if types_path.exists():
|
||||||
|
checker.import_midas(types_path)
|
||||||
|
|
||||||
|
typed_ast: TypedAST = checker.type_check(path)
|
||||||
|
|
||||||
|
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
|
||||||
|
generator = Generator(workdir=path.parent)
|
||||||
|
result.compiled_ast = generator.generate_ast(typed_ast, path)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
GeneratorTester.main()
|
||||||
@@ -22,6 +22,8 @@ from midas.ast.python import (
|
|||||||
LogicalExpr,
|
LogicalExpr,
|
||||||
MidasType,
|
MidasType,
|
||||||
Pass,
|
Pass,
|
||||||
|
RawExpr,
|
||||||
|
RawStmt,
|
||||||
ReturnStmt,
|
ReturnStmt,
|
||||||
SliceExpr,
|
SliceExpr,
|
||||||
Stmt,
|
Stmt,
|
||||||
@@ -191,6 +193,12 @@ class PythonAstJsonSerializer(
|
|||||||
"body": self._serialize_list(stmt.body),
|
"body": self._serialize_list(stmt.body),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def visit_raw_stmt(self, stmt: RawStmt) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "RawStmt",
|
||||||
|
"stmt": ast.unparse(stmt.stmt),
|
||||||
|
}
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
|
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "BinaryExpr",
|
"_type": "BinaryExpr",
|
||||||
@@ -284,3 +292,9 @@ class PythonAstJsonSerializer(
|
|||||||
"upper": self._serialize_optional(expr.upper),
|
"upper": self._serialize_optional(expr.upper),
|
||||||
"step": self._serialize_optional(expr.step),
|
"step": self._serialize_optional(expr.step),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def visit_raw_expr(self, expr: RawExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "RawExpr",
|
||||||
|
"expr": ast.unparse(expr.expr),
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user