Merge pull request 'Frame / column operations' (#27) from feat/simple-frame-ops into main
Reviewed-on: #27
This commit was merged in pull request #27.
This commit is contained in:
@@ -678,6 +678,10 @@ In the following example, a runtime check would be generated to ensure that the
|
||||
caption: [Typing of `cast` expression],
|
||||
)
|
||||
|
||||
#gc.warning[
|
||||
Assertions are statements inserted just before a statement using a `cast` expression. This means that the expression is evaluated _before_ its actual intended usage location, which might cause issues if you rely on logical operator short-circuiting. See @eager-eval for more information.
|
||||
]
|
||||
|
||||
There may be some cases where the cost of checking a value at runtime is simply not worth the safety, for example when dealing with a big dataset. If do wish so, you can use `unsafe_cast` which will only tell the type checker the type of the value, without generating a runtime assertion. This maps to the default behavior of `typing`'s own `cast` function.
|
||||
|
||||
If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a string, a list of literals, etc.), the assertion is evaluated _at compile-time_ and no runtime assertion is generated.
|
||||
@@ -695,3 +699,26 @@ If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a
|
||||
== Generating Stubs (`stubs`) <cmd-stubs>
|
||||
== Showing Type Judgements (`types`) <cmd-types>
|
||||
== Validating Definitions (`validate`) <cmd-validate>
|
||||
|
||||
= Known limitations <limitations>
|
||||
|
||||
== Eager evaluation in runtime assertions <eager-eval>
|
||||
|
||||
The process of generating assertions to ensure safety at runtime, mainly for `cast` expressions, leads to the creation of aliases for the expressions being casted. These alias definitions eagerly evaluate before the assertion, and most importantly before the real usage location. This means that you should avoid using `cast` expressions inside logical expressions like `and` or `or`, because the normal "short-circuit" behavior will be irrelevant to the evaluations of the operands.
|
||||
|
||||
For example:
|
||||
|
||||
#figure(
|
||||
```py
|
||||
def foo():
|
||||
print("Foo")
|
||||
return True
|
||||
def bar():
|
||||
print("Bar")
|
||||
return True
|
||||
result = foo() or bar()
|
||||
# Foo
|
||||
# Bar
|
||||
```,
|
||||
caption: [Runtime assertions may eagerly evaluate expressions and bypass logical operator's short-circuit],
|
||||
)
|
||||
|
||||
@@ -14,6 +14,8 @@ if TYPE_CHECKING:
|
||||
from midas.checker.registry import TypesRegistry
|
||||
|
||||
|
||||
# Hard-coded subtype relationships between builtin types
|
||||
# Circular dependencies and diamond inheritance MUST be avoided
|
||||
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||
"object": {"float", "list", "dict", "str", "bytes", "tuple"},
|
||||
"float": {"int"},
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallDispatcher, CallResult
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import (
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
Function,
|
||||
OverloadedFunction,
|
||||
TopType,
|
||||
Type,
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
|
||||
|
||||
@staticmethod
|
||||
def frame_method(*names: str):
|
||||
def wrapper(func):
|
||||
names_: tuple[str, ...] = names
|
||||
if len(names_) == 0:
|
||||
names_ = (func.__name__,)
|
||||
setattr(func, "__method_names__", names_)
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
frame: DataFrameType
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
|
||||
class _MethodRegistryMeta(type):
|
||||
_methods: dict[str, Callable[..., Type]] = {}
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
name: str,
|
||||
bases: tuple[type, ...],
|
||||
namespace: dict[str, Any],
|
||||
):
|
||||
new_class = super().__new__(cls, name, bases, namespace)
|
||||
new_class._methods = {}
|
||||
for attr in namespace.values():
|
||||
if callable(attr) and hasattr(attr, "__method_names__"):
|
||||
for name in attr.__method_names__: # type: ignore
|
||||
new_class._methods[name] = attr # type: ignore
|
||||
return new_class
|
||||
|
||||
|
||||
class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
|
||||
@property
|
||||
def reporter(self) -> FileReporter:
|
||||
return self.typer.reporter
|
||||
|
||||
@property
|
||||
def types(self) -> TypesRegistry:
|
||||
return self.typer.types
|
||||
|
||||
@property
|
||||
def dispatcher(self) -> CallDispatcher[p.Expr]:
|
||||
return self.typer.dispatcher
|
||||
|
||||
def call(
|
||||
self,
|
||||
method: str,
|
||||
call: Call,
|
||||
) -> Type:
|
||||
func: Optional[Callable[..., Type]] = self._methods.get(method)
|
||||
if func is None:
|
||||
self.reporter.warning(call.location, f"Unknown method {method}")
|
||||
return UnknownType()
|
||||
return func(self, call)
|
||||
|
||||
@frame_method("add", "__add__")
|
||||
def add(
|
||||
self,
|
||||
call: Call,
|
||||
) -> Type:
|
||||
# TODO: support add with scalar, sequence, Series, dict
|
||||
# TODO: check operation exists on inner column types
|
||||
|
||||
new_columns: list[DataFrameType.Column] = []
|
||||
|
||||
by_name: dict[str, DataFrameType.Column] = {}
|
||||
frame2: Optional[DataFrameType] = None
|
||||
if len(call.positional) != 0:
|
||||
other: Type = call.positional[0][1]
|
||||
unfolded_other: Type = unfold_type(other)
|
||||
if isinstance(unfolded_other, DataFrameType):
|
||||
frame2 = unfolded_other
|
||||
by_name = {
|
||||
col.name: col for col in frame2.columns if col.name is not None
|
||||
}
|
||||
|
||||
in_frame1: set[str] = set()
|
||||
for column in call.frame.columns:
|
||||
if column.name is not None:
|
||||
in_frame1.add(column.name)
|
||||
|
||||
col_type1: Type = column.type
|
||||
col_type: Type = ColumnType(type=UnknownType())
|
||||
if column.name in by_name:
|
||||
column2 = by_name[column.name]
|
||||
col_type2: Type = column2.type
|
||||
if self.types.are_equivalent(col_type2, col_type1):
|
||||
col_type = col_type1
|
||||
|
||||
new_column = DataFrameType.Column(
|
||||
index=column.index,
|
||||
name=column.name,
|
||||
type=col_type,
|
||||
)
|
||||
new_columns.append(new_column)
|
||||
|
||||
if frame2 is not None:
|
||||
for column in frame2.columns:
|
||||
if column.name in in_frame1:
|
||||
continue
|
||||
new_columns.append(
|
||||
DataFrameType.Column(
|
||||
index=len(new_columns),
|
||||
name=column.name,
|
||||
type=ColumnType(type=UnknownType()),
|
||||
)
|
||||
)
|
||||
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="other",
|
||||
type=DataFrameType(columns=[]),
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
returns=DataFrameType(columns=new_columns),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@frame_method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
with_axis = Function(
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
returns=ColumnType(type=TopType()),
|
||||
)
|
||||
without_axis = Function(
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=self.types.get_type("None"),
|
||||
required=True,
|
||||
)
|
||||
],
|
||||
returns=TopType(),
|
||||
)
|
||||
overload = OverloadedFunction(
|
||||
overloads=[
|
||||
with_axis,
|
||||
without_axis,
|
||||
]
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=overload,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
203
midas/checker/frames/column_groupby_methods.py
Normal file
203
midas/checker/frames/column_groupby_methods.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import ColumnGroupBy, ColumnType, Function, TopType, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
groupby: ColumnGroupBy
|
||||
groupby_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.groupby_expr, self.groupby)
|
||||
|
||||
|
||||
class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
NAMED_ARGS: dict[str, str] = {
|
||||
"numeric_only": "bool",
|
||||
"skipna": "bool",
|
||||
"engine": "str",
|
||||
"engine_kwargs": "dict",
|
||||
}
|
||||
|
||||
def _aggregate(
|
||||
self,
|
||||
call: Call,
|
||||
args: list[str | tuple[str, str, bool]] = [],
|
||||
*,
|
||||
preserve_inner_type: bool = False,
|
||||
) -> Type:
|
||||
real_args: list[Function.Argument] = []
|
||||
for i, arg in enumerate(args):
|
||||
match arg:
|
||||
case str() as name:
|
||||
arg = Function.Argument(
|
||||
pos=i,
|
||||
name=name,
|
||||
type=self.types.get_type(self.NAMED_ARGS[name]),
|
||||
required=False,
|
||||
)
|
||||
case (name, type, required):
|
||||
arg = Function.Argument(
|
||||
pos=i,
|
||||
name=name,
|
||||
type=self.types.get_type(type),
|
||||
required=required,
|
||||
)
|
||||
real_args.append(arg)
|
||||
|
||||
signature = Function(
|
||||
args=real_args,
|
||||
returns=(
|
||||
call.groupby.column
|
||||
if preserve_inner_type
|
||||
else ColumnType(type=TopType())
|
||||
),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def kurt(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["skipna", "numeric_only"],
|
||||
)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["numeric_only", "skipna", "engine", "engine_kwargs"],
|
||||
)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["numeric_only", "skipna"],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def prod(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
(
|
||||
"ddof",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
"numeric_only",
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
(
|
||||
"var",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
"numeric_only",
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
78
midas/checker/frames/column_manager.py
Normal file
78
midas/checker/frames/column_manager.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frames.column_groupby_methods import Call as GroupByCall
|
||||
from midas.checker.frames.column_groupby_methods import ColumnGroupByMethodRegistry
|
||||
from midas.checker.frames.column_methods import Call, ColumnMethodRegistry
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import ColumnGroupBy, ColumnType, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
|
||||
|
||||
class ColumnManager:
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
self.method_resolver: ColumnMethodRegistry = ColumnMethodRegistry(self.typer)
|
||||
self.groupby_method_resolver: ColumnGroupByMethodRegistry = (
|
||||
ColumnGroupByMethodRegistry(self.typer)
|
||||
)
|
||||
|
||||
def call(
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
column: ColumnType,
|
||||
column_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: Call = Call(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
column=column,
|
||||
column_expr=column_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.method_resolver.call(method, call)
|
||||
|
||||
def groupby_call(
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
groupby: ColumnGroupBy,
|
||||
groupby_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: GroupByCall = GroupByCall(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
groupby=groupby,
|
||||
groupby_expr=groupby_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.groupby_method_resolver.call(method, call)
|
||||
|
||||
def get_attribute(self, column: ColumnType, name: str) -> Optional[Type]:
|
||||
types: TypesRegistry = self.typer.types
|
||||
match name:
|
||||
case "ndim" | "size":
|
||||
return types.get_type("int")
|
||||
|
||||
case "shape":
|
||||
return types.tuple_of("int")
|
||||
|
||||
case "T":
|
||||
return column
|
||||
|
||||
case _:
|
||||
return None
|
||||
410
midas/checker/frames/column_methods.py
Normal file
410
midas/checker/frames/column_methods.py
Normal file
@@ -0,0 +1,410 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
Function,
|
||||
GenericType,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
column: ColumnType
|
||||
column_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.column_expr, self.column)
|
||||
|
||||
|
||||
class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
def _element_binary_op(self, call: Call, method: str) -> ColumnType:
|
||||
"""Compute the result of an element-wise binary operation
|
||||
|
||||
This function delegates to the inner types for computing the resulting
|
||||
type.
|
||||
|
||||
Args:
|
||||
call (Call): the call that triggered this resolution
|
||||
method (str): the method name
|
||||
|
||||
Returns:
|
||||
ColumnType: the resulting column type
|
||||
"""
|
||||
column2: Optional[ColumnType] = None
|
||||
|
||||
col_type1: Type = call.column.type
|
||||
new_column: Type = ColumnType(type=UnknownType())
|
||||
if len(call.positional) != 0:
|
||||
other: Type = call.positional[0][1]
|
||||
unfolded_other: Type = unfold_type(other)
|
||||
if isinstance(unfolded_other, ColumnType):
|
||||
column2 = unfolded_other
|
||||
col_type2: Type = column2.type
|
||||
|
||||
new_inner_type = self.typer.result_of_binary_op(
|
||||
location=call.location,
|
||||
expr=call.call_expr,
|
||||
left=(call.column_expr, col_type1),
|
||||
right=(call.positional[0][0], col_type2),
|
||||
method=method,
|
||||
)
|
||||
new_column = ColumnType(type=new_inner_type)
|
||||
return new_column
|
||||
|
||||
def _element_wise(self, call: Call, method: str) -> Type:
|
||||
# TODO: support add with scalar
|
||||
|
||||
# Build signature with new column type and generic operand
|
||||
param_type: TypeVar = TypeVar(name="T", bound=None)
|
||||
signature = GenericType(
|
||||
name="add",
|
||||
params=[param_type],
|
||||
body=Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="other",
|
||||
type=ColumnType(type=param_type),
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
returns=self._element_binary_op(call, method),
|
||||
),
|
||||
)
|
||||
|
||||
# Map arguments and compute result type
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
if result.is_valid:
|
||||
self._assert_same_length(
|
||||
call.call_expr, call.column_expr, call.positional[0][0]
|
||||
)
|
||||
|
||||
return result.result
|
||||
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__add__")
|
||||
|
||||
@method("sub", "__sub__")
|
||||
def sub(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__sub__")
|
||||
|
||||
@method("mul", "__mul__")
|
||||
def mul(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mul__")
|
||||
|
||||
@method("div", "truediv", "__truediv__")
|
||||
def truediv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__truediv__")
|
||||
|
||||
@method("floordiv", "__floordiv__")
|
||||
def floordiv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__floordiv__")
|
||||
|
||||
@method("mod", "__mod__")
|
||||
def mod(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mod__")
|
||||
|
||||
@method("pow", "__pow__")
|
||||
def pow(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__pow__")
|
||||
|
||||
@method("lt", "__lt__")
|
||||
def lt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__lt__")
|
||||
|
||||
@method("gt", "__gt__")
|
||||
def gt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__gt__")
|
||||
|
||||
@method("le", "__le__")
|
||||
def le(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__le__")
|
||||
|
||||
@method("ge", "__ge__")
|
||||
def ge(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ge__")
|
||||
|
||||
@method("ne", "__ne__")
|
||||
def ne(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ne__")
|
||||
|
||||
@method("eq", "__eq__")
|
||||
def eq(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__eq__")
|
||||
|
||||
def _aggregate(
|
||||
self,
|
||||
call: Call,
|
||||
kwargs: list[Function.Argument] = [],
|
||||
*,
|
||||
preserve_inner_type: bool = False,
|
||||
) -> Type:
|
||||
signature = Function(
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
returns=call.column if preserve_inner_type else ColumnType(type=TopType()),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method("kurtosis", "kurt")
|
||||
def kurtosis(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def mode(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method("product", "prod")
|
||||
def product(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="ddof",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="var",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def head(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=call.column,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def tail(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=call.column,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def groupby(self, call: Call) -> Type:
|
||||
bool_: Type = self.types.get_type("bool")
|
||||
function: Function = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="by",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="level",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=2,
|
||||
name="as_index",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=3,
|
||||
name="sort",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=4,
|
||||
name="group_keys",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=5,
|
||||
name="observed",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=6,
|
||||
name="dropna",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=ColumnGroupBy(column=call.column),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=function,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def _assert_same_length(self, call_expr: p.Expr, column1: p.Expr, column2: p.Expr):
|
||||
func_name: str = "__midas_column_same_length__"
|
||||
|
||||
# Efficiently compute length
|
||||
# https://stackoverflow.com/a/15943975/11109181
|
||||
def len_of_col(col: ast.expr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=ast.Name(id="len"),
|
||||
args=[
|
||||
ast.Attribute(
|
||||
value=col,
|
||||
attr="index",
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
|
||||
self.assertions.define(
|
||||
func_name,
|
||||
ast.FunctionDef(
|
||||
name=func_name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
ast.arg(arg="column1"),
|
||||
ast.arg(arg="column2"),
|
||||
],
|
||||
kwonlyargs=[],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
),
|
||||
body=[
|
||||
ast.Return(
|
||||
value=ast.Compare(
|
||||
left=len_of_col(ast.Name(id="column1")),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[
|
||||
len_of_col(ast.Name(id="column2")),
|
||||
],
|
||||
)
|
||||
)
|
||||
],
|
||||
decorator_list=[],
|
||||
),
|
||||
)
|
||||
self.assertions.add(
|
||||
bound_expr=call_expr,
|
||||
inputs=[column1, column2],
|
||||
builder=lambda c1, c2: ast.Call(
|
||||
func=ast.Name(id=func_name),
|
||||
args=[c1, c2],
|
||||
keywords=[],
|
||||
),
|
||||
message="Columns must have the same length",
|
||||
)
|
||||
103
midas/checker/frames/frame_groupby_methods.py
Normal file
103
midas/checker/frames/frame_groupby_methods.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
FrameGroupBy,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
groupby: FrameGroupBy
|
||||
groupby_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.groupby_expr, self.groupby)
|
||||
|
||||
|
||||
class FrameGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
NAMED_ARGS: dict[str, str] = {
|
||||
"numeric_only": "bool",
|
||||
"skipna": "bool",
|
||||
"engine": "str",
|
||||
"engine_kwargs": "dict",
|
||||
}
|
||||
|
||||
def _aggregate(self, call: Call, method: str) -> Type:
|
||||
new_columns: list[DataFrameType.Column] = []
|
||||
|
||||
for column in call.groupby.frame.columns:
|
||||
column_groupby: ColumnGroupBy = ColumnGroupBy(column=column.type)
|
||||
result_type: Type = self.typer.call_method(
|
||||
location=call.location,
|
||||
call_expr=call.call_expr,
|
||||
obj=(call.groupby_expr, column_groupby),
|
||||
method_name=method,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
if not isinstance(result_type, ColumnType):
|
||||
result_type = ColumnType(type=UnknownType())
|
||||
new_columns.append(
|
||||
DataFrameType.Column(
|
||||
index=column.index,
|
||||
name=column.name,
|
||||
type=result_type,
|
||||
)
|
||||
)
|
||||
|
||||
return DataFrameType(columns=new_columns)
|
||||
|
||||
@method()
|
||||
def kurt(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "kurt")
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "max")
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "mean")
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "median")
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "min")
|
||||
|
||||
@method()
|
||||
def prod(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "prod")
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "std")
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "sum")
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "var")
|
||||
@@ -4,9 +4,20 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frame_methods import Call, MethodRegistry
|
||||
from midas.checker.frames.frame_groupby_methods import Call as GroupByCall
|
||||
from midas.checker.frames.frame_groupby_methods import FrameGroupByMethodRegistry
|
||||
from midas.checker.frames.frame_methods import Call, FrameMethodRegistry
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
FrameGroupBy,
|
||||
TupleType,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
@@ -19,7 +30,10 @@ def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
|
||||
class FrameManager:
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
self.method_resolver: MethodRegistry = MethodRegistry(self.typer)
|
||||
self.method_resolver: FrameMethodRegistry = FrameMethodRegistry(self.typer)
|
||||
self.groupby_method_resolver: FrameGroupByMethodRegistry = (
|
||||
FrameGroupByMethodRegistry(self.typer)
|
||||
)
|
||||
|
||||
def assign(
|
||||
self,
|
||||
@@ -34,12 +48,41 @@ class FrameManager:
|
||||
return self.assign_column(reporter, location, frame, name, value_type)
|
||||
|
||||
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
|
||||
isinstance(idx, str) for idx in indices
|
||||
isinstance(index.value, str) for index in indices
|
||||
):
|
||||
raise NotImplementedError
|
||||
names: list[str] = [cast(str, index.value) for index in indices]
|
||||
|
||||
if not isinstance(value_type, TupleType):
|
||||
reporter.error(
|
||||
location,
|
||||
f"Cannot assign {type} to dataframe columns. Must be a tuple of columns",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
if len(names) != len(value_type.items):
|
||||
reporter.error(
|
||||
location,
|
||||
f"Wrong number of columns. Cannot assign {len(value_type.items)} to {len(names)} targets",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
new_frame: Type = frame
|
||||
for name, value in zip(names, value_type.items):
|
||||
new_frame = self.assign_column(
|
||||
reporter,
|
||||
location,
|
||||
new_frame,
|
||||
name,
|
||||
value,
|
||||
)
|
||||
if not isinstance(new_frame, DataFrameType):
|
||||
return new_frame
|
||||
return new_frame
|
||||
|
||||
case _:
|
||||
reporter.error(location, f"Invalid index type {index} on {frame}")
|
||||
reporter.error(
|
||||
location, f"Invalid index type {index} on {frame} (assignment)"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def assign_column(
|
||||
@@ -87,9 +130,31 @@ class FrameManager:
|
||||
return TupleType(items=tuple(columns))
|
||||
|
||||
case _:
|
||||
reporter.error(location, f"Invalid index type {index} on {frame}")
|
||||
reporter.error(
|
||||
location, f"Invalid index type {index} on {frame} (access)"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def groupby_get(
|
||||
self,
|
||||
reporter: FileReporter,
|
||||
location: Location,
|
||||
groupby: FrameGroupBy,
|
||||
index: p.Expr,
|
||||
) -> Type:
|
||||
result: Type = self.get(reporter, location, groupby.frame, index)
|
||||
match result:
|
||||
case ColumnType():
|
||||
result = ColumnGroupBy(column=result)
|
||||
case TupleType(items=columns):
|
||||
result = TupleType(
|
||||
items=tuple(
|
||||
ColumnGroupBy(column=cast(ColumnType, column))
|
||||
for column in columns
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _set_column(
|
||||
cls, frame: DataFrameType, name: str, column: ColumnType
|
||||
@@ -141,14 +206,50 @@ class FrameManager:
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
frame: DataFrameType,
|
||||
frame_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: Call = Call(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
frame=frame,
|
||||
frame_expr=frame_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.method_resolver.call(method, call)
|
||||
|
||||
def groupby_call(
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
groupby: FrameGroupBy,
|
||||
groupby_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: GroupByCall = GroupByCall(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
groupby=groupby,
|
||||
groupby_expr=groupby_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.groupby_method_resolver.call(method, call)
|
||||
|
||||
def get_attribute(self, frame: DataFrameType, name: str) -> Optional[Type]:
|
||||
types: TypesRegistry = self.typer.types
|
||||
match name:
|
||||
case "ndim" | "size":
|
||||
return types.get_type("int")
|
||||
|
||||
case "shape":
|
||||
return types.tuple_of("int", "int")
|
||||
|
||||
case _:
|
||||
return None
|
||||
487
midas/checker/frames/frame_methods.py
Normal file
487
midas/checker/frames/frame_methods.py
Normal file
@@ -0,0 +1,487 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import (
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
OverloadedFunction,
|
||||
TopType,
|
||||
Type,
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
frame: DataFrameType
|
||||
frame_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.frame_expr, self.frame)
|
||||
|
||||
|
||||
class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
def _get_method_result(
|
||||
self,
|
||||
call: Call,
|
||||
column1: ColumnType,
|
||||
column2: ColumnType,
|
||||
method: str,
|
||||
) -> ColumnType:
|
||||
"""Get the result of calling a method on a column, passing a second
|
||||
|
||||
This function delegates to the main typer the resolution of the method
|
||||
member, as well as computing the result type. Because we don't have any
|
||||
AST expression for the individual columns, the frame expressions are
|
||||
used instead.
|
||||
|
||||
Args:
|
||||
call (Call): the call that triggered this resolution
|
||||
column1 (ColumnType): the first column, i.e. left operand
|
||||
column2 (ColumnType): the second column, i.e. right operand
|
||||
method (str): the method name
|
||||
|
||||
Returns:
|
||||
ColumnType: the resulting column.
|
||||
If the operation is invalid / doesn't exist,
|
||||
`ColumnType(type=UnknownType())` is returned
|
||||
"""
|
||||
|
||||
result: Type = self.typer.result_of_binary_op(
|
||||
location=call.location,
|
||||
expr=call.call_expr,
|
||||
left=(call.frame_expr, column1),
|
||||
right=(call.positional[0][0], column2),
|
||||
method=method,
|
||||
)
|
||||
|
||||
if not isinstance(result, ColumnType):
|
||||
return ColumnType(type=UnknownType())
|
||||
return result
|
||||
|
||||
def _element_binary_op(self, call: Call, method: str) -> DataFrameType:
|
||||
"""Compute the result of an element-wise binary operation
|
||||
|
||||
This function delegates to the matching columns for computing resulting
|
||||
types. Any column only present in one of the frames is forwarded as a
|
||||
generic `ColumnType(type=UnknownType())`. Columns only in the second
|
||||
frame are append at the end of the schema.
|
||||
|
||||
Args:
|
||||
call (Call): the call that triggered this resolution
|
||||
method (str): the method name
|
||||
|
||||
Returns:
|
||||
DataFrameType: the resulting frame type
|
||||
"""
|
||||
new_columns: list[DataFrameType.Column] = []
|
||||
|
||||
by_name: dict[str, DataFrameType.Column] = {}
|
||||
frame2: Optional[DataFrameType] = None
|
||||
# Get map of operand's columns by name, if there is at least 1 operand, which is a dataframe
|
||||
if len(call.positional) != 0:
|
||||
operand: TypedExpr = call.positional[0]
|
||||
unfolded_other: Type = unfold_type(operand[1])
|
||||
if isinstance(unfolded_other, DataFrameType):
|
||||
frame2 = unfolded_other
|
||||
by_name = {
|
||||
col.name: col for col in frame2.columns if col.name is not None
|
||||
}
|
||||
|
||||
# Compute new schema:
|
||||
# Step 1: for all columns in frame1:
|
||||
# - if present in frame2 -> delegate operation to columns
|
||||
# - if not -> add to schema as unknown
|
||||
in_frame1: set[str] = set()
|
||||
for column in call.frame.columns:
|
||||
if column.name is not None:
|
||||
in_frame1.add(column.name)
|
||||
|
||||
col_type1: ColumnType = column.type
|
||||
col_type: ColumnType = ColumnType(type=UnknownType())
|
||||
if column.name in by_name:
|
||||
column2 = by_name[column.name]
|
||||
col_type2: ColumnType = column2.type
|
||||
|
||||
col_type = self._get_method_result(call, col_type1, col_type2, method)
|
||||
|
||||
new_column = DataFrameType.Column(
|
||||
index=column.index,
|
||||
name=column.name,
|
||||
type=col_type,
|
||||
)
|
||||
new_columns.append(new_column)
|
||||
|
||||
# Step 2: for all columns in frame2
|
||||
# - if not in frame1 -> add to schema as unknown
|
||||
if frame2 is not None:
|
||||
for column in frame2.columns:
|
||||
if column.name in in_frame1:
|
||||
continue
|
||||
new_columns.append(
|
||||
DataFrameType.Column(
|
||||
index=len(new_columns),
|
||||
name=column.name,
|
||||
type=ColumnType(type=UnknownType()),
|
||||
)
|
||||
)
|
||||
|
||||
return DataFrameType(columns=new_columns)
|
||||
|
||||
def _element_wise(self, call: Call, method: str) -> Type:
|
||||
# TODO: support scalar, sequence, Series, dict operand
|
||||
# Build signature with new schema and generic operand
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="other",
|
||||
type=DataFrameType(columns=[]),
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
returns=self._element_binary_op(call, method),
|
||||
)
|
||||
|
||||
# Map arguments and compute result type
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
if result.is_valid:
|
||||
self._assert_same_length(
|
||||
call.call_expr, call.frame_expr, call.positional[0][0]
|
||||
)
|
||||
|
||||
return result.result
|
||||
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__add__")
|
||||
|
||||
@method("sub", "__sub__")
|
||||
def sub(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__sub__")
|
||||
|
||||
@method("mul", "__mul__")
|
||||
def mul(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mul__")
|
||||
|
||||
@method("div", "truediv", "__truediv__")
|
||||
def truediv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__truediv__")
|
||||
|
||||
@method("floordiv", "__floordiv__")
|
||||
def floordiv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__floordiv__")
|
||||
|
||||
@method("mod", "__mod__")
|
||||
def mod(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mod__")
|
||||
|
||||
@method("pow", "__pow__")
|
||||
def pow(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__pow__")
|
||||
|
||||
@method("lt", "__lt__")
|
||||
def lt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__lt__")
|
||||
|
||||
@method("gt", "__gt__")
|
||||
def gt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__gt__")
|
||||
|
||||
@method("le", "__le__")
|
||||
def le(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__le__")
|
||||
|
||||
@method("ge", "__ge__")
|
||||
def ge(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ge__")
|
||||
|
||||
@method("ne", "__ne__")
|
||||
def ne(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ne__")
|
||||
|
||||
@method("eq", "__eq__")
|
||||
def eq(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__eq__")
|
||||
|
||||
def _aggregate(self, call: Call, kwargs: list[Function.Argument] = []) -> Type:
|
||||
with_axis = Function(
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
returns=ColumnType(type=TopType()),
|
||||
)
|
||||
without_axis = Function(
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=self.types.get_type("None"),
|
||||
required=True,
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
returns=TopType(),
|
||||
)
|
||||
overload = OverloadedFunction(
|
||||
overloads=[
|
||||
with_axis,
|
||||
without_axis,
|
||||
]
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=overload,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method("kurtosis", "kurt")
|
||||
def kurtosis(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def mode(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method("product", "prod")
|
||||
def product(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="ddof",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="var",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def head(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=call.frame,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def tail(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=call.frame,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def groupby(self, call: Call) -> Type:
|
||||
bool_: Type = self.types.get_type("bool")
|
||||
function: Function = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="by",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="level",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=2,
|
||||
name="as_index",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=3,
|
||||
name="sort",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=4,
|
||||
name="group_keys",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=5,
|
||||
name="observed",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=6,
|
||||
name="dropna",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=FrameGroupBy(frame=call.frame),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=function,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr):
|
||||
func_name: str = "__midas_frame_same_length__"
|
||||
|
||||
# Efficiently compute length
|
||||
# https://stackoverflow.com/a/15943975/11109181
|
||||
def len_of_df(df: ast.expr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=ast.Name(id="len"),
|
||||
args=[
|
||||
ast.Attribute(
|
||||
value=df,
|
||||
attr="index",
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
|
||||
self.assertions.define(
|
||||
func_name,
|
||||
ast.FunctionDef(
|
||||
name=func_name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
ast.arg(arg="frame1"),
|
||||
ast.arg(arg="frame2"),
|
||||
],
|
||||
kwonlyargs=[],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
),
|
||||
body=[
|
||||
ast.Return(
|
||||
value=ast.Compare(
|
||||
left=len_of_df(ast.Name(id="frame1")),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[len_of_df(ast.Name(id="frame2"))],
|
||||
)
|
||||
)
|
||||
],
|
||||
decorator_list=[],
|
||||
),
|
||||
)
|
||||
self.assertions.add(
|
||||
bound_expr=call_expr,
|
||||
inputs=[frame1, frame2],
|
||||
builder=lambda f1, f2: ast.Call(
|
||||
func=ast.Name(id=func_name),
|
||||
args=[f1, f2],
|
||||
keywords=[],
|
||||
),
|
||||
message="DataFrames must have the same length",
|
||||
)
|
||||
100
midas/checker/frames/utils.py
Normal file
100
midas/checker/frames/utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Optional,
|
||||
Protocol,
|
||||
Self,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallDispatcher
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import Type, UnknownType
|
||||
from midas.generator.collector import AssertionCollector
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
|
||||
|
||||
class _MethodRegistryMeta(type):
|
||||
_methods: dict[str, Callable[..., Type]] = {}
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
name: str,
|
||||
bases: tuple[type, ...],
|
||||
namespace: dict[str, Any],
|
||||
):
|
||||
new_class = super().__new__(cls, name, bases, namespace)
|
||||
new_class._methods = {}
|
||||
for attr in namespace.values():
|
||||
if callable(attr) and hasattr(attr, "__method_names__"):
|
||||
for name in attr.__method_names__: # type: ignore
|
||||
new_class._methods[name] = attr # type: ignore
|
||||
return new_class
|
||||
|
||||
|
||||
class MethodCall(Protocol):
|
||||
@property
|
||||
def location(self) -> Location: ...
|
||||
|
||||
@property
|
||||
def call_expr(self) -> p.Expr: ...
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=MethodCall)
|
||||
|
||||
|
||||
class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
|
||||
@property
|
||||
def reporter(self) -> FileReporter:
|
||||
return self.typer.reporter
|
||||
|
||||
@property
|
||||
def types(self) -> TypesRegistry:
|
||||
return self.typer.types
|
||||
|
||||
@property
|
||||
def dispatcher(self) -> CallDispatcher[p.Expr]:
|
||||
return self.typer.dispatcher
|
||||
|
||||
@property
|
||||
def assertions(self) -> AssertionCollector:
|
||||
return self.typer.assertions
|
||||
|
||||
def call(self, method: str, call: T) -> Type:
|
||||
func: Optional[Callable[[Self, T], Type]] = self._methods.get(method)
|
||||
if func is None:
|
||||
self.reporter.warning(
|
||||
call.location, f"Unknown method {method} on {call.subject[1]}"
|
||||
)
|
||||
return UnknownType()
|
||||
return func(self, call)
|
||||
|
||||
|
||||
_Self = TypeVar("_Self", bound=MethodRegistry[Any])
|
||||
Method = Callable[[_Self, T], Type]
|
||||
|
||||
|
||||
def method(*names: str) -> Callable[[Method[_Self, T]], Method[_Self, T]]:
|
||||
def wrapper(func: Method[_Self, T]) -> Method[_Self, T]:
|
||||
names_: tuple[str, ...] = names
|
||||
if len(names_) == 0:
|
||||
names_ = (func.__name__,)
|
||||
setattr(func, "__method_names__", names_)
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
@@ -108,8 +108,8 @@ class Preamble(Environment):
|
||||
],
|
||||
)
|
||||
|
||||
def _list_of(self, item_type: Type) -> Type:
|
||||
return self._types.apply_generic(self._types.get_type("list"), [item_type])
|
||||
def _list_of(self, item_type: str | Type) -> Type:
|
||||
return self._types.list_of(item_type)
|
||||
|
||||
def _def_type_constructor(
|
||||
self, name: str, py_function: Optional[Callable[..., Any]] = None
|
||||
|
||||
@@ -9,7 +9,8 @@ from midas.ast.printer import MidasPrinter
|
||||
from midas.checker.dispatcher import CallDispatcher, CallResult
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.evaluator import Evaluator
|
||||
from midas.checker.frames import FrameManager
|
||||
from midas.checker.frames.column_manager import ColumnManager
|
||||
from midas.checker.frames.frame_manager import FrameManager
|
||||
from midas.checker.operators import (
|
||||
PY_COMPARATOR_METHODS,
|
||||
PY_OPERATOR_METHODS,
|
||||
@@ -22,12 +23,15 @@ from midas.checker.resolver import Resolver
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
ConstraintType,
|
||||
DataFrameType,
|
||||
DerivedType,
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
GenericType,
|
||||
TopType,
|
||||
TupleType,
|
||||
Type,
|
||||
TypeVar,
|
||||
@@ -36,6 +40,7 @@ from midas.checker.types import (
|
||||
Variance,
|
||||
unfold_type,
|
||||
)
|
||||
from midas.generator.collector import AssertionCollector
|
||||
from midas.parser.python import PythonParser
|
||||
from midas.utils import TypedAST
|
||||
|
||||
@@ -79,6 +84,7 @@ class PythonTyper(
|
||||
self.reporter: FileReporter = reporter.for_file(None)
|
||||
self.types: TypesRegistry = types
|
||||
self.frame_mgr: FrameManager = FrameManager(self)
|
||||
self.column_mgr: ColumnManager = ColumnManager(self)
|
||||
self.global_env: Environment = Preamble(self.types)
|
||||
self.env: Environment = self.global_env
|
||||
self.locals: dict[p.Expr, int] = {}
|
||||
@@ -87,6 +93,7 @@ class PythonTyper(
|
||||
self.dispatcher: CallDispatcher[p.Expr] = CallDispatcher[p.Expr](
|
||||
self.types, self.reporter
|
||||
)
|
||||
self.assertions: AssertionCollector = AssertionCollector()
|
||||
|
||||
def set_reporter(self, reporter: FileReporter):
|
||||
self.reporter = reporter
|
||||
@@ -113,6 +120,7 @@ class PythonTyper(
|
||||
stmts=stmts,
|
||||
judgements=self.judgements,
|
||||
evaluated_casts=self.evaluated_casts,
|
||||
assertions=self.assertions,
|
||||
)
|
||||
|
||||
def judge(self, expr: p.Expr, type: Type):
|
||||
@@ -209,23 +217,59 @@ class PythonTyper(
|
||||
def call_method(
|
||||
self,
|
||||
location: Location,
|
||||
obj: Type,
|
||||
call_expr: p.Expr,
|
||||
obj: TypedExpr,
|
||||
method_name: str,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Optional[Type]:
|
||||
unfolded: Type = unfold_type(obj)
|
||||
) -> Type:
|
||||
unfolded: Type = unfold_type(obj[1])
|
||||
match unfolded:
|
||||
case DataFrameType():
|
||||
return self.frame_mgr.call(
|
||||
method=method_name,
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
frame=unfolded,
|
||||
frame_expr=obj[0],
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
method: Optional[Type] = self.types.lookup_member(obj, method_name)
|
||||
case FrameGroupBy():
|
||||
return self.frame_mgr.groupby_call(
|
||||
method=method_name,
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
groupby=unfolded,
|
||||
groupby_expr=obj[0],
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
case ColumnType():
|
||||
return self.column_mgr.call(
|
||||
method=method_name,
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
column=unfolded,
|
||||
column_expr=obj[0],
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
case ColumnGroupBy():
|
||||
return self.column_mgr.groupby_call(
|
||||
method=method_name,
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
groupby=unfolded,
|
||||
groupby_expr=obj[0],
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
method: Optional[Type] = self.types.lookup_member(obj[1], method_name)
|
||||
if method is None:
|
||||
raise UndefinedMethodException
|
||||
|
||||
@@ -499,7 +543,15 @@ class PythonTyper(
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||
left: Type = self.type_of(expr.left)
|
||||
right: Type = self.type_of(expr.right)
|
||||
return self.result_of_binary_op(
|
||||
expr.location,
|
||||
expr,
|
||||
(expr.left, left),
|
||||
(expr.right, right),
|
||||
method,
|
||||
)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||
method: Optional[str] = PY_COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||
@@ -510,26 +562,40 @@ class PythonTyper(
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||
left: Type = self.type_of(expr.left)
|
||||
right: Type = self.type_of(expr.right)
|
||||
return self.result_of_binary_op(
|
||||
expr.location,
|
||||
expr,
|
||||
(expr.left, left),
|
||||
(expr.right, right),
|
||||
method,
|
||||
)
|
||||
|
||||
def _visit_binary_expr(
|
||||
self, location: Location, left_expr: p.Expr, right_expr: p.Expr, method: str
|
||||
def result_of_binary_op(
|
||||
self,
|
||||
location: Location,
|
||||
expr: p.Expr,
|
||||
left: TypedExpr,
|
||||
right: TypedExpr,
|
||||
method: str,
|
||||
) -> Type:
|
||||
left: Type = self.type_of(left_expr)
|
||||
right: Type = self.type_of(right_expr)
|
||||
|
||||
result: Optional[Type]
|
||||
try:
|
||||
result = self.call_method(location, left, method, [(right_expr, right)], {})
|
||||
return self.call_method(
|
||||
location=location,
|
||||
call_expr=expr,
|
||||
obj=left,
|
||||
method_name=method,
|
||||
positional=[right],
|
||||
keywords={},
|
||||
)
|
||||
except UndefinedMethodException:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Undefined operation {method} between {left} and {right}",
|
||||
f"Undefined operation {method} between {left[1]} and {right[1]}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return result or UnknownType()
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
||||
method: Optional[str] = PY_UNARY_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
@@ -541,9 +607,15 @@ class PythonTyper(
|
||||
|
||||
operand: Type = self.type_of(expr.right)
|
||||
|
||||
result: Optional[Type]
|
||||
try:
|
||||
result = self.call_method(expr.location, operand, method, [], {})
|
||||
return self.call_method(
|
||||
location=expr.location,
|
||||
call_expr=expr,
|
||||
obj=(expr.right, operand),
|
||||
method_name=method,
|
||||
positional=[],
|
||||
keywords={},
|
||||
)
|
||||
except UndefinedMethodException:
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
@@ -551,8 +623,6 @@ class PythonTyper(
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return result or UnknownType()
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||
match expr.callee:
|
||||
case p.VariableExpr(name="TypeVar"):
|
||||
@@ -568,15 +638,14 @@ class PythonTyper(
|
||||
match expr.callee:
|
||||
case p.GetExpr(object=obj, name=method):
|
||||
obj_type: Type = self.type_of(obj)
|
||||
unfolded: Type = unfold_type(obj_type)
|
||||
if isinstance(unfolded, DataFrameType):
|
||||
return self.frame_mgr.call(
|
||||
method,
|
||||
expr.location,
|
||||
unfolded,
|
||||
positional,
|
||||
keywords,
|
||||
)
|
||||
return self.call_method(
|
||||
location=expr.location,
|
||||
call_expr=expr,
|
||||
obj=(obj, obj_type),
|
||||
method_name=method,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
callee: Type = self.type_of(expr.callee)
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
@@ -590,6 +659,14 @@ class PythonTyper(
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
||||
object: Type = self.type_of(expr.object)
|
||||
member: Optional[Type] = self.types.lookup_member(object, expr.name)
|
||||
|
||||
if member is None:
|
||||
match object:
|
||||
case DataFrameType():
|
||||
member = self.frame_mgr.get_attribute(object, expr.name)
|
||||
case ColumnType():
|
||||
member = self.column_mgr.get_attribute(object, expr.name)
|
||||
|
||||
if member is None:
|
||||
self.reporter.warning(
|
||||
expr.location, f"Unknown member '{expr.name}' of {object}"
|
||||
@@ -738,6 +815,8 @@ class PythonTyper(
|
||||
return self._visit_tuple_subscript(unfolded, expr)
|
||||
case DataFrameType():
|
||||
return self._visit_frame_subscript(unfolded, expr)
|
||||
case FrameGroupBy():
|
||||
return self._visit_frame_groupby_subscript(unfolded, expr)
|
||||
|
||||
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
|
||||
if operation is None:
|
||||
@@ -936,6 +1015,17 @@ class PythonTyper(
|
||||
self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any
|
||||
) -> bool:
|
||||
match target_type:
|
||||
case TopType():
|
||||
return True
|
||||
|
||||
case UnitType():
|
||||
if lit_value is not None:
|
||||
self.reporter.error(
|
||||
expr.location, f"Value {lit_value!r} is not None"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
case DerivedType(type=base):
|
||||
return self._evaluate_cast_statically(
|
||||
expr, subject_type, base, lit_value
|
||||
@@ -1052,3 +1142,10 @@ class PythonTyper(
|
||||
self, frame: DataFrameType, expr: p.SubscriptExpr
|
||||
) -> Type:
|
||||
return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index)
|
||||
|
||||
def _visit_frame_groupby_subscript(
|
||||
self, groupby: FrameGroupBy, expr: p.SubscriptExpr
|
||||
) -> Type:
|
||||
return self.frame_mgr.groupby_get(
|
||||
self.reporter, expr.location, groupby, expr.index
|
||||
)
|
||||
|
||||
@@ -113,6 +113,15 @@ class TypesRegistry:
|
||||
raise ValueError(f"Predicate {name} already defined")
|
||||
self._predicates[name] = predicate
|
||||
|
||||
def is_builtin_subtype(self, name1: str, name2: str) -> bool:
|
||||
subtypes: set[str] = BUILTIN_SUBTYPES.get(name2, set())
|
||||
if name1 in subtypes:
|
||||
return True
|
||||
for subtype in subtypes:
|
||||
if self.is_builtin_subtype(name1, subtype):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||
"""Check whether `type1` is a subtype of `type2`
|
||||
|
||||
@@ -150,7 +159,7 @@ class TypesRegistry:
|
||||
return self.is_subtype(base1, type2)
|
||||
|
||||
case (BaseType(name=name1), BaseType(name=name2)):
|
||||
return name1 in BUILTIN_SUBTYPES.get(name2, set())
|
||||
return self.is_builtin_subtype(name1, name2)
|
||||
|
||||
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
||||
for k, t in props2.items():
|
||||
@@ -443,3 +452,29 @@ class TypesRegistry:
|
||||
|
||||
def lookup_predicate(self, name: str) -> Optional[Predicate]:
|
||||
return self._predicates.get(name)
|
||||
|
||||
def _by_name_or_type(self, name_or_type: str | Type) -> Type:
|
||||
if isinstance(name_or_type, str):
|
||||
return self.get_type(name_or_type)
|
||||
return name_or_type
|
||||
|
||||
def list_of(self, item_type: str | Type) -> Type:
|
||||
list_ = self.get_type("list")
|
||||
return self.apply_generic(list_, [self._by_name_or_type(item_type)])
|
||||
|
||||
def tuple_of(self, *item_types: str | Type) -> Type:
|
||||
tuple_ = self.get_type("tuple")
|
||||
return self.apply_generic(
|
||||
tuple_,
|
||||
[self._by_name_or_type(item_type) for item_type in item_types],
|
||||
)
|
||||
|
||||
def dict_of(self, key_type: str | Type, value_type: str | Type) -> Type:
|
||||
dict_ = self.get_type("dict")
|
||||
return self.apply_generic(
|
||||
dict_,
|
||||
[
|
||||
self._by_name_or_type(key_type),
|
||||
self._by_name_or_type(value_type),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -187,6 +187,22 @@ class DataFrameType:
|
||||
type: ColumnType
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class FrameGroupBy:
|
||||
frame: DataFrameType
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"FrameGroupBy[{self.frame}]"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ColumnGroupBy:
|
||||
column: ColumnType
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ColumnGroupBy[{self.column}]"
|
||||
|
||||
|
||||
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
def sub_argument(arg: Function.Argument):
|
||||
return Function.Argument(
|
||||
@@ -305,11 +321,20 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
columns=list(map(sub_column, columns)),
|
||||
)
|
||||
|
||||
case FrameGroupBy(frame=frame):
|
||||
return FrameGroupBy(
|
||||
frame=cast(DataFrameType, substitute_typevars(frame, substitutions))
|
||||
)
|
||||
|
||||
case ColumnGroupBy(column=column):
|
||||
return ColumnGroupBy(
|
||||
column=cast(ColumnType, substitute_typevars(column, substitutions))
|
||||
)
|
||||
|
||||
case UnknownType() | UnitType():
|
||||
return type
|
||||
|
||||
case TopType() | GenericType():
|
||||
|
||||
raise NotImplementedError(f"Unsupported type {type}")
|
||||
|
||||
# Ensure exhaustiveness
|
||||
@@ -382,6 +407,12 @@ def to_annotation(type: Type) -> str:
|
||||
case DataFrameType():
|
||||
return "pd.DataFrame"
|
||||
|
||||
case FrameGroupBy():
|
||||
return "pd.api.typing.DataFrameGroupBy"
|
||||
|
||||
case ColumnGroupBy():
|
||||
return "pd.api.typing.SeriesGroupBy"
|
||||
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
@@ -410,4 +441,6 @@ Type = (
|
||||
| TupleType
|
||||
| ColumnType
|
||||
| DataFrameType
|
||||
| FrameGroupBy
|
||||
| ColumnGroupBy
|
||||
)
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import Optional
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
Function,
|
||||
GenericType,
|
||||
TopType,
|
||||
@@ -98,6 +100,30 @@ class Unifier:
|
||||
|
||||
return substitutions
|
||||
|
||||
case (
|
||||
DataFrameType(columns=template_columns),
|
||||
DataFrameType(columns=concrete_columns),
|
||||
) if len(template_columns) == len(concrete_columns):
|
||||
substitutions: dict[str, Type] = {}
|
||||
for template_column, concrete_column in zip(
|
||||
template_columns, concrete_columns
|
||||
):
|
||||
if template_column.index != concrete_column or (
|
||||
template_column.name != concrete_column.name
|
||||
):
|
||||
self.logger.debug(
|
||||
f"Column mismatch: template={template_column}, concrete={concrete_column}"
|
||||
)
|
||||
raise UnificationError
|
||||
new_substistutions: dict[str, Type] = self.match(
|
||||
template_column.type, concrete_column.type
|
||||
)
|
||||
substitutions = self.merge(substitutions, new_substistutions)
|
||||
return substitutions
|
||||
|
||||
case (ColumnType(type=template_column), ColumnType(type=concrete_column)):
|
||||
return self.match(template_column, concrete_column)
|
||||
|
||||
case (Function(), Function()):
|
||||
mapped: list[tuple[Function.Argument, Function.Argument]] = (
|
||||
self.map_params(template, concrete)
|
||||
|
||||
59
midas/generator/collector.py
Normal file
59
midas/generator/collector.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import midas.ast.python as p
|
||||
|
||||
AssertionBuilder = Callable[..., ast.expr]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Assertion:
|
||||
bound_expr: p.Expr
|
||||
inputs: list[p.Expr]
|
||||
builder: AssertionBuilder
|
||||
message: str
|
||||
|
||||
def is_bound_to(self, expr: p.Expr) -> bool:
|
||||
return expr == self.bound_expr
|
||||
|
||||
|
||||
class AssertionCollector:
|
||||
def __init__(self):
|
||||
self.assertions: list[Assertion] = []
|
||||
self.definitions: dict[str, ast.stmt] = {}
|
||||
|
||||
def add(
|
||||
self,
|
||||
bound_expr: p.Expr,
|
||||
inputs: list[p.Expr],
|
||||
builder: AssertionBuilder,
|
||||
message: str,
|
||||
):
|
||||
self.assertions.append(
|
||||
Assertion(
|
||||
bound_expr=bound_expr,
|
||||
inputs=inputs,
|
||||
builder=builder,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
|
||||
def remove(self, assertion: Assertion):
|
||||
try:
|
||||
self.assertions.remove(assertion)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def define(self, name: str, stmt: ast.stmt):
|
||||
if name not in self.definitions:
|
||||
self.definitions[name] = stmt
|
||||
|
||||
def get_definitions(self) -> list[ast.stmt]:
|
||||
return list(self.definitions.values())
|
||||
|
||||
def get_assertions(self) -> list[Assertion]:
|
||||
return self.assertions
|
||||
|
||||
def get_assertions_for(self, expr: p.Expr) -> list[Assertion]:
|
||||
return list(filter(lambda a: a.is_bound_to(expr), self.assertions))
|
||||
@@ -14,12 +14,14 @@ from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
DataFrameType,
|
||||
DerivedType,
|
||||
ExtensionType,
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
@@ -30,6 +32,7 @@ from midas.checker.types import (
|
||||
UnitType,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.generator.collector import Assertion, AssertionCollector
|
||||
from midas.generator.constraints import ConstraintGenerator
|
||||
from midas.generator.stubs import StubsGenerator
|
||||
from midas.utils import TypedAST
|
||||
@@ -55,10 +58,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
stmts=[],
|
||||
judgements=[],
|
||||
evaluated_casts=[],
|
||||
assertions=AssertionCollector(),
|
||||
)
|
||||
self._alias_count: int = 0
|
||||
self._predicate_count: int = 0
|
||||
self._scopes: list[Scope] = []
|
||||
self._aliases: list[tuple[p.Expr, ast.expr]] = []
|
||||
|
||||
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
|
||||
self._constraints: list[tuple[m.Expr, ast.expr]] = []
|
||||
@@ -71,7 +76,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def generate_ast(self, typed_ast: TypedAST) -> ast.AST:
|
||||
self._typed_ast = typed_ast
|
||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts, can_be_empty=True)
|
||||
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
|
||||
|
||||
body = predicates + body
|
||||
@@ -129,39 +134,48 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
output: str = ast.unparse(module)
|
||||
out_path.write_text(output)
|
||||
|
||||
def convert(self, expr: p.Expr) -> ast.expr:
|
||||
for expr2, alias in self._aliases:
|
||||
if expr2 == expr:
|
||||
return alias
|
||||
assertions = self._typed_ast.assertions.get_assertions_for(expr)
|
||||
if len(assertions) != 0:
|
||||
return self._apply_assertions(expr, assertions)
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> ast.expr:
|
||||
return ast.BinOp(
|
||||
left=expr.left.accept(self),
|
||||
left=self.convert(expr.left),
|
||||
op=expr.operator,
|
||||
right=expr.right.accept(self),
|
||||
right=self.convert(expr.right),
|
||||
)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> ast.expr:
|
||||
return ast.Compare(
|
||||
left=expr.left.accept(self),
|
||||
left=self.convert(expr.left),
|
||||
ops=[expr.operator],
|
||||
comparators=[expr.right.accept(self)],
|
||||
comparators=[self.convert(expr.right)],
|
||||
)
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> ast.expr:
|
||||
return ast.UnaryOp(
|
||||
op=expr.operator,
|
||||
operand=expr.right.accept(self),
|
||||
operand=self.convert(expr.right),
|
||||
)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=expr.callee.accept(self),
|
||||
args=[arg.accept(self) for arg in expr.arguments],
|
||||
func=self.convert(expr.callee),
|
||||
args=[self.convert(arg) for arg in expr.arguments],
|
||||
keywords=[
|
||||
ast.keyword(arg=name, value=arg.accept(self))
|
||||
ast.keyword(arg=name, value=self.convert(arg))
|
||||
for name, arg in expr.keywords.items()
|
||||
],
|
||||
)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> ast.expr:
|
||||
return ast.Attribute(
|
||||
value=expr.object.accept(self),
|
||||
value=self.convert(expr.object),
|
||||
attr=expr.name,
|
||||
)
|
||||
|
||||
@@ -174,16 +188,16 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> ast.expr:
|
||||
return ast.BoolOp(
|
||||
op=expr.operator,
|
||||
values=[expr.left.accept(self), expr.right.accept(self)],
|
||||
values=[self.convert(expr.left), self.convert(expr.right)],
|
||||
)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
||||
expr2: ast.expr = expr.expr.accept(self)
|
||||
expr2: ast.expr = self.convert(expr.expr)
|
||||
|
||||
if expr in self._typed_ast.evaluated_casts or expr.unsafe:
|
||||
return expr2
|
||||
|
||||
alias: ast.expr = self._make_alias(expr2)
|
||||
alias: ast.expr = self._make_alias(expr.expr, expr2)
|
||||
|
||||
type: Type = self._get_expr_type(expr)
|
||||
asserts: list[ast.stmt] = self._make_cast_asserts(expr.location, alias, type)
|
||||
@@ -194,38 +208,38 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
|
||||
return ast.IfExp(
|
||||
test=expr.test.accept(self),
|
||||
body=expr.if_true.accept(self),
|
||||
orelse=expr.if_false.accept(self),
|
||||
test=self.convert(expr.test),
|
||||
body=self.convert(expr.if_true),
|
||||
orelse=self.convert(expr.if_false),
|
||||
)
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> ast.expr:
|
||||
return ast.List(
|
||||
elts=[item.accept(self) for item in expr.items],
|
||||
elts=[self.convert(item) for item in expr.items],
|
||||
)
|
||||
|
||||
def visit_dict_expr(self, expr: p.DictExpr) -> ast.expr:
|
||||
return ast.Dict(
|
||||
keys=[key.accept(self) if key is not None else None for key in expr.keys],
|
||||
values=[value.accept(self) for value in expr.values],
|
||||
keys=[self.convert(key) if key is not None else None for key in expr.keys],
|
||||
values=[self.convert(value) for value in expr.values],
|
||||
)
|
||||
|
||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr:
|
||||
return ast.Subscript(
|
||||
value=expr.object.accept(self),
|
||||
slice=expr.index.accept(self),
|
||||
value=self.convert(expr.object),
|
||||
slice=self.convert(expr.index),
|
||||
)
|
||||
|
||||
def visit_slice_expr(self, expr: p.SliceExpr) -> ast.expr:
|
||||
return ast.Slice(
|
||||
lower=expr.lower.accept(self) if expr.lower is not None else None,
|
||||
upper=expr.upper.accept(self) if expr.upper is not None else None,
|
||||
step=expr.step.accept(self) if expr.step is not None else None,
|
||||
lower=self.convert(expr.lower) if expr.lower is not None else None,
|
||||
upper=self.convert(expr.upper) if expr.upper is not None else None,
|
||||
step=self.convert(expr.step) if expr.step is not None else None,
|
||||
)
|
||||
|
||||
def visit_tuple_expr(self, expr: p.TupleExpr) -> ast.expr:
|
||||
return ast.Tuple(
|
||||
elts=[item.accept(self) for item in expr.items],
|
||||
elts=[self.convert(item) for item in expr.items],
|
||||
)
|
||||
|
||||
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
|
||||
@@ -233,7 +247,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt:
|
||||
return ast.Expr(
|
||||
value=stmt.expr.accept(self),
|
||||
value=self.convert(stmt.expr),
|
||||
)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> ast.stmt:
|
||||
@@ -246,12 +260,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
kwonlyargs=[ast.arg(arg=arg.name) for arg in stmt.kwonlyargs],
|
||||
kwarg=None,
|
||||
defaults=[
|
||||
arg.default.accept(self)
|
||||
self.convert(arg.default)
|
||||
for arg in stmt.posonlyargs + stmt.args
|
||||
if arg.default is not None
|
||||
],
|
||||
kw_defaults=[
|
||||
arg.default.accept(self) if arg.default is not None else None
|
||||
self.convert(arg.default) if arg.default is not None else None
|
||||
for arg in stmt.kwonlyargs
|
||||
],
|
||||
),
|
||||
@@ -265,20 +279,20 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> ast.stmt:
|
||||
return ast.Assign(
|
||||
targets=[target.accept(self) for target in stmt.targets],
|
||||
value=stmt.value.accept(self),
|
||||
targets=[self.convert(target) for target in stmt.targets],
|
||||
value=self.convert(stmt.value),
|
||||
)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> ast.stmt:
|
||||
return ast.Return(
|
||||
value=stmt.value.accept(self) if stmt.value is not None else None,
|
||||
value=self.convert(stmt.value) if stmt.value is not None else None,
|
||||
)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> ast.stmt:
|
||||
return ast.If(
|
||||
test=stmt.test.accept(self),
|
||||
test=self.convert(stmt.test),
|
||||
body=self._visit_body(stmt.body),
|
||||
orelse=self._visit_body(stmt.orelse),
|
||||
orelse=self._visit_body(stmt.orelse, can_be_empty=True),
|
||||
)
|
||||
|
||||
def visit_pass(self, stmt: p.Pass) -> ast.stmt:
|
||||
@@ -286,8 +300,8 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def visit_for_stmt(self, stmt: p.ForStmt) -> ast.stmt:
|
||||
return ast.For(
|
||||
target=stmt.target.accept(self),
|
||||
iter=stmt.iterator.accept(self),
|
||||
target=self.convert(stmt.target),
|
||||
iter=self.convert(stmt.iterator),
|
||||
body=self._visit_body(stmt.body),
|
||||
orelse=[],
|
||||
)
|
||||
@@ -295,7 +309,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
def visit_raw_stmt(self, stmt: p.RawStmt) -> ast.stmt:
|
||||
return stmt.stmt
|
||||
|
||||
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
|
||||
def _visit_body(
|
||||
self, stmts: list[p.Stmt], can_be_empty: bool = False
|
||||
) -> list[ast.stmt]:
|
||||
generated: list[ast.stmt] = []
|
||||
for stmt in stmts:
|
||||
scope = Scope()
|
||||
@@ -313,9 +329,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
# Remove redundant pass statements
|
||||
if len(generated) > 1:
|
||||
generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)]
|
||||
if len(generated) == 0 and not can_be_empty:
|
||||
generated = [ast.Pass()]
|
||||
return generated
|
||||
|
||||
def _make_alias(self, expr: ast.expr) -> ast.expr:
|
||||
def _make_alias(self, node: p.Expr, expr: ast.expr) -> ast.expr:
|
||||
name: str = f"__midas_a{self._alias_count}__"
|
||||
alias = ast.Name(id=name)
|
||||
self._alias_count += 1
|
||||
@@ -326,6 +344,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
value=expr,
|
||||
)
|
||||
)
|
||||
self._aliases.append((node, alias))
|
||||
return alias
|
||||
|
||||
def _build_assert(self, expr: ast.expr, message: str | ast.expr) -> ast.stmt:
|
||||
@@ -349,7 +368,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
self, src_location: Location, expr: ast.expr, type: Type
|
||||
) -> list[ast.stmt]:
|
||||
match type:
|
||||
case UnknownType():
|
||||
case UnknownType() | TopType():
|
||||
return []
|
||||
|
||||
case BaseType(name=name):
|
||||
@@ -480,12 +499,13 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
return asserts
|
||||
|
||||
case (
|
||||
TopType()
|
||||
| Function()
|
||||
Function()
|
||||
| OverloadedFunction()
|
||||
| ComplexType()
|
||||
| ExtensionType()
|
||||
| GenericType()
|
||||
| FrameGroupBy()
|
||||
| ColumnGroupBy()
|
||||
):
|
||||
self.logger.warning(f"Can't make assertion for type {type}")
|
||||
return []
|
||||
@@ -640,3 +660,30 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
body=body,
|
||||
orelse=[],
|
||||
)
|
||||
|
||||
def _convert_assertion(self, assertion: Assertion) -> ast.stmt:
|
||||
inputs: list[ast.expr] = []
|
||||
|
||||
for input in assertion.inputs:
|
||||
converted: ast.expr = self.convert(input)
|
||||
alias: ast.expr = self._make_alias(input, converted)
|
||||
inputs.append(alias)
|
||||
|
||||
test: ast.expr = assertion.builder(*inputs)
|
||||
location: Location = assertion.bound_expr.location
|
||||
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||
return self._build_assert(
|
||||
test, f"{loc_str}: AssertionError: {assertion.message}"
|
||||
)
|
||||
|
||||
def _apply_assertions(self, expr: p.Expr, assertions: list[Assertion]) -> ast.expr:
|
||||
for assertion in assertions:
|
||||
assert_stmt: ast.stmt
|
||||
assert_stmt = self._convert_assertion(assertion)
|
||||
self._add_assert(assert_stmt)
|
||||
|
||||
# Mutating list in frozen dataclass
|
||||
# Not ideal but easiest way to avoid duplicate assertions
|
||||
self._typed_ast.assertions.remove(assertion)
|
||||
|
||||
return expr.accept(self)
|
||||
|
||||
@@ -6,12 +6,14 @@ from midas.checker.registry import Member, TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
DataFrameType,
|
||||
DerivedType,
|
||||
ExtensionType,
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
@@ -91,6 +93,21 @@ class StubsGenerator:
|
||||
def generate_stub(self, name: str, type: Type):
|
||||
base_type: Type = type
|
||||
|
||||
# TODO: improve
|
||||
match type:
|
||||
case DerivedType(name=name_) | GenericType(name=name_) if name_ == name:
|
||||
pass
|
||||
case UnitType() if name == "None":
|
||||
pass
|
||||
case TopType() if name == "Any":
|
||||
pass
|
||||
case _:
|
||||
alias = ast.Assign(
|
||||
targets=[ast.Name(id=name)], value=self.dump_type(type)
|
||||
)
|
||||
self.add_stub(alias)
|
||||
return
|
||||
|
||||
members: dict[str, Member] = self.types._members.get(name, {})
|
||||
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
|
||||
return
|
||||
@@ -272,6 +289,32 @@ class StubsGenerator:
|
||||
attr="DataFrame",
|
||||
)
|
||||
|
||||
case FrameGroupBy():
|
||||
self.import_pandas = True
|
||||
return ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="api",
|
||||
),
|
||||
attr="typing",
|
||||
),
|
||||
attr="DataFrameGroupBy",
|
||||
)
|
||||
|
||||
case ColumnGroupBy():
|
||||
self.import_pandas = True
|
||||
return ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="api",
|
||||
),
|
||||
attr="typing",
|
||||
),
|
||||
attr="SeriesGroupBy",
|
||||
)
|
||||
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Callable, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.checker.types import Type
|
||||
from midas.generator.collector import AssertionCollector
|
||||
|
||||
AllowRepeat = Callable[[object], bool]
|
||||
|
||||
@@ -63,3 +64,4 @@ class TypedAST:
|
||||
stmts: list[p.Stmt]
|
||||
judgements: list[tuple[p.Expr, Type]]
|
||||
evaluated_casts: list[p.CastExpr]
|
||||
assertions: AssertionCollector
|
||||
|
||||
43
tests/__main__.py
Normal file
43
tests/__main__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Type
|
||||
|
||||
from midas.cli.ansi import Ansi
|
||||
from tests.base import Tester
|
||||
from tests.checker import CheckerTester
|
||||
from tests.generator import GeneratorTester
|
||||
from tests.midas import MidasTester
|
||||
from tests.python import PythonTester
|
||||
|
||||
|
||||
def print_banner(name: str):
|
||||
horizontal: str = "+" + "-" * (len(name) + 2) + "+"
|
||||
print(horizontal)
|
||||
print(f"| {name} |")
|
||||
print(horizontal)
|
||||
|
||||
|
||||
def run_tests(tester_cls: Type[Tester]) -> bool:
|
||||
print_banner(tester_cls.__name__)
|
||||
tester: Tester = tester_cls()
|
||||
success: bool = tester.run_all_tests()
|
||||
print()
|
||||
return success
|
||||
|
||||
|
||||
def main():
|
||||
testers: list[Type[Tester]] = [
|
||||
PythonTester,
|
||||
MidasTester,
|
||||
CheckerTester,
|
||||
GeneratorTester,
|
||||
]
|
||||
|
||||
success: bool = all(map(run_tests, testers))
|
||||
|
||||
if success:
|
||||
print(Ansi.FG(Ansi.BRIGHT_GREEN) + "All tests passed!" + Ansi.RESET)
|
||||
else:
|
||||
print(Ansi.FG(Ansi.BRIGHT_RED) + "Some tests failed!" + Ansi.RESET)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -7,6 +7,8 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Protocol
|
||||
|
||||
from midas.cli.ansi import Ansi
|
||||
|
||||
|
||||
class CaseResult(Protocol):
|
||||
def dumps(self) -> str: ...
|
||||
@@ -44,8 +46,11 @@ class Tester(ABC):
|
||||
|
||||
print(rule)
|
||||
for i, test in enumerate(tests):
|
||||
print(f"Case {i+1}/{n}: {test.resolve().relative_to(self.CASES_DIR)}")
|
||||
path: Path = test.resolve().relative_to(self.CASES_DIR)
|
||||
print(f"{Ansi.FG(Ansi.BRIGHT_CYAN)}Case {i+1}/{n}: {path}{Ansi.RESET}")
|
||||
print(Ansi.DIM, end="")
|
||||
success: bool = self._run_test(test)
|
||||
print(Ansi.RESET, end="")
|
||||
if success:
|
||||
successes += 1
|
||||
else:
|
||||
@@ -146,8 +151,9 @@ class Tester(ABC):
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
case None:
|
||||
print("No subcommand provided. Available subcommands: run, update")
|
||||
sys.exit(1)
|
||||
success: bool = tester.run_all_tests()
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
case _:
|
||||
print(f"Unknown subcommand '{args.subcommand}'")
|
||||
sys.exit(1)
|
||||
|
||||
117
tests/cases/checker/09_frame_ops.py
Normal file
117
tests/cases/checker/09_frame_ops.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
df1: Frame[a:int, b:float]
|
||||
df2: Frame[a:int, b:float]
|
||||
|
||||
_: Any
|
||||
|
||||
# Arithmetic
|
||||
_ = df1 + df2
|
||||
_ = df1 - df2
|
||||
_ = df1 * df2
|
||||
_ = df1 / df2
|
||||
_ = df1 // df2
|
||||
_ = df1 % df2
|
||||
_ = df1**df2
|
||||
|
||||
# Comparisons
|
||||
_ = df1 < df2
|
||||
_ = df1 > df2
|
||||
_ = df1 <= df2
|
||||
_ = df1 >= df2
|
||||
_ = df1 != df2
|
||||
_ = df1 == df2
|
||||
|
||||
# Aggregate
|
||||
_ = df1.kurt()
|
||||
_ = df1.kurtosis()
|
||||
_ = df1.max()
|
||||
_ = df1.mean()
|
||||
_ = df1.median()
|
||||
_ = df1.min()
|
||||
_ = df1.mode()
|
||||
_ = df1.prod()
|
||||
_ = df1.product()
|
||||
_ = df1.std()
|
||||
_ = df1.sum()
|
||||
_ = df1.var()
|
||||
|
||||
# Groupby
|
||||
df_gb = df1.groupby(by="a")
|
||||
|
||||
_ = df_gb.kurt()
|
||||
_ = df_gb.max()
|
||||
_ = df_gb.mean()
|
||||
_ = df_gb.median()
|
||||
_ = df_gb.min()
|
||||
_ = df_gb.prod()
|
||||
_ = df_gb.std()
|
||||
_ = df_gb.sum()
|
||||
_ = df_gb.var()
|
||||
|
||||
|
||||
# Columns
|
||||
|
||||
col1 = df1["a"]
|
||||
col2 = df1["a"]
|
||||
|
||||
# Arithmetic
|
||||
_ = col1 + col2
|
||||
_ = col1 - col2
|
||||
_ = col1 * col2
|
||||
_ = col1 / col2
|
||||
_ = col1 // col2
|
||||
_ = col1 % col2
|
||||
_ = col1**col2
|
||||
|
||||
# Comparisons
|
||||
_ = col1 < col2
|
||||
_ = col1 > col2
|
||||
_ = col1 <= col2
|
||||
_ = col1 >= col2
|
||||
_ = col1 != col2
|
||||
_ = col1 == col2
|
||||
|
||||
# Aggregate
|
||||
_ = col1.kurt()
|
||||
_ = col1.kurtosis()
|
||||
_ = col1.max()
|
||||
_ = col1.mean()
|
||||
_ = col1.median()
|
||||
_ = col1.min()
|
||||
_ = col1.mode()
|
||||
_ = col1.prod()
|
||||
_ = col1.product()
|
||||
_ = col1.std()
|
||||
_ = col1.sum()
|
||||
_ = col1.var()
|
||||
|
||||
# Groupby
|
||||
col_gb = col1.groupby(level=0)
|
||||
|
||||
_ = col_gb.kurt()
|
||||
_ = col_gb.max()
|
||||
_ = col_gb.mean()
|
||||
_ = col_gb.median()
|
||||
_ = col_gb.min()
|
||||
_ = col_gb.prod()
|
||||
_ = col_gb.std()
|
||||
_ = col_gb.sum()
|
||||
_ = col_gb.var()
|
||||
|
||||
# Attributes
|
||||
_ = df1.ndim # int
|
||||
_ = df1.size # int
|
||||
_ = df1.shape # (int, int)
|
||||
_ = col1.ndim # int
|
||||
_ = col1.size # int
|
||||
_ = col1.shape # (int)
|
||||
_ = col1.T # Column[int]
|
||||
|
||||
|
||||
# Misc
|
||||
_ = df1.head()
|
||||
_ = df1.tail()
|
||||
_ = col1.head()
|
||||
_ = col1.tail()
|
||||
4924
tests/cases/checker/09_frame_ops.py.ref.json
Normal file
4924
tests/cases/checker/09_frame_ops.py.ref.json
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user