18 Commits

Author SHA1 Message Date
45894b66b0 feat(cli): generate stubs in build dir when compiling 2026-06-30 11:19:17 +02:00
16239db479 feat(gen): add tuple expr to generator 2026-06-29 22:44:39 +02:00
dc2134c87d tests: update with multi-parameter generics 2026-06-29 22:43:08 +02:00
89f3c945e4 fix: minor fixes 2026-06-29 22:41:47 +02:00
45f7d1be2b feat: add Python tuple expression 2026-06-29 14:35:31 +02:00
27f3fa7d1e feat: handle multi-parameter generic in Python 2026-06-29 14:24:38 +02:00
78eba39ae3 feat(checker): add len() 2026-06-29 14:02:52 +02:00
3b78b37306 fix(checker): allow some assignments to unknown 2026-06-29 14:02:29 +02:00
9e14b30bc9 feat(checker): add methods on str 2026-06-29 14:01:33 +02:00
a6a1075f91 feat(checker): type check tuple instantiation in Midas 2026-06-29 14:00:37 +02:00
11be47fce3 fix(parser): parse empty calls 2026-06-29 13:59:03 +02:00
2eeede9826 fix(gen): prevent empty loop for column asserts 2026-06-29 11:19:26 +02:00
f796f4c6fa fix(checker): allow iterating on unknown 2026-06-29 11:13:47 +02:00
c333735580 fix(checker): allow subtypes and unknown as if test 2026-06-29 11:06:35 +02:00
2416102494 feat(gen): assertions for column values 2026-06-29 11:05:59 +02:00
eb4971686a fix(checker): allow calling unknown method on dataframes 2026-06-29 11:01:53 +02:00
9f59366289 feat(gen): generate asserts for dataframes and columns 2026-06-26 14:56:15 +02:00
fd0b410d74 fix(checker): change heterogeneous errors to warnings 2026-06-26 11:55:31 +02:00
24 changed files with 612 additions and 184 deletions

View File

@@ -15,7 +15,7 @@ from midas.ast.location import Location
###> MidasType | Type annotations | node
class BaseType:
base: str
param: Optional[MidasType]
args: tuple[MidasType, ...]
class ConstraintType:
@@ -174,6 +174,10 @@ class SliceExpr:
step: Optional[Expr]
class TupleExpr:
items: tuple[Expr, ...]
class RawExpr:
expr: ast.expr

View File

@@ -549,7 +549,13 @@ class PythonAstPrinter(
self._write_line("BaseType")
with self._child_level():
self._write_line(f"base: {node.base}")
self._write_optional_child("param", node.param, last=True)
self._write_line("args:", last=True)
with self._child_level():
for i, arg in enumerate(node.args):
self._idx = i
if i == len(node.args) - 1:
self._mark_last()
arg.accept(self)
def visit_constraint_type(self, node: p.ConstraintType) -> None:
self._write_line("ConstraintType")
@@ -862,6 +868,17 @@ class PythonAstPrinter(
self._write_optional_child("upper", expr.upper)
self._write_optional_child("step", expr.step, last=True)
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
self._write_line("TupleExpr")
with self._child_level():
self._write_line("items", last=True)
with self._child_level():
for i, item in enumerate(expr.items):
self._idx = i
if i == len(expr.items) - 1:
self._mark_last()
item.accept(self)
def visit_raw_expr(self, expr: p.RawExpr) -> None:
self._write_line("RawExpr")
with self._child_level(single=True):

View File

@@ -44,7 +44,7 @@ class MidasType(ABC):
@dataclass(frozen=True)
class BaseType(MidasType):
base: str
param: Optional[MidasType]
args: tuple[MidasType, ...]
def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_base_type(self)
@@ -268,6 +268,9 @@ class Expr(ABC):
@abstractmethod
def visit_slice_expr(self, expr: SliceExpr) -> T: ...
@abstractmethod
def visit_tuple_expr(self, expr: TupleExpr) -> T: ...
@abstractmethod
def visit_raw_expr(self, expr: RawExpr) -> T: ...
@@ -402,6 +405,14 @@ class SliceExpr(Expr):
return visitor.visit_slice_expr(self)
@dataclass(frozen=True)
class TupleExpr(Expr):
items: tuple[Expr, ...]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_tuple_expr(self)
@dataclass(frozen=True)
class RawExpr(Expr):
expr: ast.expr

View File

@@ -178,4 +178,100 @@ extend dict[K, V] {
// def __ior__: fn(value: SupportsKeysAndGetItem[K, V], /) -> dict[K, V]
// def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V]
}
}
extend str {
def capitalize: fn() -> str
def casefold: fn() -> str
def center: fn(width: int, fillchar: str?, /) -> str
def count: fn(sub: str, start: None?, end: None?, /) -> int
def count: fn(sub: str, start: int, end: None?, /) -> int
def count: fn(sub: str, start: None, end: int, /) -> int
def count: fn(sub: str, start: int, end: int, /) -> int
def encode: fn(encoding: str?, errors: str?) -> bytes
def endswith: fn(suffix: str, start: None?, end: None?, /) -> bool
def endswith: fn(suffix: str, start: int, end: None?, /) -> bool
def endswith: fn(suffix: str, start: None, end: int, /) -> bool
def endswith: fn(suffix: str, start: int, end: int, /) -> bool
def expandtabs: fn(tabsize: int?) -> str
def find: fn(sub: str, start: None?, end: None?, /) -> int
def find: fn(sub: str, start: int, end: None?, /) -> int
def find: fn(sub: str, start: None, end: int, /) -> int
def find: fn(sub: str, start: int, end: int, /) -> int
// def format: fn(*args: object, **kwargs: object) -> str
// def format_map: fn(mapping: _FormatMapMapping, /) -> str
def index: fn(sub: str, start: None?, end: None?, /) -> int
def index: fn(sub: str, start: int, end: None?, /) -> int
def index: fn(sub: str, start: None, end: int, /) -> int
def index: fn(sub: str, start: int, end: int, /) -> int
def isalnum: fn() -> bool
def isalpha: fn() -> bool
def isascii: fn() -> bool
def isdecimal: fn() -> bool
def isdigit: fn() -> bool
def isidentifier: fn() -> bool
def islower: fn() -> bool
def isnumeric: fn() -> bool
def isprintable: fn() -> bool
def isspace: fn() -> bool
def istitle: fn() -> bool
def isupper: fn() -> bool
def join: fn(iterable: list[str], /) -> str // TODO: use Iterable
def ljust: fn(width: int, fillchar: str?, /) -> str
def lower: fn() -> str
def lstrip: fn(chars: None?, /) -> str
def lstrip: fn(chars: str, /) -> str
def partition: fn(sep: str, /) -> tuple[str, str, str]
def replace: fn(old: str, new: str, count: int?, /) -> str
def removeprefix: fn(prefix: str, /) -> str
def removesuffix: fn(suffix: str, /) -> str
def rfind: fn(sub: str, start: None?, end: None?, /) -> int
def rfind: fn(sub: str, start: int, end: None?, /) -> int
def rfind: fn(sub: str, start: None, end: int, /) -> int
def rfind: fn(sub: str, start: int, end: int, /) -> int
def rindex: fn(sub: str, start: None?, end: None?, /) -> int
def rindex: fn(sub: str, start: int, end: None?, /) -> int
def rindex: fn(sub: str, start: None, end: int, /) -> int
def rindex: fn(sub: str, start: int, end: int, /) -> int
def rjust: fn(width: int, fillchar: str?, /) -> str
def rpartition: fn(sep: str, /) -> tuple[str, str, str]
def rsplit: fn(sep: None?, maxsplit: int?) -> list[str]
def rsplit: fn(sep: str, maxsplit: int?) -> list[str]
def rstrip: fn(chars: None?, /) -> str
def rstrip: fn(chars: str, /) -> str
def split: fn(sep: None?, maxsplit: int?) -> list[str]
def split: fn(sep: str, maxsplit: int?) -> list[str]
def splitlines: fn(keepends: bool?) -> list[str]
def startswith: fn(prefix: str, start: None?, end: None?, /) -> bool
def startswith: fn(prefix: str, start: int, end: None?, /) -> bool
def startswith: fn(prefix: str, start: None, end: int, /) -> bool
def startswith: fn(prefix: str, start: int, end: int, /) -> bool
def strip: fn(chars: None?, /) -> str
def strip: fn(chars: str, /) -> str
def swapcase: fn() -> str
def title: fn() -> str
// def translate: fn(table: _TranslateTable, /) -> str
def upper: fn() -> str
def zfill: fn(width: int, /) -> str
def __add__: fn(value: str, /) -> str
// Incompatible with Sequence.__contains__
def __contains__: fn(key: str, /) -> bool
def __eq__: fn(value: object, /) -> bool
def __ge__: fn(value: str, /) -> bool
def __getitem__: fn(key: slice, /) -> str
def __getitem__: fn(key: int, /) -> str
def __gt__: fn(value: str, /) -> bool
def __hash__: fn() -> int
// def __iter__: fn() -> Iterator[str]
def __le__: fn(value: str, /) -> bool
def __len__: fn() -> int
def __lt__: fn(value: str, /) -> bool
def __mod__: fn(value: Any, /) -> str
def __mul__: fn(value: int, /) -> str
def __ne__: fn(value: object, /) -> bool
def __rmul__: fn(value: int, /) -> str
def __getnewargs__: fn() -> tuple[str]
def __format__: fn(format_spec: str, /) -> str
}

View File

@@ -15,7 +15,7 @@ if TYPE_CHECKING:
BUILTIN_SUBTYPES: dict[str, set[str]] = {
"object": {"float", "list", "dict", "str"},
"object": {"float", "list", "dict", "str", "bytes", "tuple"},
"float": {"int"},
"int": {"bool"},
}
@@ -26,12 +26,15 @@ def define_builtins(reg: TypesRegistry):
any = reg.define_type("Any", TopType())
unit = reg.define_type("None", UnitType())
object = reg.define_type("object", BaseType(name="object"))
bytes = reg.define_type("bytes", BaseType(name="bytes"))
bool = reg.define_type("bool", BaseType(name="bool"))
int = reg.define_type("int", BaseType(name="int"))
float = reg.define_type("float", BaseType(name="float"))
str = reg.define_type("str", BaseType(name="str"))
slice = reg.define_type("slice", BaseType(name="slice"))
tuple = reg.define_type("tuple", BaseType(name="tuple"))
list = reg.define_type(
"list",
GenericType(

View File

@@ -42,7 +42,7 @@ class Call:
class _MethodRegistryMeta(type):
_methods: dict[str, Callable] = {}
_methods: dict[str, Callable[..., Type]] = {}
def __new__(
cls,
@@ -55,7 +55,7 @@ class _MethodRegistryMeta(type):
for attr in namespace.values():
if callable(attr) and hasattr(attr, "__method_names__"):
for name in attr.__method_names__: # type: ignore
new_class._methods[name] = attr
new_class._methods[name] = attr # type: ignore
return new_class
@@ -76,9 +76,9 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
method: str,
call: Call,
) -> Type:
func: Optional[Callable] = self._methods.get(method)
func: Optional[Callable[..., Type]] = self._methods.get(method)
if func is None:
self.reporter.error(call.location, f"Unknown method {method}")
self.reporter.warning(call.location, f"Unknown method {method}")
return UnknownType()
return func(self, call)

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Callable, Optional
from typing import Any, Callable, Optional
from midas.checker.environment import Environment
from midas.checker.registry import TypesRegistry
@@ -17,7 +17,7 @@ class Preamble(Environment):
def __init__(self, types: TypesRegistry) -> None:
super().__init__()
self._types: TypesRegistry = types
self._python_funcs: dict[str, Callable] = {}
self._python_funcs: dict[str, Callable[..., Any]] = {}
self._def_type_constructor("object", object)
self._def_type_constructor("float", float)
@@ -34,7 +34,7 @@ class Preamble(Environment):
# TODO: use sink
self._def_function(
name="print",
pos=[Param("object", TopType())],
pos=[Param("object", TopType(), required=False)],
returns=UnitType(),
py_function=print,
)
@@ -64,11 +64,18 @@ class Preamble(Environment):
pos=[Param("prompt", TopType(), required=False)],
returns=self._types.get_type("str"),
)
self._def_function(
name="len",
pos=[Param("object", TopType())],
returns=self._types.get_type("int"),
)
def _list_of(self, item_type: Type) -> Type:
return self._types.apply_generic(self._types.get_type("list"), [item_type])
def _def_type_constructor(self, name: str, py_function: Optional[Callable] = None):
def _def_type_constructor(
self, name: str, py_function: Optional[Callable[..., Any]] = None
):
# TODO: more specific arg types
self._def_function(
name=name,
@@ -121,7 +128,7 @@ class Preamble(Environment):
kw: list[Param] = [],
returns: Type = UnitType(),
type_vars: list[TypeVar] = [],
py_function: Optional[Callable] = None,
py_function: Optional[Callable[..., Any]] = None,
):
function: Type = self._make_function(
name=name,
@@ -135,5 +142,5 @@ class Preamble(Environment):
if py_function is not None:
self._python_funcs[name] = py_function
def get_py_func(self, name: str) -> Optional[Callable]:
def get_py_func(self, name: str) -> Optional[Callable[..., Any]]:
return self._python_funcs.get(name)

View File

@@ -413,13 +413,16 @@ class PythonTyper(
value_type: Type,
):
var_type: Type = self.type_of(var)
unfolded_type: Type = unfold_type(var_type)
# TODO: what happens if type is an alias of a dataframe type
match var_type:
match unfolded_type:
case DataFrameType() as frame:
new_type: Type = self.frame_mgr.assign(
self.reporter, location, frame, index, value_type
)
self.env.assign(var.name, new_type)
case UnknownType():
return
case _:
self.reporter.error(
location,
@@ -439,8 +442,10 @@ class PythonTyper(
# print(m) # <- m is still defined
test_type: Type = self.type_of(stmt.test)
# TODO Allow subtypes or any type
if test_type != self.types.get_type("bool"):
if (
not self.types.is_subtype(test_type, self.types.get_type("bool"))
and test_type != UnknownType()
):
self.reporter.error(
stmt.test.location, f"If test must be a boolean, got {test_type}"
)
@@ -456,13 +461,16 @@ class PythonTyper(
pass
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
item_type: Optional[Type] = self._get_iterator_type(stmt.iterator)
if item_type is None:
iterator_type: Type = self.compute_type(stmt.iterator)
self.reporter.error(
stmt.iterator.location, f"{iterator_type} is not iterable"
)
item_type = UnknownType()
item_type: Type = UnknownType()
iterator_type: Type = self.type_of(stmt.iterator)
if iterator_type != UnknownType():
maybe_item_type = self._get_iterator_type(stmt.iterator, iterator_type)
if maybe_item_type is None:
self.reporter.error(
stmt.iterator.location, f"{iterator_type} is not iterable"
)
else:
item_type = maybe_item_type
self._assign(stmt.location, stmt.target, item_type)
self.judge(stmt.target, item_type)
@@ -577,7 +585,7 @@ class PythonTyper(
object: Type = self.type_of(expr.object)
member: Optional[Type] = self.types.lookup_member(object, expr.name)
if member is None:
self.reporter.error(
self.reporter.warning(
expr.location, f"Unknown member '{expr.name}' of {object}"
)
return UnknownType()
@@ -638,7 +646,10 @@ class PythonTyper(
test_type: Type = self.type_of(expr.test)
# TODO Allow subtypes or any type
if test_type != self.types.get_type("bool"):
if (
not self.is_subtype(test_type, self.types.get_type("bool"))
and test_type != UnknownType()
):
self.reporter.error(
expr.test.location, f"If test must be a boolean, got {test_type}"
)
@@ -667,7 +678,7 @@ class PythonTyper(
if len(item_types) == 1:
item_type: Type = item_types[0]
return self.types.apply_generic(list_type, [item_type])
self.reporter.error(
self.reporter.warning(
expr.location,
f"Heterogeneous list items: [{', '.join(map(str, item_types))}]",
)
@@ -699,7 +710,7 @@ class PythonTyper(
if len(key_types) == 1:
key_type = key_types[0]
else:
self.reporter.error(
self.reporter.warning(
expr.location,
f"Heterogeneous dict keys: [{', '.join(map(str, key_types))}]",
)
@@ -707,7 +718,7 @@ class PythonTyper(
if len(value_types) == 1:
value_type = value_types[0]
else:
self.reporter.error(
self.reporter.warning(
expr.location,
f"Heterogeneous dict values: [{', '.join(map(str, value_types))}]",
)
@@ -739,6 +750,11 @@ class PythonTyper(
def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
return self.types.get_type("slice")
def visit_tuple_expr(self, expr: p.TupleExpr) -> Type:
return TupleType(
items=tuple(self.type_of(item) for item in expr.items),
)
def visit_raw_expr(self, expr: p.RawExpr) -> Type:
return UnknownType()
@@ -750,9 +766,9 @@ class PythonTyper(
self.reporter.warning(node.location, f"Unknown type '{node.base}'")
return UnknownType()
if node.param is not None:
param: Type = self.resolve_type_expr(node.param)
return self.types.apply_generic(base, [param])
if len(node.args) != 0:
args: list[Type] = [self.resolve_type_expr(arg) for arg in node.args]
return self.types.apply_generic(base, args)
return base
def visit_constraint_type(self, node: p.ConstraintType) -> Type:
@@ -1150,9 +1166,8 @@ class PythonTyper(
return False
return True
def _get_iterator_type(self, expr: p.Expr) -> Optional[Type]:
def _get_iterator_type(self, expr: p.Expr, type: Type) -> Optional[Type]:
# TODO: lookup __iter__
type: Type = self.type_of(expr)
getitem: Optional[Type] = self.types.lookup_member(type, "__getitem__")
if getitem is None:
return None
@@ -1218,7 +1233,7 @@ class PythonTyper(
node: ast.Expression = ast.parse(value, mode="eval")
return parser._parse_type(node.body)
case p.VariableExpr(name=name):
return p.BaseType(location=location, base=name, param=None)
return p.BaseType(location=location, base=name, args=())
case _:
raise NotImplementedError
@@ -1306,6 +1321,12 @@ class PythonTyper(
return False
return True
case DataFrameType() | ColumnType():
self.reporter.error(
expr.location, f"Cannot cast {lit_value!r} to {target_type}"
)
return False
case _:
self.reporter.info(
expr.location, f"Cannot evaluate cast to {target_type} statically"

View File

@@ -18,6 +18,7 @@ from midas.checker.types import (
OverloadedFunction,
Predicate,
TopType,
TupleType,
Type,
TypeVar,
UnknownType,
@@ -346,6 +347,9 @@ class TypesRegistry:
body=substitute_typevars(body, substitutions),
)
case BaseType(name="tuple"):
return TupleType(items=tuple(args))
case _:
raise ValueError(f"{type} is not a generic type")

View File

@@ -236,5 +236,9 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
if expr.step is not None:
self.resolve(expr.step)
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
for item in expr.items:
self.resolve(item)
def visit_raw_expr(self, expr: p.RawExpr) -> None:
pass

View File

@@ -5,7 +5,7 @@
import sys
from pathlib import Path
from typing import TextIO
from typing import Optional, TextIO
import click
@@ -19,18 +19,23 @@ from midas.utils import TypedAST
@click.command(help="Compile source")
@click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-s", "--stubs", type=str, multiple=True)
@click.option("--ignore-errors", is_flag=True)
def compile(
file: TextIO,
types: tuple[TextIO],
stubs: tuple[str],
ignore_errors: bool,
):
source: str = file.read()
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
for types_file in types:
checker.import_midas(Path(types_file.name).resolve())
type_files: list[tuple[Path, Optional[str]]] = []
for i, types_file in enumerate(types):
in_path: Path = Path(types_file.name).resolve()
checker.import_midas(in_path)
type_files.append((in_path, stubs[i] if i < len(stubs) else None))
typed_ast: TypedAST = checker.type_check_source(source, str(source_path))
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
@@ -43,4 +48,4 @@ def compile(
sys.exit(1)
generator = Generator(workdir=source_path.parent, types=checker.types)
generator.generate(typed_ast, source_path)
generator.generate(typed_ast, source_path, type_files=type_files)

View File

@@ -1,7 +1,7 @@
import ast
import time
from pathlib import Path
from typing import TextIO
from typing import Optional, TextIO
import black
import click
@@ -38,15 +38,17 @@ class Handler(FileSystemEventHandler):
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
@click.option("-o", "--output", type=click.File("w"))
@click.option("-w", "--watch", is_flag=True)
def stubs(
file: TextIO,
output: TextIO,
output: Optional[TextIO],
watch: bool,
):
source_path: Path = Path(file.name).resolve()
out_path: Path = Path(output.name).resolve()
out_path: Path = source_path.with_suffix(".pyi")
if output is not None:
out_path = Path(output.name).resolve()
generate_stubs(source_path, out_path)
if watch:

View File

@@ -134,9 +134,9 @@ class PythonHighlighter(
def visit_base_type(self, node: p.BaseType) -> None:
self.wrap(node, "base-type")
if node.param is not None:
self.wrap(node.param, "param")
node.param.accept(self)
for arg in node.args:
self.wrap(arg, "arg")
arg.accept(self)
def visit_constraint_type(self, node: p.ConstraintType) -> None:
self.wrap(node, "constraint-type")
@@ -247,6 +247,10 @@ class PythonHighlighter(
if expr.step is not None:
expr.step.accept(self)
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
for item in expr.items:
item.accept(self)
def visit_raw_expr(self, expr: p.RawExpr) -> None: ...
def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ...

View File

@@ -3,7 +3,7 @@ span {
--col: 108, 233, 108;
}
&.param {
&.arg {
--col: 103, 192, 224;
}

View File

@@ -9,6 +9,7 @@ import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasPrinter
from midas.checker.checker import TypeChecker
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
AliasType,
@@ -30,16 +31,20 @@ from midas.checker.types import (
UnknownType,
)
from midas.generator.constraints import ConstraintGenerator
from midas.generator.stubs import StubsGenerator
from midas.utils import TypedAST
@dataclass
class Scope:
pre_assertions: list[ast.stmt] = field(default_factory=list)
aliases: list[str] = field(default_factory=list)
pre_assertions: list[ast.stmt] = field(default_factory=list[ast.stmt])
aliases: list[str] = field(default_factory=list[str])
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
IS_DATAFRAME_FUNC = "__midas_is_dataframe__"
IS_COLUMN_FUNC = "__midas_is_column__"
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
self.workdir: Path = workdir.resolve()
self.build_dir: Path = self.workdir / "build" / "midas"
@@ -58,20 +63,37 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
self._constraints: list[tuple[m.Expr, ast.expr]] = []
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
self.rel_src_path = src_path.resolve().relative_to(self.workdir)
self.define_is_dataframe: bool = False
self.define_is_column: bool = False
def set_src_path(self, path: Path):
self.rel_src_path = path.resolve().relative_to(self.workdir)
def generate_ast(self, typed_ast: TypedAST) -> ast.AST:
self._typed_ast = typed_ast
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
module = ast.Module(body=predicates + body, type_ignores=[])
body = predicates + body
if self.define_is_dataframe:
body = [self._is_dataframe_definition()] + body
if self.define_is_column:
body = [self._is_column_definition()] + body
module = ast.Module(body=body, type_ignores=[])
module = ast.fix_missing_locations(module)
return module
def generate(
self, typed_ast: TypedAST, src_path: Path, out_path: Optional[Path] = None
self,
typed_ast: TypedAST,
src_path: Path,
out_path: Optional[Path] = None,
type_files: Optional[list[tuple[Path, Optional[str]]]] = None,
) -> Path:
module: ast.AST = self.generate_ast(typed_ast, src_path)
compiled: str = ast.unparse(module)
self.set_src_path(src_path)
if out_path is None:
if self.build_dir.exists():
shutil.rmtree(self.build_dir)
@@ -83,10 +105,30 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
raise ValueError(
f"Directory traversal, {self.rel_src_path} points outside of parent directory"
)
out_path.parent.mkdir(parents=True, exist_ok=True)
out_dir: Path = out_path.parent
out_dir.parent.mkdir(parents=True, exist_ok=True)
if type_files is not None:
for in_path, out_name in type_files:
if out_name is None:
out_name = in_path.stem
self.generate_stubs(in_path, out_dir / f"{out_name}.py")
module: ast.AST = self.generate_ast(typed_ast)
compiled: str = ast.unparse(module)
out_path.write_text(compiled)
return out_path
def generate_stubs(self, in_path: Path, out_path: Path):
checker = TypeChecker()
checker.import_midas(in_path)
generator = StubsGenerator(checker.types)
module: ast.Module = generator.generate_stubs()
module = ast.fix_missing_locations(module)
output: str = ast.unparse(module)
out_path.write_text(output)
def visit_binary_expr(self, expr: p.BinaryExpr) -> ast.expr:
return ast.BinOp(
left=expr.left.accept(self),
@@ -144,7 +186,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
alias: ast.expr = self._make_alias(expr2)
type: Type = self._get_expr_type(expr)
self._make_cast_asserts(expr.location, alias, type)
asserts: list[ast.stmt] = self._make_cast_asserts(expr.location, alias, type)
for assert_ in asserts:
self._add_assert(assert_)
return alias
@@ -179,6 +223,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
step=expr.step.accept(self) if expr.step is not None else None,
)
def visit_tuple_expr(self, expr: p.TupleExpr) -> ast.expr:
return ast.Tuple(
elts=[item.accept(self) for item in expr.items],
)
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
return expr.expr
@@ -279,76 +328,156 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
)
return alias
def _add_assert(self, expr: ast.expr, message: str | ast.expr):
def _build_assert(self, expr: ast.expr, message: str | ast.expr) -> ast.stmt:
if isinstance(message, str):
message = ast.Constant(value=message)
self._scopes[-1].pre_assertions.append(
ast.Assert(
test=expr,
msg=message,
)
return ast.Assert(
test=expr,
msg=message,
)
def _add_assert(self, assertion: ast.stmt):
self._scopes[-1].pre_assertions.append(assertion)
def _get_expr_type(self, query: p.Expr) -> Type:
for expr, type in self._typed_ast.judgements:
if expr == query:
return type
raise RuntimeError(f"Cannot get type judgement for {query}")
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
def _make_cast_asserts(
self, src_location: Location, expr: ast.expr, type: Type
) -> list[ast.stmt]:
match type:
case UnknownType():
pass
return []
case BaseType(name=name):
self._add_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id=name)],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
)
return [
self._build_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id=name)],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
)
]
case AliasType(type=base):
self._make_cast_asserts(src_location, expr, base)
return self._make_cast_asserts(src_location, expr, base)
case UnitType():
self._add_assert(
ast.Compare(
left=expr,
ops=[ast.Is()],
comparators=[
ast.Constant(value=None),
],
return [
self._build_assert(
ast.Compare(
left=expr,
ops=[ast.Is()],
comparators=[
ast.Constant(value=None),
],
),
self._make_cast_assert_message(src_location, expr, type),
),
self._make_cast_assert_message(src_location, expr, type),
)
]
case AppliedType(body=body):
self._make_cast_asserts(src_location, expr, body)
return self._make_cast_asserts(src_location, expr, body)
case ConstraintType(type=base, constraint=constraint):
self._make_cast_asserts(src_location, expr, base)
self._make_constraint_assert(src_location, expr, constraint)
asserts: list[ast.stmt] = self._make_cast_asserts(
src_location, expr, base
)
asserts.append(
self._make_constraint_assert(src_location, expr, constraint)
)
return asserts
case TypeVar(bound=bound):
# TODO: check with type from arguments / use call-site context
if bound is not None:
self._make_cast_asserts(src_location, expr, bound)
if bound is None:
return []
return self._make_cast_asserts(src_location, expr, bound)
case TupleType(items=items):
self._add_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id="tuple")],
keywords=[],
asserts: list[ast.stmt] = [
self._build_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id="tuple")],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
),
self._make_cast_assert_message(src_location, expr, type),
)
]
assert isinstance(expr, ast.Tuple)
for item, item_type in zip(expr.elts, items):
self._make_cast_asserts(src_location, item, item_type)
asserts.extend(
self._make_cast_asserts(src_location, item, item_type)
)
return asserts
case DataFrameType(columns=columns):
self.define_is_dataframe = True
asserts: list[ast.stmt] = [
self._build_assert(
ast.Call(
func=ast.Name(id=self.IS_DATAFRAME_FUNC),
args=[expr],
keywords=[],
),
self._make_cast_assert_message(
src_location, expr, type, ": Not a dataframe"
),
),
]
for column in columns:
asserts.append(
self._build_assert(
ast.Compare(
left=ast.Constant(value=column.name),
ops=[ast.In()],
comparators=[expr],
),
self._make_cast_assert_message(
src_location,
expr,
type,
f": Missing column {column.name}",
),
)
)
asserts.extend(
self._make_cast_asserts(
src_location,
ast.Subscript(
value=expr, slice=ast.Constant(value=column.name)
),
column.type,
)
)
return asserts
case ColumnType():
self.define_is_column = True
asserts: list[ast.stmt] = [
self._build_assert(
ast.Call(
func=ast.Name(id=self.IS_COLUMN_FUNC),
args=[expr],
keywords=[],
),
self._make_cast_assert_message(
src_location, expr, type, ": Not a column"
),
),
]
inner_assert: Optional[ast.stmt] = self._make_column_inner_assert(
src_location, expr, type
)
if inner_assert is not None:
asserts.append(inner_assert)
return asserts
case (
TopType()
@@ -357,17 +486,20 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
| ComplexType()
| ExtensionType()
| GenericType()
| ColumnType()
| DataFrameType()
):
self.logger.warning(f"Can't make assertion for type {type}")
return []
# Ensure exhaustiveness
case _:
assert_never(type)
def _make_cast_assert_message(
self, location: Location, expr: ast.expr, type: Type
self,
location: Location,
expr: ast.expr,
type: Type,
extra: Optional[str] = None,
) -> ast.expr:
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
@@ -385,15 +517,15 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
),
conversion=-1,
),
ast.Constant(f" to {type}"),
ast.Constant(f" to {type}{extra or ''}"),
]
)
def _make_constraint_assert(
self, src_location: Location, expr: ast.expr, constraint: m.Expr
):
) -> ast.stmt:
test_func: ast.expr = self._get_constraint(constraint)
self._add_assert(
return self._build_assert(
ast.Call(
func=test_func,
args=[expr],
@@ -421,3 +553,90 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
constraint: ast.expr = self._constraint_generator.generate(expr)
self._constraints.append((expr, constraint))
return constraint
def _is_dataframe_definition(self) -> ast.stmt:
"""
def IS_DATAFRAME_FUNC(obj) -> bool:
import pandas as pd
return isinstance(obj, pd.DataFrame)
"""
return ast.FunctionDef(
name=self.IS_DATAFRAME_FUNC,
args=ast.arguments(
posonlyargs=[ast.arg(arg="obj")],
args=[],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
ast.Return(
value=ast.Call(
func=ast.Name(id="isinstance"),
args=[
ast.Name(id="obj"),
ast.Attribute(
value=ast.Name(id="pd"),
attr="DataFrame",
),
],
keywords=[],
)
),
],
decorator_list=[],
returns=ast.Name(id="bool"),
)
def _is_column_definition(self) -> ast.stmt:
"""
def IS_COLUMN_FUNC(obj) -> bool:
import pandas as pd
return isinstance(obj, pd.Series)
"""
return ast.FunctionDef(
name=self.IS_COLUMN_FUNC,
args=ast.arguments(
posonlyargs=[ast.arg(arg="obj")],
args=[],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
ast.Return(
value=ast.Call(
func=ast.Name(id="isinstance"),
args=[
ast.Name(id="obj"),
ast.Attribute(
value=ast.Name(id="pd"),
attr="Series",
),
],
keywords=[],
)
),
],
decorator_list=[],
returns=ast.Name(id="bool"),
)
def _make_column_inner_assert(
self, src_location: Location, column: ast.expr, type: ColumnType
) -> Optional[ast.stmt]:
# TODO: improve message, maybe chain contexts
col: ast.expr = ast.Name(id="col")
body: list[ast.stmt] = self._make_cast_asserts(src_location, col, type.type)
if len(body) == 0:
return None
return ast.For(
target=col,
iter=column,
body=body,
orelse=[],
)

View File

@@ -377,7 +377,7 @@ class MidasParser(Parser):
pos_args: list[Expr] = []
kw_args: dict[str, Expr] = {}
keywords: bool = False
while not self.match(TokenType.RIGHT_PAREN):
while not self.check(TokenType.RIGHT_PAREN):
if self.check_identifier() and self.check_next(TokenType.EQUAL):
keywords = True
keyword: Token = self.advance()

View File

@@ -30,6 +30,7 @@ from midas.ast.python import (
Stmt,
SubscriptExpr,
TernaryExpr,
TupleExpr,
TypeAssign,
UnaryExpr,
VariableExpr,
@@ -300,26 +301,28 @@ class PythonParser:
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
return self._parse_frame_type(schema)
case ast.Subscript(value=ast.Name(id=name), slice=param):
case ast.Subscript(value=ast.Name(id=name), slice=arg):
args: tuple[MidasType, ...] = (
tuple(self._parse_type(a) for a in arg.elts)
if isinstance(arg, ast.Tuple)
else (self._parse_type(arg),)
)
return BaseType(
location=loc,
base=name,
param=self._parse_type(param),
args=args,
)
case ast.Name(id=name):
return BaseType(
location=loc,
base=name,
param=None,
args=(),
)
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
left = self._parse_type(left_expr)
match left:
case None:
raise InvalidSyntaxError()
# If chained constraints, separate base type and rebuild constraint
case ConstraintType(type=left_type, constraint=left_constraint):
constraint = ast.BinOp(
@@ -345,7 +348,7 @@ class PythonParser:
return BaseType(
location=loc,
base="None",
param=None,
args=(),
)
case _:
@@ -477,6 +480,12 @@ class PythonParser:
step=self.parse_expr(step) if step is not None else None,
)
case ast.Tuple(elts=items):
return TupleExpr(
location=location,
items=tuple(self.parse_expr(item) for item in items),
)
case _:
print(f"Unsupported expression: {ast.unparse(node)}")
return RawExpr(location=location, expr=node)

View File

@@ -24,7 +24,7 @@
"type": {
"_type": "BaseType",
"base": "Meter",
"param": null
"args": []
},
"expr": {
"_type": "LiteralExpr",
@@ -62,7 +62,7 @@
"type": {
"_type": "BaseType",
"base": "Second",
"param": null
"args": []
},
"expr": {
"_type": "LiteralExpr",

View File

@@ -317,7 +317,7 @@
"pos": 0,
"name": "object",
"type": {},
"required": true
"required": false
}
],
"args": [],

View File

@@ -16,7 +16,7 @@
"type": {
"_type": "BaseType",
"base": "bool",
"param": null
"args": []
}
},
{
@@ -25,7 +25,7 @@
"type": {
"_type": "BaseType",
"base": "int",
"param": null
"args": []
}
},
{
@@ -36,7 +36,7 @@
"type": {
"_type": "BaseType",
"base": "float",
"param": null
"args": []
},
"constraint": "(_ > 0) + (_ < 250)"
}
@@ -47,7 +47,7 @@
"type": {
"_type": "BaseType",
"base": "str",
"param": null
"args": []
}
},
{
@@ -56,7 +56,7 @@
"type": {
"_type": "BaseType",
"base": "datetime",
"param": null
"args": []
}
},
{
@@ -65,7 +65,7 @@
"type": {
"_type": "BaseType",
"base": "float",
"param": null
"args": []
}
},
{
@@ -79,7 +79,7 @@
"type": {
"_type": "BaseType",
"base": "_",
"param": null
"args": []
}
}
]

View File

@@ -16,7 +16,7 @@
"type": {
"_type": "BaseType",
"base": "GeoLocation",
"param": null
"args": []
}
}
]
@@ -28,11 +28,13 @@
"type": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "BaseType",
"base": "GeoLocation",
"param": null
}
"args": [
{
"_type": "BaseType",
"base": "GeoLocation",
"args": []
}
]
}
},
{
@@ -65,11 +67,13 @@
"type": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "BaseType",
"base": "GeoLocation",
"param": null
}
"args": [
{
"_type": "BaseType",
"base": "GeoLocation",
"args": []
}
]
}
},
{
@@ -117,7 +121,7 @@
"type": {
"_type": "BaseType",
"base": "Latitude",
"param": null
"args": []
}
},
{
@@ -146,7 +150,7 @@
"type": {
"_type": "BaseType",
"base": "Latitude",
"param": null
"args": []
}
},
{
@@ -175,11 +179,13 @@
"type": {
"_type": "BaseType",
"base": "Difference",
"param": {
"_type": "BaseType",
"base": "Latitude",
"param": null
}
"args": [
{
"_type": "BaseType",
"base": "Latitude",
"args": []
}
]
}
},
{
@@ -217,7 +223,7 @@
"type": {
"_type": "BaseType",
"base": "int",
"param": null
"args": []
},
"constraint": "_ >= 0"
}
@@ -230,7 +236,7 @@
"type": {
"_type": "BaseType",
"base": "float",
"param": null
"args": []
},
"constraint": "_ >= 0"
}
@@ -252,7 +258,7 @@
"type": {
"_type": "BaseType",
"base": "int",
"param": null
"args": []
},
"constraint": "Positive"
}
@@ -265,7 +271,7 @@
"type": {
"_type": "BaseType",
"base": "float",
"param": null
"args": []
},
"constraint": "Positive"
}

View File

@@ -14,15 +14,17 @@
"type": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "0 <= _ <= 1"
}
"args": [
{
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"args": []
},
"constraint": "0 <= _ <= 1"
}
]
},
"default": null
},
@@ -31,15 +33,17 @@
"type": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "0 <= _ <= 1"
}
"args": [
{
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"args": []
},
"constraint": "0 <= _ <= 1"
}
]
},
"default": null
}
@@ -50,15 +54,17 @@
"returns": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "0 <= _ <= 2"
}
"args": [
{
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"args": []
},
"constraint": "0 <= _ <= 2"
}
]
},
"body": [
{
@@ -67,15 +73,17 @@
"type": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "0 <= _ <= 2"
}
"args": [
{
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"args": []
},
"constraint": "0 <= _ <= 2"
}
]
}
},
{
@@ -117,7 +125,7 @@
"type": {
"_type": "BaseType",
"base": "int",
"param": null
"args": []
},
"default": null
}
@@ -128,7 +136,7 @@
"type": {
"_type": "BaseType",
"base": "float",
"param": null
"args": []
},
"default": null
}
@@ -140,7 +148,7 @@
"type": {
"_type": "BaseType",
"base": "str",
"param": null
"args": []
},
"default": null
}

View File

@@ -46,7 +46,8 @@ class GeneratorTester(Tester):
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
generator = Generator(workdir=path.parent, types=checker.types)
result.compiled_ast = generator.generate_ast(typed_ast, path)
generator.set_src_path(path)
result.compiled_ast = generator.generate_ast(typed_ast)
return result

View File

@@ -30,6 +30,7 @@ from midas.ast.python import (
Stmt,
SubscriptExpr,
TernaryExpr,
TupleExpr,
TypeAssign,
UnaryExpr,
VariableExpr,
@@ -98,7 +99,7 @@ class PythonAstJsonSerializer(
return {
"_type": "BaseType",
"base": node.base,
"param": self._serialize_optional(node.param),
"args": self._serialize_list(node.args),
}
def visit_constraint_type(self, node: ConstraintType) -> dict:
@@ -302,6 +303,12 @@ class PythonAstJsonSerializer(
"step": self._serialize_optional(expr.step),
}
def visit_tuple_expr(self, expr: TupleExpr) -> dict:
return {
"_type": "TupleExpr",
"items": [item.accept(self) for item in expr.items],
}
def visit_raw_expr(self, expr: RawExpr) -> dict:
return {
"_type": "RawExpr",