Compare commits
83 Commits
d039a8e4b3
...
feat/simpl
| Author | SHA1 | Date | |
|---|---|---|---|
|
a143972ef1
|
|||
|
0c70048b62
|
|||
|
1c0c917873
|
|||
|
1f6189daa4
|
|||
|
66b585c3d6
|
|||
|
819ab3c2bf
|
|||
|
d8c0b17512
|
|||
|
6e06f9078e
|
|||
|
ece2e3a6a3
|
|||
|
74c07c9afb
|
|||
|
be2fd4c837
|
|||
|
1bc4c704c3
|
|||
|
0288a05901
|
|||
|
b14f46d405
|
|||
|
8e8ed62266
|
|||
|
2fce2f4bfc
|
|||
|
640f2d1771
|
|||
|
b48dfe5301
|
|||
|
0d5840a4ce
|
|||
|
3c92f0867d
|
|||
|
b5acae4078
|
|||
|
5d20f8ec3e
|
|||
|
955c2233ed
|
|||
|
ff69b65171
|
|||
|
8df01afd8c
|
|||
|
47b2dfdd73
|
|||
|
bd4d793ce0
|
|||
|
f7a36f61b6
|
|||
|
ad2fabf471
|
|||
|
a59a58d21a
|
|||
| 3260ae4a1e | |||
|
bd1c9581c7
|
|||
|
663642ea6c
|
|||
|
e2abc04fe4
|
|||
|
a4016b55ce
|
|||
|
1ea5da7024
|
|||
|
a017a8cf1f
|
|||
|
8fc5ab623e
|
|||
|
14007db846
|
|||
|
6ad2ce4b68
|
|||
|
9a276c34c7
|
|||
|
6e717a3f9e
|
|||
|
77aadfa264
|
|||
| c81287df7f | |||
|
ffccc1bedd
|
|||
|
d14f208897
|
|||
|
293953a078
|
|||
|
bccc96e4d0
|
|||
|
9db56adf56
|
|||
|
3f99563ac8
|
|||
|
b36896cc7b
|
|||
|
cb75878ae9
|
|||
|
a5fe985eb2
|
|||
|
e324f414e6
|
|||
|
256536562f
|
|||
|
64f4314f0d
|
|||
|
6f6245d283
|
|||
|
3392bc347d
|
|||
|
7e0319906a
|
|||
|
75bd203d4a
|
|||
|
db40198357
|
|||
|
d79e1dee18
|
|||
|
4ea400265c
|
|||
|
24bffdabd4
|
|||
|
d7bb6326de
|
|||
|
dbf6f9e2db
|
|||
|
3cdc9031d3
|
|||
|
00e2ca8fe3
|
|||
|
4efb01285c
|
|||
|
f84a19159f
|
|||
|
946b2e0d2e
|
|||
|
08dd7408ec
|
|||
|
b33fadf768
|
|||
|
7219109e5d
|
|||
|
cdf1725c26
|
|||
|
7074b074bc
|
|||
|
ede7272c09
|
|||
|
87d5e286d2
|
|||
|
c91b206791
|
|||
|
a31d295eb1
|
|||
|
0d20993f02
|
|||
|
5357ca8e58
|
|||
|
556765fd35
|
@@ -678,6 +678,10 @@ In the following example, a runtime check would be generated to ensure that the
|
||||
caption: [Typing of `cast` expression],
|
||||
)
|
||||
|
||||
#gc.warning[
|
||||
Assertions are statements inserted just before a statement using a `cast` expression. This means that the expression is evaluated _before_ its actual intended usage location, which might cause issues if you rely on logical operator short-circuiting. See @eager-eval for more information.
|
||||
]
|
||||
|
||||
There may be some cases where the cost of checking a value at runtime is simply not worth the safety, for example when dealing with a big dataset. If do wish so, you can use `unsafe_cast` which will only tell the type checker the type of the value, without generating a runtime assertion. This maps to the default behavior of `typing`'s own `cast` function.
|
||||
|
||||
If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a string, a list of literals, etc.), the assertion is evaluated _at compile-time_ and no runtime assertion is generated.
|
||||
@@ -695,3 +699,26 @@ If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a
|
||||
== Generating Stubs (`stubs`) <cmd-stubs>
|
||||
== Showing Type Judgements (`types`) <cmd-types>
|
||||
== Validating Definitions (`validate`) <cmd-validate>
|
||||
|
||||
= Known limitations <limitations>
|
||||
|
||||
== Eager evaluation in runtime assertions <eager-eval>
|
||||
|
||||
The process of generating assertions to ensure safety at runtime, mainly for `cast` expressions, leads to the creation of aliases for the expressions being casted. These alias definitions eagerly evaluate before the assertion, and most importantly before the real usage location. This means that you should avoid using `cast` expressions inside logical expressions like `and` or `or`, because the normal "short-circuit" behavior will be irrelevant to the evaluations of the operands.
|
||||
|
||||
For example:
|
||||
|
||||
#figure(
|
||||
```py
|
||||
def foo():
|
||||
print("Foo")
|
||||
return True
|
||||
def bar():
|
||||
print("Bar")
|
||||
return True
|
||||
result = foo() or bar()
|
||||
# Foo
|
||||
# Bar
|
||||
```,
|
||||
caption: [Runtime assertions may eagerly evaluate expressions and bypass logical operator's short-circuit],
|
||||
)
|
||||
|
||||
10
gen/midas.py
10
gen/midas.py
@@ -157,4 +157,14 @@ class FunctionType:
|
||||
required: bool
|
||||
|
||||
|
||||
class FrameType:
|
||||
columns: list[Column]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Column:
|
||||
location: Optional[Location] = None
|
||||
name: Token
|
||||
type: Type
|
||||
|
||||
|
||||
###<
|
||||
|
||||
@@ -15,7 +15,7 @@ from midas.ast.location import Location
|
||||
###> MidasType | Type annotations | node
|
||||
class BaseType:
|
||||
base: str
|
||||
param: Optional[MidasType]
|
||||
args: tuple[MidasType, ...]
|
||||
|
||||
|
||||
class ConstraintType:
|
||||
@@ -174,6 +174,10 @@ class SliceExpr:
|
||||
step: Optional[Expr]
|
||||
|
||||
|
||||
class TupleExpr:
|
||||
items: tuple[Expr, ...]
|
||||
|
||||
|
||||
class RawExpr:
|
||||
expr: ast.expr
|
||||
|
||||
|
||||
@@ -265,6 +265,9 @@ class Type(ABC):
|
||||
@abstractmethod
|
||||
def visit_function_type(self, type: FunctionType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_frame_type(self, type: FrameType) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NamedType(Type):
|
||||
@@ -323,3 +326,17 @@ class FunctionType(Type):
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_function_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FrameType(Type):
|
||||
columns: list[Column]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Column:
|
||||
location: Optional[Location] = None
|
||||
name: Token
|
||||
type: Type
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_frame_type(self)
|
||||
|
||||
@@ -358,6 +358,25 @@ class MidasAstPrinter(
|
||||
arg.type.accept(self)
|
||||
self._write_line(f"required: {arg.required}", last=True)
|
||||
|
||||
def visit_frame_type(self, type: m.FrameType) -> None:
|
||||
self._write_line("FrameType")
|
||||
with self._child_level(single=True):
|
||||
self._write_line("columns")
|
||||
with self._child_level():
|
||||
for i, column in enumerate(type.columns):
|
||||
self._idx = i
|
||||
if i == len(type.columns) - 1:
|
||||
self._mark_last()
|
||||
self._print_frame_column(column)
|
||||
|
||||
def _print_frame_column(self, column: m.FrameType.Column) -> None:
|
||||
self._write_line("Column")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{column.name.lexeme}"')
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
column.type.accept(self)
|
||||
|
||||
|
||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
||||
def __init__(self, indent: int = 4):
|
||||
@@ -513,6 +532,23 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
res += "?"
|
||||
return res
|
||||
|
||||
def visit_frame_type(self, type: m.FrameType) -> str:
|
||||
res: str = self.indented("Frame[")
|
||||
if len(type.columns) != 0:
|
||||
res += "\n"
|
||||
self.level += 1
|
||||
columns: list[str] = []
|
||||
for column in type.columns:
|
||||
columns.append(self.indented(self._print_frame_column(column)))
|
||||
res += ",\n".join(columns)
|
||||
self.level -= 1
|
||||
res += "\n"
|
||||
res += "]"
|
||||
return res
|
||||
|
||||
def _print_frame_column(self, column: m.FrameType.Column) -> str:
|
||||
return f"{column.name.lexeme}: {column.type.accept(self)}"
|
||||
|
||||
|
||||
class PythonAstPrinter(
|
||||
AstPrinter,
|
||||
@@ -524,7 +560,13 @@ class PythonAstPrinter(
|
||||
self._write_line("BaseType")
|
||||
with self._child_level():
|
||||
self._write_line(f"base: {node.base}")
|
||||
self._write_optional_child("param", node.param, last=True)
|
||||
self._write_line("args:", last=True)
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(node.args):
|
||||
self._idx = i
|
||||
if i == len(node.args) - 1:
|
||||
self._mark_last()
|
||||
arg.accept(self)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||
self._write_line("ConstraintType")
|
||||
@@ -837,6 +879,17 @@ class PythonAstPrinter(
|
||||
self._write_optional_child("upper", expr.upper)
|
||||
self._write_optional_child("step", expr.step, last=True)
|
||||
|
||||
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
|
||||
self._write_line("TupleExpr")
|
||||
with self._child_level():
|
||||
self._write_line("items", last=True)
|
||||
with self._child_level():
|
||||
for i, item in enumerate(expr.items):
|
||||
self._idx = i
|
||||
if i == len(expr.items) - 1:
|
||||
self._mark_last()
|
||||
item.accept(self)
|
||||
|
||||
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
||||
self._write_line("RawExpr")
|
||||
with self._child_level(single=True):
|
||||
|
||||
@@ -44,7 +44,7 @@ class MidasType(ABC):
|
||||
@dataclass(frozen=True)
|
||||
class BaseType(MidasType):
|
||||
base: str
|
||||
param: Optional[MidasType]
|
||||
args: tuple[MidasType, ...]
|
||||
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_base_type(self)
|
||||
@@ -268,6 +268,9 @@ class Expr(ABC):
|
||||
@abstractmethod
|
||||
def visit_slice_expr(self, expr: SliceExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_tuple_expr(self, expr: TupleExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_raw_expr(self, expr: RawExpr) -> T: ...
|
||||
|
||||
@@ -402,6 +405,14 @@ class SliceExpr(Expr):
|
||||
return visitor.visit_slice_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TupleExpr(Expr):
|
||||
items: tuple[Expr, ...]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_tuple_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RawExpr(Expr):
|
||||
expr: ast.expr
|
||||
|
||||
@@ -178,4 +178,100 @@ extend dict[K, V] {
|
||||
// def __ior__: fn(value: SupportsKeysAndGetItem[K, V], /) -> dict[K, V]
|
||||
// def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V]
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
extend str {
|
||||
def capitalize: fn() -> str
|
||||
def casefold: fn() -> str
|
||||
def center: fn(width: int, fillchar: str?, /) -> str
|
||||
def count: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def count: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def count: fn(sub: str, start: None, end: int, /) -> int
|
||||
def count: fn(sub: str, start: int, end: int, /) -> int
|
||||
def encode: fn(encoding: str?, errors: str?) -> bytes
|
||||
def endswith: fn(suffix: str, start: None?, end: None?, /) -> bool
|
||||
def endswith: fn(suffix: str, start: int, end: None?, /) -> bool
|
||||
def endswith: fn(suffix: str, start: None, end: int, /) -> bool
|
||||
def endswith: fn(suffix: str, start: int, end: int, /) -> bool
|
||||
def expandtabs: fn(tabsize: int?) -> str
|
||||
def find: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def find: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def find: fn(sub: str, start: None, end: int, /) -> int
|
||||
def find: fn(sub: str, start: int, end: int, /) -> int
|
||||
// def format: fn(*args: object, **kwargs: object) -> str
|
||||
// def format_map: fn(mapping: _FormatMapMapping, /) -> str
|
||||
def index: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def index: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def index: fn(sub: str, start: None, end: int, /) -> int
|
||||
def index: fn(sub: str, start: int, end: int, /) -> int
|
||||
def isalnum: fn() -> bool
|
||||
def isalpha: fn() -> bool
|
||||
def isascii: fn() -> bool
|
||||
def isdecimal: fn() -> bool
|
||||
def isdigit: fn() -> bool
|
||||
def isidentifier: fn() -> bool
|
||||
def islower: fn() -> bool
|
||||
def isnumeric: fn() -> bool
|
||||
def isprintable: fn() -> bool
|
||||
def isspace: fn() -> bool
|
||||
def istitle: fn() -> bool
|
||||
def isupper: fn() -> bool
|
||||
def join: fn(iterable: list[str], /) -> str // TODO: use Iterable
|
||||
def ljust: fn(width: int, fillchar: str?, /) -> str
|
||||
def lower: fn() -> str
|
||||
def lstrip: fn(chars: None?, /) -> str
|
||||
def lstrip: fn(chars: str, /) -> str
|
||||
def partition: fn(sep: str, /) -> tuple[str, str, str]
|
||||
|
||||
def replace: fn(old: str, new: str, count: int?, /) -> str
|
||||
|
||||
def removeprefix: fn(prefix: str, /) -> str
|
||||
def removesuffix: fn(suffix: str, /) -> str
|
||||
def rfind: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def rfind: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def rfind: fn(sub: str, start: None, end: int, /) -> int
|
||||
def rfind: fn(sub: str, start: int, end: int, /) -> int
|
||||
def rindex: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def rindex: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def rindex: fn(sub: str, start: None, end: int, /) -> int
|
||||
def rindex: fn(sub: str, start: int, end: int, /) -> int
|
||||
def rjust: fn(width: int, fillchar: str?, /) -> str
|
||||
def rpartition: fn(sep: str, /) -> tuple[str, str, str]
|
||||
def rsplit: fn(sep: None?, maxsplit: int?) -> list[str]
|
||||
def rsplit: fn(sep: str, maxsplit: int?) -> list[str]
|
||||
def rstrip: fn(chars: None?, /) -> str
|
||||
def rstrip: fn(chars: str, /) -> str
|
||||
def split: fn(sep: None?, maxsplit: int?) -> list[str]
|
||||
def split: fn(sep: str, maxsplit: int?) -> list[str]
|
||||
def splitlines: fn(keepends: bool?) -> list[str]
|
||||
def startswith: fn(prefix: str, start: None?, end: None?, /) -> bool
|
||||
def startswith: fn(prefix: str, start: int, end: None?, /) -> bool
|
||||
def startswith: fn(prefix: str, start: None, end: int, /) -> bool
|
||||
def startswith: fn(prefix: str, start: int, end: int, /) -> bool
|
||||
def strip: fn(chars: None?, /) -> str
|
||||
def strip: fn(chars: str, /) -> str
|
||||
def swapcase: fn() -> str
|
||||
def title: fn() -> str
|
||||
// def translate: fn(table: _TranslateTable, /) -> str
|
||||
def upper: fn() -> str
|
||||
def zfill: fn(width: int, /) -> str
|
||||
def __add__: fn(value: str, /) -> str
|
||||
// Incompatible with Sequence.__contains__
|
||||
def __contains__: fn(key: str, /) -> bool
|
||||
def __eq__: fn(value: object, /) -> bool
|
||||
def __ge__: fn(value: str, /) -> bool
|
||||
def __getitem__: fn(key: slice, /) -> str
|
||||
def __getitem__: fn(key: int, /) -> str
|
||||
def __gt__: fn(value: str, /) -> bool
|
||||
def __hash__: fn() -> int
|
||||
// def __iter__: fn() -> Iterator[str]
|
||||
def __le__: fn(value: str, /) -> bool
|
||||
def __len__: fn() -> int
|
||||
def __lt__: fn(value: str, /) -> bool
|
||||
def __mod__: fn(value: Any, /) -> str
|
||||
def __mul__: fn(value: int, /) -> str
|
||||
def __ne__: fn(value: object, /) -> bool
|
||||
def __rmul__: fn(value: int, /) -> str
|
||||
def __getnewargs__: fn() -> tuple[str]
|
||||
def __format__: fn(format_spec: str, /) -> str
|
||||
}
|
||||
|
||||
@@ -14,8 +14,10 @@ if TYPE_CHECKING:
|
||||
from midas.checker.registry import TypesRegistry
|
||||
|
||||
|
||||
# Hard-coded subtype relationships between builtin types
|
||||
# Circular dependencies and diamond inheritance MUST be avoided
|
||||
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||
"object": {"float", "list", "dict", "str"},
|
||||
"object": {"float", "list", "dict", "str", "bytes", "tuple"},
|
||||
"float": {"int"},
|
||||
"int": {"bool"},
|
||||
}
|
||||
@@ -26,12 +28,15 @@ def define_builtins(reg: TypesRegistry):
|
||||
any = reg.define_type("Any", TopType())
|
||||
unit = reg.define_type("None", UnitType())
|
||||
object = reg.define_type("object", BaseType(name="object"))
|
||||
bytes = reg.define_type("bytes", BaseType(name="bytes"))
|
||||
bool = reg.define_type("bool", BaseType(name="bool"))
|
||||
int = reg.define_type("int", BaseType(name="int"))
|
||||
float = reg.define_type("float", BaseType(name="float"))
|
||||
str = reg.define_type("str", BaseType(name="str"))
|
||||
slice = reg.define_type("slice", BaseType(name="slice"))
|
||||
|
||||
tuple = reg.define_type("tuple", BaseType(name="tuple"))
|
||||
|
||||
list = reg.define_type(
|
||||
"list",
|
||||
GenericType(
|
||||
|
||||
484
midas/checker/dispatcher.py
Normal file
484
midas/checker/dispatcher.py
Normal file
@@ -0,0 +1,484 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Generic, Optional, Protocol, TypeVar, Union
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
DerivedType,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.checker.unifier import Unifier
|
||||
|
||||
|
||||
class HasLocation(Protocol):
|
||||
@property
|
||||
def location(self) -> Location: ...
|
||||
|
||||
|
||||
E = TypeVar("E", bound=HasLocation)
|
||||
|
||||
TypedExpr = tuple[E, Type]
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MappedArgument(Generic[E]):
|
||||
expr: E
|
||||
type: Type
|
||||
argument: Function.Argument
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class OverloadCandidate:
|
||||
function: Function
|
||||
mapped: list[MappedArgument]
|
||||
|
||||
|
||||
class CallError(StrEnum):
|
||||
INVALID_ARGS = "Invalid arguments"
|
||||
NO_MATCHING_OVERLOAD = "No matching overload"
|
||||
IMPOSSIBLE_UNIFICATION = "Parameters unification failed"
|
||||
NOT_CALLABLE = "Not callable"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class CallResult:
|
||||
error: Optional[CallError] = None
|
||||
result: Type = UnknownType()
|
||||
message: Optional[str] = None
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return self.error is None
|
||||
|
||||
@property
|
||||
def error_message(self) -> str:
|
||||
if self.message is not None:
|
||||
return self.message
|
||||
if self.error is not None:
|
||||
return str(self.error)
|
||||
return ""
|
||||
|
||||
|
||||
class CallDispatcher(Generic[E]):
|
||||
def __init__(self, types: TypesRegistry, reporter: FileReporter) -> None:
|
||||
self.types: TypesRegistry = types
|
||||
self.reporter: FileReporter = reporter
|
||||
self.logger: logging.Logger = logging.getLogger("CallDispatcher")
|
||||
|
||||
def set_reporter(self, reporter: FileReporter):
|
||||
self.reporter = reporter
|
||||
|
||||
def get_result(
|
||||
self,
|
||||
location: Location,
|
||||
callee: Type,
|
||||
positional: list[TypedExpr[E]],
|
||||
keywords: dict[str, TypedExpr[E]],
|
||||
report_errors: bool = True,
|
||||
) -> CallResult:
|
||||
"""Get the result type of a function call
|
||||
|
||||
If the function has overloads, the function will try to resolve the
|
||||
appropriate signature.
|
||||
Argument types are matched to the defined parameters.
|
||||
The function doesn't take the raw expression as a parameter to accommodate
|
||||
for desugared calls such as for operators.
|
||||
|
||||
Args:
|
||||
location (Location): the call location
|
||||
callee (Type): the called function
|
||||
positional (list[TypedExpr]): the list positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Type: the return type of the call, or `None` if either
|
||||
the call is invalid or no overload matched the arguments uniquely
|
||||
"""
|
||||
match callee:
|
||||
case Function() as function:
|
||||
valid: bool
|
||||
mapped: list[MappedArgument[E]]
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function, location, positional, keywords
|
||||
)
|
||||
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
||||
if not valid:
|
||||
return CallResult(error=CallError.INVALID_ARGS)
|
||||
return CallResult(result=function.returns)
|
||||
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
res = self._match_overload(
|
||||
overloads, location, positional, keywords, report_errors
|
||||
)
|
||||
if res[0] is None:
|
||||
return CallResult(
|
||||
error=CallError.NO_MATCHING_OVERLOAD,
|
||||
message=res[1],
|
||||
)
|
||||
return CallResult(result=res[0].returns)
|
||||
|
||||
case AppliedType(body=body):
|
||||
return self.get_result(
|
||||
location, body, positional, keywords, report_errors
|
||||
)
|
||||
|
||||
case UnknownType():
|
||||
return CallResult(result=UnknownType())
|
||||
|
||||
case DerivedType(type=base):
|
||||
return self.get_result(
|
||||
location, base, positional, keywords, report_errors
|
||||
)
|
||||
|
||||
case GenericType():
|
||||
unifier: Unifier = Unifier(self.types)
|
||||
pos: list[Type] = [a[1] for a in positional]
|
||||
kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()}
|
||||
unified: Optional[Type] = unifier.unify_call(callee, pos, kw)
|
||||
if unified is None:
|
||||
pos_str: str = ", ".join(str(t) for t in pos)
|
||||
kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items())
|
||||
message: str = (
|
||||
f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}"
|
||||
)
|
||||
if report_errors:
|
||||
self.reporter.error(location, message)
|
||||
return CallResult(
|
||||
error=CallError.IMPOSSIBLE_UNIFICATION,
|
||||
message=message,
|
||||
)
|
||||
return self.get_result(
|
||||
location,
|
||||
unified,
|
||||
positional,
|
||||
keywords,
|
||||
report_errors,
|
||||
)
|
||||
|
||||
case _:
|
||||
message: str = f"{callee} ({callee.__class__.__name__}) is not callable"
|
||||
if report_errors:
|
||||
self.reporter.error(location, message)
|
||||
return CallResult(
|
||||
error=CallError.NOT_CALLABLE,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def _unwrap_function(
|
||||
self,
|
||||
callee: Type,
|
||||
positional: list[TypedExpr[E]],
|
||||
keywords: dict[str, TypedExpr[E]],
|
||||
) -> Union[tuple[Function, None], tuple[None, CallError]]:
|
||||
match callee:
|
||||
case DerivedType(type=base):
|
||||
return self._unwrap_function(base, positional, keywords)
|
||||
|
||||
case GenericType():
|
||||
unifier: Unifier = Unifier(self.types)
|
||||
unified: Optional[Type] = unifier.unify_call(
|
||||
callee,
|
||||
[a[1] for a in positional],
|
||||
{k: v[1] for k, v in keywords.items()},
|
||||
)
|
||||
if unified is None:
|
||||
return None, CallError.IMPOSSIBLE_UNIFICATION
|
||||
return self._unwrap_function(unified, positional, keywords)
|
||||
|
||||
case Function():
|
||||
return callee, None
|
||||
|
||||
case AppliedType(body=body):
|
||||
return self._unwrap_function(body, positional, keywords)
|
||||
|
||||
case _:
|
||||
return None, CallError.NOT_CALLABLE
|
||||
|
||||
def _are_arguments_valid(
|
||||
self,
|
||||
arguments: list[MappedArgument[E]],
|
||||
report_errors: bool = True,
|
||||
) -> bool:
|
||||
"""Check whether the passed argument types correspond to their matched parameter definitions
|
||||
|
||||
Args:
|
||||
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
||||
"""
|
||||
valid: bool = True
|
||||
for arg in arguments:
|
||||
if not self.types.is_subtype(arg.type, arg.argument.type):
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
)
|
||||
valid = False
|
||||
return valid
|
||||
|
||||
def _match_overload(
|
||||
self,
|
||||
overloads: list[Type],
|
||||
location: Location,
|
||||
positional: list[TypedExpr[E]],
|
||||
keywords: dict[str, TypedExpr[E]],
|
||||
report_errors: bool = True,
|
||||
) -> Union[tuple[Function, None], tuple[None, str]]:
|
||||
"""Try and resolve the appropriate overload for the given arguments
|
||||
|
||||
Args:
|
||||
overloads (list[Type]): the list of possible overloads
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional[Function]: the resolved function signature if it can be
|
||||
determined unambiguously, or `None`.
|
||||
"""
|
||||
candidates: list[OverloadCandidate] = []
|
||||
errors: list[CallError] = []
|
||||
for overload in overloads:
|
||||
function, unwrap_error = self._unwrap_function(
|
||||
overload, positional, keywords
|
||||
)
|
||||
if function is None:
|
||||
errors.append(unwrap_error) # type: ignore
|
||||
continue
|
||||
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function=function,
|
||||
location=location,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
report_errors=False,
|
||||
)
|
||||
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||
candidates.append(
|
||||
OverloadCandidate(
|
||||
function=function,
|
||||
mapped=mapped,
|
||||
)
|
||||
)
|
||||
|
||||
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||
kw_types: str = ", ".join(
|
||||
f"{name}: {type}" for name, (_, type) in keywords.items()
|
||||
)
|
||||
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||
|
||||
n_candidates: int = len(candidates)
|
||||
|
||||
# Exactly 1 match -> return it
|
||||
if n_candidates == 1:
|
||||
return candidates[0].function, None
|
||||
|
||||
# No match -> invalid call
|
||||
if n_candidates == 0:
|
||||
overloads_str: str = ", ".join(map(str, overloads))
|
||||
errors_str: str = ", ".join(errors)
|
||||
message: str = (
|
||||
f"No matching overload in [{overloads_str}] {for_args} (errors: {errors_str})"
|
||||
)
|
||||
if report_errors:
|
||||
self.reporter.error(location, message)
|
||||
return None, message
|
||||
|
||||
# Multiple matches -> see if one <: all others (more specific)
|
||||
for i1, c1 in enumerate(candidates):
|
||||
mapped1: list[MappedArgument[E]] = c1.mapped
|
||||
best_match: bool = True
|
||||
for i2, c2 in enumerate(candidates):
|
||||
if i1 == i2:
|
||||
continue
|
||||
mapped2: list[MappedArgument[E]] = c2.mapped
|
||||
if not self._are_mapped_subtypes(mapped1, mapped2):
|
||||
best_match = False
|
||||
break
|
||||
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
||||
if best_match:
|
||||
return c1.function, None
|
||||
|
||||
candidates_str: str = ", ".join(
|
||||
str(candidate.function) for candidate in candidates
|
||||
)
|
||||
message: str = f"Multiple matching overloads {for_args}: {candidates_str}"
|
||||
if report_errors:
|
||||
self.reporter.error(location, message)
|
||||
return None, message
|
||||
|
||||
def map_call_arguments(
|
||||
self,
|
||||
function: Function,
|
||||
location: Location,
|
||||
positional: list[TypedExpr[E]],
|
||||
keywords: dict[str, TypedExpr[E]],
|
||||
report_errors: bool = True,
|
||||
) -> tuple[bool, list[MappedArgument]]:
|
||||
"""Map call arguments to a function's parameters as defined in its signature
|
||||
|
||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||
with the arguments passed at the call site
|
||||
|
||||
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
||||
unless `report_errors` is set to `False`
|
||||
|
||||
Args:
|
||||
function (Function): the function definition
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
||||
the call is valid and the list of mapped arguments
|
||||
"""
|
||||
set_args: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
arg.name for arg in function.pos_args + function.args if arg.required
|
||||
]
|
||||
required_keyword: list[str] = [
|
||||
arg.name for arg in function.kw_args if arg.required
|
||||
]
|
||||
|
||||
mapped: list[MappedArgument[E]] = []
|
||||
|
||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||
mixed_params: list[Function.Argument] = list(function.args)
|
||||
kw_params: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.kw_args
|
||||
}
|
||||
|
||||
valid_call: bool = True
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Argument
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg[0].location, "Too many positional arguments"
|
||||
)
|
||||
valid_call = False
|
||||
break
|
||||
name: str = param.name
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Argument
|
||||
if name not in kw_params:
|
||||
if report_errors:
|
||||
if name in set_args:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
)
|
||||
else:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||
)
|
||||
valid_call = False
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
def join_args(args: list[str]) -> str:
|
||||
args = list(map(lambda a: f"'{a}'", args))
|
||||
if len(args) == 0:
|
||||
return ""
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
args: str = join_args(required_positional)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required positional argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
args: str = join_args(required_keyword)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required keyword argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
return valid_call, mapped
|
||||
|
||||
def _are_mapped_subtypes(
|
||||
self, mapped1: list[MappedArgument[E]], mapped2: list[MappedArgument[E]]
|
||||
) -> bool:
|
||||
"""Check whether the given argument mappings are subtype/supertype of one another
|
||||
|
||||
This function checks whether the argument mappings `mapped1` are subtypes
|
||||
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
||||
of the corresponding parameter in `mapped2`, `False` is returned.
|
||||
|
||||
This is used to check whether a given overload is
|
||||
a more specific function/ a subtype of another.
|
||||
|
||||
Args:
|
||||
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
||||
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
||||
|
||||
Returns:
|
||||
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
||||
"""
|
||||
by_expr: dict[E, Type] = {}
|
||||
for arg in mapped1:
|
||||
by_expr[arg.expr] = arg.argument.type
|
||||
|
||||
for arg in mapped2:
|
||||
type2: Type = arg.argument.type
|
||||
type1: Type = by_expr[arg.expr]
|
||||
if not self.types.is_subtype(type1, type2):
|
||||
return False
|
||||
return True
|
||||
70
midas/checker/frames/column_groupby_methods.py
Normal file
70
midas/checker/frames/column_groupby_methods.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import ColumnGroupBy, Function, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
groupby: ColumnGroupBy
|
||||
groupby_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.groupby_expr, self.groupby)
|
||||
|
||||
|
||||
class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
bool_ = self.types.get_type("bool")
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="numeric_only",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="skipna",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=2,
|
||||
name="engine",
|
||||
type=self.types.get_type("str"),
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=3,
|
||||
name="engine_kwargs",
|
||||
type=self.types.get_type("dict"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=call.groupby.column,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
62
midas/checker/frames/column_manager.py
Normal file
62
midas/checker/frames/column_manager.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frames.column_groupby_methods import Call as GroupByCall
|
||||
from midas.checker.frames.column_groupby_methods import ColumnGroupByMethodRegistry
|
||||
from midas.checker.frames.column_methods import Call, ColumnMethodRegistry
|
||||
from midas.checker.types import ColumnGroupBy, ColumnType, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
|
||||
|
||||
class ColumnManager:
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
self.method_resolver: ColumnMethodRegistry = ColumnMethodRegistry(self.typer)
|
||||
self.groupby_method_resolver: ColumnGroupByMethodRegistry = (
|
||||
ColumnGroupByMethodRegistry(self.typer)
|
||||
)
|
||||
|
||||
def call(
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
column: ColumnType,
|
||||
column_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: Call = Call(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
column=column,
|
||||
column_expr=column_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.method_resolver.call(method, call)
|
||||
|
||||
def groupby_call(
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
groupby: ColumnGroupBy,
|
||||
groupby_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: GroupByCall = GroupByCall(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
groupby=groupby,
|
||||
groupby_expr=groupby_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.groupby_method_resolver.call(method, call)
|
||||
351
midas/checker/frames/column_methods.py
Normal file
351
midas/checker/frames/column_methods.py
Normal file
@@ -0,0 +1,351 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
Function,
|
||||
GenericType,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
column: ColumnType
|
||||
column_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.column_expr, self.column)
|
||||
|
||||
|
||||
class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
def _element_binary_op(self, call: Call, method: str) -> ColumnType:
|
||||
"""Compute the result of an element-wise binary operation
|
||||
|
||||
This function delegates to the inner types for computing the resulting
|
||||
type.
|
||||
|
||||
Args:
|
||||
call (Call): the call that triggered this resolution
|
||||
method (str): the method name
|
||||
|
||||
Returns:
|
||||
ColumnType: the resulting column type
|
||||
"""
|
||||
column2: Optional[ColumnType] = None
|
||||
|
||||
col_type1: Type = call.column.type
|
||||
new_column: Type = ColumnType(type=UnknownType())
|
||||
if len(call.positional) != 0:
|
||||
other: Type = call.positional[0][1]
|
||||
unfolded_other: Type = unfold_type(other)
|
||||
if isinstance(unfolded_other, ColumnType):
|
||||
column2 = unfolded_other
|
||||
col_type2: Type = column2.type
|
||||
|
||||
new_inner_type = self.typer.result_of_binary_op(
|
||||
location=call.location,
|
||||
expr=call.call_expr,
|
||||
left=(call.column_expr, col_type1),
|
||||
right=(call.positional[0][0], col_type2),
|
||||
method=method,
|
||||
)
|
||||
new_column = ColumnType(type=new_inner_type)
|
||||
return new_column
|
||||
|
||||
def _element_wise(self, call: Call, method: str) -> Type:
|
||||
# TODO: support add with scalar
|
||||
|
||||
# Build signature with new column type and generic operand
|
||||
param_type: TypeVar = TypeVar(name="T", bound=None)
|
||||
signature = GenericType(
|
||||
name="add",
|
||||
params=[param_type],
|
||||
body=Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="other",
|
||||
type=ColumnType(type=param_type),
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
returns=self._element_binary_op(call, method),
|
||||
),
|
||||
)
|
||||
|
||||
# Map arguments and compute result type
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
if result.is_valid:
|
||||
self._assert_same_length(
|
||||
call.call_expr, call.column_expr, call.positional[0][0]
|
||||
)
|
||||
|
||||
return result.result
|
||||
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__add__")
|
||||
|
||||
@method("sub", "__sub__")
|
||||
def sub(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__sub__")
|
||||
|
||||
@method("mul", "__mul__")
|
||||
def mul(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mul__")
|
||||
|
||||
@method("div", "truediv", "__truediv__")
|
||||
def truediv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__truediv__")
|
||||
|
||||
@method("floordiv", "__floordiv__")
|
||||
def floordiv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__floordiv__")
|
||||
|
||||
@method("mod", "__mod__")
|
||||
def mod(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mod__")
|
||||
|
||||
@method("pow", "__pow__")
|
||||
def pow(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__pow__")
|
||||
|
||||
@method("lt", "__lt__")
|
||||
def lt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__lt__")
|
||||
|
||||
@method("gt", "__gt__")
|
||||
def gt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__gt__")
|
||||
|
||||
@method("le", "__le__")
|
||||
def le(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__le__")
|
||||
|
||||
@method("ge", "__ge__")
|
||||
def ge(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ge__")
|
||||
|
||||
@method("ne", "__ne__")
|
||||
def ne(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ne__")
|
||||
|
||||
@method("eq", "__eq__")
|
||||
def eq(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__eq__")
|
||||
|
||||
def _statistical(self, call: Call, kwargs: list[Function.Argument] = []) -> Type:
|
||||
signature = Function(
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
returns=ColumnType(type=TopType()),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method("kurtosis", "kurt")
|
||||
def kurtosis(self, call: Call) -> Type:
|
||||
return self._statistical(call)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._statistical(call)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._statistical(call)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._statistical(call)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._statistical(call)
|
||||
|
||||
@method()
|
||||
def mode(self, call: Call) -> Type:
|
||||
return self._statistical(call)
|
||||
|
||||
@method("product", "prod")
|
||||
def product(self, call: Call) -> Type:
|
||||
return self._statistical(call)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._statistical(
|
||||
call,
|
||||
[
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="ddof",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._statistical(call)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._statistical(
|
||||
call,
|
||||
[
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="var",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def groupby(self, call: Call) -> Type:
|
||||
bool_: Type = self.types.get_type("bool")
|
||||
function: Function = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="by",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="level",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=2,
|
||||
name="as_index",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=3,
|
||||
name="sort",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=4,
|
||||
name="group_keys",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=5,
|
||||
name="observed",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=6,
|
||||
name="dropna",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=ColumnGroupBy(column=call.column),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=function,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def _assert_same_length(self, call_expr: p.Expr, column1: p.Expr, column2: p.Expr):
|
||||
func_name: str = "__midas_column_same_length__"
|
||||
self.assertions.define(
|
||||
func_name,
|
||||
ast.FunctionDef(
|
||||
name=func_name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
ast.arg(arg="column1"),
|
||||
ast.arg(arg="column2"),
|
||||
],
|
||||
kwonlyargs=[],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
),
|
||||
body=[
|
||||
ast.Return(
|
||||
value=ast.Compare(
|
||||
left=ast.Attribute(
|
||||
value=ast.Name(id="column1"),
|
||||
attr="size",
|
||||
),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="column2"),
|
||||
attr="size",
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
],
|
||||
decorator_list=[],
|
||||
),
|
||||
)
|
||||
self.assertions.add(
|
||||
bound_expr=call_expr,
|
||||
inputs=[column1, column2],
|
||||
builder=lambda c1, c2: ast.Call(
|
||||
func=ast.Name(id=func_name),
|
||||
args=[c1, c2],
|
||||
keywords=[],
|
||||
),
|
||||
message="Columns must have the same length",
|
||||
)
|
||||
195
midas/checker/frames/frame_groupby_methods.py
Normal file
195
midas/checker/frames/frame_groupby_methods.py
Normal file
@@ -0,0 +1,195 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import FrameGroupBy, Function, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
groupby: FrameGroupBy
|
||||
groupby_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.groupby_expr, self.groupby)
|
||||
|
||||
|
||||
class FrameGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
NAMED_ARGS: dict[str, str] = {
|
||||
"numeric_only": "bool",
|
||||
"skipna": "bool",
|
||||
"engine": "str",
|
||||
"engine_kwargs": "dict",
|
||||
}
|
||||
|
||||
def _aggregate(
|
||||
self, call: Call, args: list[str | tuple[str, str, bool]] = []
|
||||
) -> Type:
|
||||
real_args: list[Function.Argument] = []
|
||||
for i, arg in enumerate(args):
|
||||
match arg:
|
||||
case str() as name:
|
||||
arg = Function.Argument(
|
||||
pos=i,
|
||||
name=name,
|
||||
type=self.types.get_type(self.NAMED_ARGS[name]),
|
||||
required=False,
|
||||
)
|
||||
case (name, type, required):
|
||||
arg = Function.Argument(
|
||||
pos=i,
|
||||
name=name,
|
||||
type=self.types.get_type(type),
|
||||
required=required,
|
||||
)
|
||||
real_args.append(arg)
|
||||
|
||||
signature = Function(
|
||||
args=real_args,
|
||||
returns=call.groupby.frame,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def kurt(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"skipna",
|
||||
"numeric_only",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["numeric_only", "skipna", "engine", "engine_kwargs"],
|
||||
)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["numeric_only", "skipna"],
|
||||
)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def prod(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
(
|
||||
"ddof",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
"numeric_only",
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
(
|
||||
"var",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
"numeric_only",
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
242
midas/checker/frames/frame_manager.py
Normal file
242
midas/checker/frames/frame_manager.py
Normal file
@@ -0,0 +1,242 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, TypeGuard, cast
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frames.frame_groupby_methods import Call as GroupByCall
|
||||
from midas.checker.frames.frame_groupby_methods import FrameGroupByMethodRegistry
|
||||
from midas.checker.frames.frame_methods import Call, FrameMethodRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
FrameGroupBy,
|
||||
TupleType,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
|
||||
|
||||
def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
|
||||
return all(isinstance(expr, p.LiteralExpr) for expr in exprs)
|
||||
|
||||
|
||||
class FrameManager:
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
self.method_resolver: FrameMethodRegistry = FrameMethodRegistry(self.typer)
|
||||
self.groupby_method_resolver: FrameGroupByMethodRegistry = (
|
||||
FrameGroupByMethodRegistry(self.typer)
|
||||
)
|
||||
|
||||
def assign(
|
||||
self,
|
||||
reporter: FileReporter,
|
||||
location: Location,
|
||||
frame: DataFrameType,
|
||||
index: p.Expr,
|
||||
value_type: Type,
|
||||
) -> Type:
|
||||
match index:
|
||||
case p.LiteralExpr(value=str() as name):
|
||||
return self.assign_column(reporter, location, frame, name, value_type)
|
||||
|
||||
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
|
||||
isinstance(index.value, str) for index in indices
|
||||
):
|
||||
names: list[str] = [cast(str, index.value) for index in indices]
|
||||
|
||||
if not isinstance(value_type, TupleType):
|
||||
reporter.error(
|
||||
location,
|
||||
f"Cannot assign {type} to dataframe columns. Must be a tuple of columns",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
if len(names) != len(value_type.items):
|
||||
reporter.error(
|
||||
location,
|
||||
f"Wrong number of columns. Cannot assign {len(value_type.items)} to {len(names)} targets",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
new_frame: Type = frame
|
||||
for name, value in zip(names, value_type.items):
|
||||
new_frame = self.assign_column(
|
||||
reporter,
|
||||
location,
|
||||
new_frame,
|
||||
name,
|
||||
value,
|
||||
)
|
||||
if not isinstance(new_frame, DataFrameType):
|
||||
return new_frame
|
||||
return new_frame
|
||||
|
||||
case _:
|
||||
reporter.error(
|
||||
location, f"Invalid index type {index} on {frame} (assignment)"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def assign_column(
|
||||
self,
|
||||
reporter: FileReporter,
|
||||
location: Location,
|
||||
frame: DataFrameType,
|
||||
name: str,
|
||||
type: Type,
|
||||
) -> Type:
|
||||
if not isinstance(type, ColumnType):
|
||||
reporter.error(
|
||||
location,
|
||||
f"Cannot assign {type} to dataframe column. Must be a ColumnType",
|
||||
)
|
||||
return frame
|
||||
return self._set_column(frame, name, type)
|
||||
|
||||
def get(
|
||||
self,
|
||||
reporter: FileReporter,
|
||||
location: Location,
|
||||
frame: DataFrameType,
|
||||
index: p.Expr,
|
||||
) -> Type:
|
||||
match index:
|
||||
case p.LiteralExpr(value=str() as name):
|
||||
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
|
||||
if column is None:
|
||||
reporter.error(location, f"Unknown column '{name}' on {frame}")
|
||||
return UnknownType()
|
||||
return column
|
||||
|
||||
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
|
||||
isinstance(index.value, str) for index in indices
|
||||
):
|
||||
names: list[str] = [cast(str, index.value) for index in indices]
|
||||
columns: list[ColumnType] = []
|
||||
for name in names:
|
||||
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
|
||||
if column is None:
|
||||
reporter.error(location, f"Unknown column '{name}' on {frame}")
|
||||
return UnknownType()
|
||||
columns.append(column)
|
||||
return TupleType(items=tuple(columns))
|
||||
|
||||
case _:
|
||||
reporter.error(
|
||||
location, f"Invalid index type {index} on {frame} (access)"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def groupby_get(
|
||||
self,
|
||||
reporter: FileReporter,
|
||||
location: Location,
|
||||
groupby: FrameGroupBy,
|
||||
index: p.Expr,
|
||||
) -> Type:
|
||||
result: Type = self.get(reporter, location, groupby.frame, index)
|
||||
match result:
|
||||
case ColumnType():
|
||||
result = ColumnGroupBy(column=result)
|
||||
case TupleType(items=columns):
|
||||
result = TupleType(
|
||||
items=tuple(
|
||||
ColumnGroupBy(column=cast(ColumnType, column))
|
||||
for column in columns
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _set_column(
|
||||
cls, frame: DataFrameType, name: str, column: ColumnType
|
||||
) -> DataFrameType:
|
||||
new_columns: list[DataFrameType.Column] = []
|
||||
index: int = len(frame.columns)
|
||||
replace: bool = False
|
||||
for i, col in enumerate(frame.columns):
|
||||
if col.name == name:
|
||||
index = i
|
||||
replace = True
|
||||
# TODO: check column type here to prevent changing it
|
||||
new_columns.append(col)
|
||||
|
||||
new_col: DataFrameType.Column = DataFrameType.Column(
|
||||
index=index,
|
||||
name=name,
|
||||
type=column,
|
||||
)
|
||||
if replace:
|
||||
new_columns[index] = new_col
|
||||
else:
|
||||
new_columns.append(new_col)
|
||||
|
||||
return DataFrameType(columns=new_columns)
|
||||
|
||||
@classmethod
|
||||
def _set_columns(
|
||||
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
|
||||
) -> DataFrameType:
|
||||
for name, col in zip(names, columns):
|
||||
frame = cls._set_column(frame, name, col)
|
||||
return frame
|
||||
|
||||
@classmethod
|
||||
def _get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
|
||||
for col in frame.columns:
|
||||
if col.name == name:
|
||||
return col.type
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_columns(
|
||||
cls, frame: DataFrameType, names: list[str]
|
||||
) -> list[Optional[ColumnType]]:
|
||||
return [cls._get_column(frame, name) for name in names]
|
||||
|
||||
def call(
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
frame: DataFrameType,
|
||||
frame_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: Call = Call(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
frame=frame,
|
||||
frame_expr=frame_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.method_resolver.call(method, call)
|
||||
|
||||
def groupby_call(
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
groupby: FrameGroupBy,
|
||||
groupby_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: GroupByCall = GroupByCall(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
groupby=groupby,
|
||||
groupby_expr=groupby_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.groupby_method_resolver.call(method, call)
|
||||
436
midas/checker/frames/frame_methods.py
Normal file
436
midas/checker/frames/frame_methods.py
Normal file
@@ -0,0 +1,436 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import (
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
OverloadedFunction,
|
||||
TopType,
|
||||
Type,
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
frame: DataFrameType
|
||||
frame_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.frame_expr, self.frame)
|
||||
|
||||
|
||||
class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
def _get_method_result(
|
||||
self,
|
||||
call: Call,
|
||||
column1: ColumnType,
|
||||
column2: ColumnType,
|
||||
method: str,
|
||||
) -> ColumnType:
|
||||
"""Get the result of calling a method on a column, passing a second
|
||||
|
||||
This function delegates to the main typer the resolution of the method
|
||||
member, as well as computing the result type. Because we don't have any
|
||||
AST expression for the individual columns, the frame expressions are
|
||||
used instead.
|
||||
|
||||
Args:
|
||||
call (Call): the call that triggered this resolution
|
||||
column1 (ColumnType): the first column, i.e. left operand
|
||||
column2 (ColumnType): the second column, i.e. right operand
|
||||
method (str): the method name
|
||||
|
||||
Returns:
|
||||
ColumnType: the resulting column.
|
||||
If the operation is invalid / doesn't exist,
|
||||
`ColumnType(type=UnknownType())` is returned
|
||||
"""
|
||||
|
||||
result: Type = self.typer.result_of_binary_op(
|
||||
location=call.location,
|
||||
expr=call.call_expr,
|
||||
left=(call.frame_expr, column1),
|
||||
right=(call.positional[0][0], column2),
|
||||
method=method,
|
||||
)
|
||||
|
||||
if not isinstance(result, ColumnType):
|
||||
return ColumnType(type=UnknownType())
|
||||
return result
|
||||
|
||||
def _element_binary_op(self, call: Call, method: str) -> DataFrameType:
|
||||
"""Compute the result of an element-wise binary operation
|
||||
|
||||
This function delegates to the matching columns for computing resulting
|
||||
types. Any column only present in one of the frames is forwarded as a
|
||||
generic `ColumnType(type=UnknownType())`. Columns only in the second
|
||||
frame are append at the end of the schema.
|
||||
|
||||
Args:
|
||||
call (Call): the call that triggered this resolution
|
||||
method (str): the method name
|
||||
|
||||
Returns:
|
||||
DataFrameType: the resulting frame type
|
||||
"""
|
||||
new_columns: list[DataFrameType.Column] = []
|
||||
|
||||
by_name: dict[str, DataFrameType.Column] = {}
|
||||
frame2: Optional[DataFrameType] = None
|
||||
# Get map of operand's columns by name, if there is at least 1 operand, which is a dataframe
|
||||
if len(call.positional) != 0:
|
||||
operand: TypedExpr = call.positional[0]
|
||||
unfolded_other: Type = unfold_type(operand[1])
|
||||
if isinstance(unfolded_other, DataFrameType):
|
||||
frame2 = unfolded_other
|
||||
by_name = {
|
||||
col.name: col for col in frame2.columns if col.name is not None
|
||||
}
|
||||
|
||||
# Compute new schema:
|
||||
# Step 1: for all columns in frame1:
|
||||
# - if present in frame2 -> delegate operation to columns
|
||||
# - if not -> add to schema as unknown
|
||||
in_frame1: set[str] = set()
|
||||
for column in call.frame.columns:
|
||||
if column.name is not None:
|
||||
in_frame1.add(column.name)
|
||||
|
||||
col_type1: ColumnType = column.type
|
||||
col_type: ColumnType = ColumnType(type=UnknownType())
|
||||
if column.name in by_name:
|
||||
column2 = by_name[column.name]
|
||||
col_type2: ColumnType = column2.type
|
||||
|
||||
col_type = self._get_method_result(call, col_type1, col_type2, method)
|
||||
|
||||
new_column = DataFrameType.Column(
|
||||
index=column.index,
|
||||
name=column.name,
|
||||
type=col_type,
|
||||
)
|
||||
new_columns.append(new_column)
|
||||
|
||||
# Step 2: for all columns in frame2
|
||||
# - if not in frame1 -> add to schema as unknown
|
||||
if frame2 is not None:
|
||||
for column in frame2.columns:
|
||||
if column.name in in_frame1:
|
||||
continue
|
||||
new_columns.append(
|
||||
DataFrameType.Column(
|
||||
index=len(new_columns),
|
||||
name=column.name,
|
||||
type=ColumnType(type=UnknownType()),
|
||||
)
|
||||
)
|
||||
|
||||
return DataFrameType(columns=new_columns)
|
||||
|
||||
def _element_wise(self, call: Call, method: str) -> Type:
|
||||
# TODO: support scalar, sequence, Series, dict operand
|
||||
# Build signature with new schema and generic operand
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="other",
|
||||
type=DataFrameType(columns=[]),
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
returns=self._element_binary_op(call, method),
|
||||
)
|
||||
|
||||
# Map arguments and compute result type
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
if result.is_valid:
|
||||
self._assert_same_length(
|
||||
call.call_expr, call.frame_expr, call.positional[0][0]
|
||||
)
|
||||
|
||||
return result.result
|
||||
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__add__")
|
||||
|
||||
@method("sub", "__sub__")
|
||||
def sub(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__sub__")
|
||||
|
||||
@method("mul", "__mul__")
|
||||
def mul(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mul__")
|
||||
|
||||
@method("div", "truediv", "__truediv__")
|
||||
def truediv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__truediv__")
|
||||
|
||||
@method("floordiv", "__floordiv__")
|
||||
def floordiv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__floordiv__")
|
||||
|
||||
@method("mod", "__mod__")
|
||||
def mod(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mod__")
|
||||
|
||||
@method("pow", "__pow__")
|
||||
def pow(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__pow__")
|
||||
|
||||
@method("lt", "__lt__")
|
||||
def lt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__lt__")
|
||||
|
||||
@method("gt", "__gt__")
|
||||
def gt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__gt__")
|
||||
|
||||
@method("le", "__le__")
|
||||
def le(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__le__")
|
||||
|
||||
@method("ge", "__ge__")
|
||||
def ge(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ge__")
|
||||
|
||||
@method("ne", "__ne__")
|
||||
def ne(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ne__")
|
||||
|
||||
@method("eq", "__eq__")
|
||||
def eq(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__eq__")
|
||||
|
||||
def _aggregate(self, call: Call, kwargs: list[Function.Argument] = []) -> Type:
|
||||
with_axis = Function(
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
returns=ColumnType(type=TopType()),
|
||||
)
|
||||
without_axis = Function(
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=self.types.get_type("None"),
|
||||
required=True,
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
returns=TopType(),
|
||||
)
|
||||
overload = OverloadedFunction(
|
||||
overloads=[
|
||||
with_axis,
|
||||
without_axis,
|
||||
]
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=overload,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method("kurtosis", "kurt")
|
||||
def kurtosis(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def mode(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method("product", "prod")
|
||||
def product(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="ddof",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="var",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def groupby(self, call: Call) -> Type:
|
||||
bool_: Type = self.types.get_type("bool")
|
||||
function: Function = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="by",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="level",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=2,
|
||||
name="as_index",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=3,
|
||||
name="sort",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=4,
|
||||
name="group_keys",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=5,
|
||||
name="observed",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=6,
|
||||
name="dropna",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=FrameGroupBy(frame=call.frame),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=function,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr):
|
||||
func_name: str = "__midas_frame_same_length__"
|
||||
self.assertions.define(
|
||||
func_name,
|
||||
ast.FunctionDef(
|
||||
name=func_name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
ast.arg(arg="frame1"),
|
||||
ast.arg(arg="frame2"),
|
||||
],
|
||||
kwonlyargs=[],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
),
|
||||
body=[
|
||||
ast.Return(
|
||||
value=ast.Compare(
|
||||
left=ast.Attribute(
|
||||
value=ast.Name(id="frame1"),
|
||||
attr="size",
|
||||
),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="frame2"),
|
||||
attr="size",
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
],
|
||||
decorator_list=[],
|
||||
),
|
||||
)
|
||||
self.assertions.add(
|
||||
bound_expr=call_expr,
|
||||
inputs=[frame1, frame2],
|
||||
builder=lambda f1, f2: ast.Call(
|
||||
func=ast.Name(id=func_name),
|
||||
args=[f1, f2],
|
||||
keywords=[],
|
||||
),
|
||||
message="DataFrames must have the same length",
|
||||
)
|
||||
100
midas/checker/frames/utils.py
Normal file
100
midas/checker/frames/utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Optional,
|
||||
Protocol,
|
||||
Self,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallDispatcher
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import Type, UnknownType
|
||||
from midas.generator.collector import AssertionCollector
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
|
||||
|
||||
class _MethodRegistryMeta(type):
|
||||
_methods: dict[str, Callable[..., Type]] = {}
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
name: str,
|
||||
bases: tuple[type, ...],
|
||||
namespace: dict[str, Any],
|
||||
):
|
||||
new_class = super().__new__(cls, name, bases, namespace)
|
||||
new_class._methods = {}
|
||||
for attr in namespace.values():
|
||||
if callable(attr) and hasattr(attr, "__method_names__"):
|
||||
for name in attr.__method_names__: # type: ignore
|
||||
new_class._methods[name] = attr # type: ignore
|
||||
return new_class
|
||||
|
||||
|
||||
class MethodCall(Protocol):
|
||||
@property
|
||||
def location(self) -> Location: ...
|
||||
|
||||
@property
|
||||
def call_expr(self) -> p.Expr: ...
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=MethodCall)
|
||||
|
||||
|
||||
class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
|
||||
@property
|
||||
def reporter(self) -> FileReporter:
|
||||
return self.typer.reporter
|
||||
|
||||
@property
|
||||
def types(self) -> TypesRegistry:
|
||||
return self.typer.types
|
||||
|
||||
@property
|
||||
def dispatcher(self) -> CallDispatcher[p.Expr]:
|
||||
return self.typer.dispatcher
|
||||
|
||||
@property
|
||||
def assertions(self) -> AssertionCollector:
|
||||
return self.typer.assertions
|
||||
|
||||
def call(self, method: str, call: T) -> Type:
|
||||
func: Optional[Callable[[Self, T], Type]] = self._methods.get(method)
|
||||
if func is None:
|
||||
self.reporter.warning(
|
||||
call.location, f"Unknown method {method} on {call.subject[1]}"
|
||||
)
|
||||
return UnknownType()
|
||||
return func(self, call)
|
||||
|
||||
|
||||
_Self = TypeVar("_Self", bound=MethodRegistry[Any])
|
||||
Method = Callable[[_Self, T], Type]
|
||||
|
||||
|
||||
def method(*names: str) -> Callable[[Method[_Self, T]], Method[_Self, T]]:
|
||||
def wrapper(func: Method[_Self, T]) -> Method[_Self, T]:
|
||||
names_: tuple[str, ...] = names
|
||||
if len(names_) == 0:
|
||||
names_ = (func.__name__,)
|
||||
setattr(func, "__method_names__", names_)
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
@@ -6,25 +6,25 @@ from typing import Optional
|
||||
import midas.ast.midas as m
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.builtins import define_builtins
|
||||
from midas.checker.dispatcher import CallDispatcher, CallResult
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
|
||||
from midas.checker.preamble import Preamble
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter, Reporter
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
ColumnType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
DataFrameType,
|
||||
DerivedType,
|
||||
ExtensionType,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
Predicate,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
from midas.checker.variance import VarianceInferrer
|
||||
from midas.lexer.midas import MidasLexer
|
||||
@@ -39,9 +39,6 @@ class TypedParamSpec:
|
||||
kw: list[Function.Argument]
|
||||
|
||||
|
||||
TypedExpr = tuple[m.Expr, Type]
|
||||
|
||||
|
||||
class ReturnException(Exception):
|
||||
pass
|
||||
|
||||
@@ -65,8 +62,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||
self.logger: logging.Logger = logging.getLogger("MidasTyper")
|
||||
self.reporter: FileReporter = reporter.for_file(None)
|
||||
|
||||
self.types: TypesRegistry = types
|
||||
self.dispatcher: CallDispatcher[m.Expr] = CallDispatcher[m.Expr](
|
||||
self.types, self.reporter
|
||||
)
|
||||
|
||||
self._local_variables: dict[str, TypeVar] = {}
|
||||
|
||||
self._predicate_params: dict[str, Type] = {}
|
||||
@@ -81,8 +81,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
|
||||
self._preamble: Environment = Preamble(self.types)
|
||||
|
||||
def set_reporter(self, reporter: FileReporter):
|
||||
self.reporter = reporter
|
||||
self.dispatcher.set_reporter(reporter)
|
||||
|
||||
def process(self, source: str, path: Optional[str]):
|
||||
self.reporter = self.reporter.for_file(path)
|
||||
reporter: FileReporter = self.reporter.for_file(path)
|
||||
self.set_reporter(reporter)
|
||||
|
||||
lexer: MidasLexer = MidasLexer(source)
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser: MidasParser = MidasParser(tokens)
|
||||
@@ -257,13 +263,13 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
result: Optional[Type] = self._get_call_result(
|
||||
location,
|
||||
operation,
|
||||
[(right_expr, right)],
|
||||
{},
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=location,
|
||||
callee=operation,
|
||||
positional=[(right_expr, right)],
|
||||
keywords={},
|
||||
)
|
||||
return result or UnknownType()
|
||||
return result.result
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
|
||||
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
|
||||
@@ -283,31 +289,29 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
result: Optional[Type] = self._get_call_result(
|
||||
expr.location,
|
||||
operation,
|
||||
[],
|
||||
{},
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=expr.location,
|
||||
callee=operation,
|
||||
positional=[],
|
||||
keywords={},
|
||||
)
|
||||
return result or UnknownType()
|
||||
return result.result
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> Type:
|
||||
callee: Type = expr.callee.accept(self)
|
||||
positional: list[TypedExpr] = [
|
||||
positional: list[tuple[m.Expr, Type]] = [
|
||||
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||
]
|
||||
keywords: dict[str, TypedExpr] = {
|
||||
keywords: dict[str, tuple[m.Expr, Type]] = {
|
||||
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
||||
}
|
||||
return (
|
||||
self._get_call_result(
|
||||
expr.location,
|
||||
callee,
|
||||
positional,
|
||||
keywords,
|
||||
)
|
||||
or UnknownType()
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=expr.location,
|
||||
callee=callee,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> Type:
|
||||
object: Type = expr.expr.accept(self)
|
||||
@@ -408,6 +412,18 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
|
||||
)
|
||||
|
||||
def visit_frame_type(self, type: m.FrameType) -> Type:
|
||||
def process_column(i: int, col: m.FrameType.Column) -> DataFrameType.Column:
|
||||
return DataFrameType.Column(
|
||||
index=i,
|
||||
name=col.name.lexeme,
|
||||
type=ColumnType(type=col.type.accept(self)),
|
||||
)
|
||||
|
||||
return DataFrameType(
|
||||
columns=[process_column(i, col) for i, col in enumerate(type.columns)]
|
||||
)
|
||||
|
||||
def _resolve_type_params(self, params: list[m.TypeParam]):
|
||||
vars: list[TypeVar] = []
|
||||
for param in params:
|
||||
@@ -419,343 +435,3 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
self._local_variables[name] = var
|
||||
vars.append(var)
|
||||
return vars
|
||||
|
||||
def _get_call_result(
|
||||
self,
|
||||
location: Location,
|
||||
callee: Type,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
report_errors: bool = True,
|
||||
) -> Optional[Type]:
|
||||
"""Get the result type of a function call
|
||||
|
||||
If the function has overloads, the function will try to resolve the
|
||||
appropriate signature.
|
||||
Argument types are matched to the defined parameters.
|
||||
The function doesn't take the raw expression as a parameter to accommodate
|
||||
for desugared calls such as for operators.
|
||||
|
||||
Args:
|
||||
location (Location): the call location
|
||||
callee (Type): the called function
|
||||
positional (list[TypedExpr]): the list positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Type: the return type of the call, or `None` if either
|
||||
the call is invalid or no overload matched the arguments uniquely
|
||||
"""
|
||||
match callee:
|
||||
case Function() as function:
|
||||
valid: bool
|
||||
mapped: list[MappedArgument]
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function, location, positional, keywords
|
||||
)
|
||||
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
||||
if not valid:
|
||||
return None
|
||||
return function.returns
|
||||
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
function = self._match_overload(
|
||||
overloads, location, positional, keywords, report_errors
|
||||
)
|
||||
if function is None:
|
||||
return None
|
||||
return function.returns
|
||||
|
||||
case AppliedType(body=body):
|
||||
return self._get_call_result(
|
||||
location, body, positional, keywords, report_errors
|
||||
)
|
||||
|
||||
case UnknownType():
|
||||
return UnknownType()
|
||||
|
||||
case _:
|
||||
if report_errors:
|
||||
self.reporter.error(location, f"{callee} is not callable")
|
||||
return None
|
||||
|
||||
def _are_arguments_valid(
|
||||
self,
|
||||
arguments: list[MappedArgument],
|
||||
report_errors: bool = True,
|
||||
) -> bool:
|
||||
"""Check whether the passed argument types correspond to their matched parameter definitions
|
||||
|
||||
Args:
|
||||
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
||||
"""
|
||||
valid: bool = True
|
||||
for arg in arguments:
|
||||
if not self.types.is_subtype(arg.type, arg.argument.type):
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
)
|
||||
valid = False
|
||||
return valid
|
||||
|
||||
def _match_overload(
|
||||
self,
|
||||
overloads: list[Type],
|
||||
location: Location,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
report_errors: bool = True,
|
||||
) -> Optional[Function]:
|
||||
"""Try and resolve the appropriate overload for the given arguments
|
||||
|
||||
Args:
|
||||
overloads (list[Type]): the list of possible overloads
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional[Function]: the resolved function signature if it can be
|
||||
determined unambiguously, or `None`.
|
||||
"""
|
||||
candidates: list[OverloadCandidate] = []
|
||||
for overload in overloads:
|
||||
function: Type = unfold_type(overload)
|
||||
if not isinstance(function, Function):
|
||||
if report_errors:
|
||||
self.logger.error(
|
||||
f"Overload is not a function: {overload} is {function}"
|
||||
)
|
||||
continue
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function=function,
|
||||
location=location,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
report_errors=False,
|
||||
)
|
||||
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||
candidates.append(
|
||||
OverloadCandidate(
|
||||
function=function,
|
||||
mapped=mapped,
|
||||
)
|
||||
)
|
||||
|
||||
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||
kw_types: str = ", ".join(
|
||||
f"{name}: {type}" for name, (_, type) in keywords.items()
|
||||
)
|
||||
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||
|
||||
n_candidates: int = len(candidates)
|
||||
|
||||
# Exactly 1 match -> return it
|
||||
if n_candidates == 1:
|
||||
return candidates[0].function
|
||||
|
||||
# No match -> invalid call
|
||||
if n_candidates == 0:
|
||||
overloads_str: str = ", ".join(map(str, overloads))
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"No matching overload in [{overloads_str}] {for_args}",
|
||||
)
|
||||
return None
|
||||
|
||||
# Multiple matches -> see if one <: all others (more specific)
|
||||
for i1, c1 in enumerate(candidates):
|
||||
mapped1: list[MappedArgument] = c1.mapped
|
||||
best_match: bool = True
|
||||
for i2, c2 in enumerate(candidates):
|
||||
if i1 == i2:
|
||||
continue
|
||||
mapped2: list[MappedArgument] = c2.mapped
|
||||
if not self._are_mapped_subtypes(mapped1, mapped2):
|
||||
best_match = False
|
||||
break
|
||||
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
||||
if best_match:
|
||||
return c1.function
|
||||
|
||||
candidates_str: str = ", ".join(
|
||||
str(candidate.function) for candidate in candidates
|
||||
)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Multiple matching overloads {for_args}: {candidates_str}",
|
||||
)
|
||||
return None
|
||||
|
||||
def map_call_arguments(
|
||||
self,
|
||||
function: Function,
|
||||
location: Location,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
report_errors: bool = True,
|
||||
) -> tuple[bool, list[MappedArgument]]:
|
||||
"""Map call arguments to a function's parameters as defined in its signature
|
||||
|
||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||
with the arguments passed at the call site
|
||||
|
||||
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
||||
unless `report_errors` is set to `False`
|
||||
|
||||
Args:
|
||||
function (Function): the function definition
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
||||
the call is valid and the list of mapped arguments
|
||||
"""
|
||||
set_args: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
arg.name for arg in function.pos_args + function.args if arg.required
|
||||
]
|
||||
required_keyword: list[str] = [
|
||||
arg.name for arg in function.kw_args if arg.required
|
||||
]
|
||||
|
||||
mapped: list[MappedArgument] = []
|
||||
|
||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||
mixed_params: list[Function.Argument] = list(function.args)
|
||||
kw_params: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.kw_args
|
||||
}
|
||||
|
||||
valid_call: bool = True
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Argument
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg[0].location, "Too many positional arguments"
|
||||
)
|
||||
valid_call = False
|
||||
break
|
||||
name: str = param.name
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Argument
|
||||
if name not in kw_params:
|
||||
if report_errors:
|
||||
if name in set_args:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
)
|
||||
else:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||
)
|
||||
valid_call = False
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
def join_args(args: list[str]) -> str:
|
||||
args = list(map(lambda a: f"'{a}'", args))
|
||||
if len(args) == 0:
|
||||
return ""
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
args: str = join_args(required_positional)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required positional argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
args: str = join_args(required_keyword)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required keyword argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
return valid_call, mapped
|
||||
|
||||
def _are_mapped_subtypes(
|
||||
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
|
||||
) -> bool:
|
||||
"""Check whether the given argument mappings are subtype/supertype of one another
|
||||
|
||||
This function checks whether the argument mappings `mapped1` are subtypes
|
||||
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
||||
of the corresponding parameter in `mapped2`, `False` is returned.
|
||||
|
||||
This is used to check whether a given overload is
|
||||
a more specific function/ a subtype of another.
|
||||
|
||||
Args:
|
||||
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
||||
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
||||
|
||||
Returns:
|
||||
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
||||
"""
|
||||
by_expr: dict[m.Expr, Type] = {}
|
||||
for arg in mapped1:
|
||||
by_expr[arg.expr] = arg.argument.type
|
||||
|
||||
for arg in mapped2:
|
||||
type2: Type = arg.argument.type
|
||||
type1: Type = by_expr[arg.expr]
|
||||
if not self.types.is_subtype(type1, type2):
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -41,7 +41,7 @@ PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = {
|
||||
|
||||
|
||||
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
|
||||
# TokenType.PLUS: "__add__",
|
||||
TokenType.PLUS: "__add__",
|
||||
TokenType.MINUS: "__sub__",
|
||||
TokenType.STAR: "__mul__",
|
||||
TokenType.SLASH: "__truediv__",
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import Function, GenericType, TopType, Type, TypeVar, UnitType
|
||||
from midas.checker.types import (
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -17,7 +25,7 @@ class Preamble(Environment):
|
||||
def __init__(self, types: TypesRegistry) -> None:
|
||||
super().__init__()
|
||||
self._types: TypesRegistry = types
|
||||
self._python_funcs: dict[str, Callable] = {}
|
||||
self._python_funcs: dict[str, Callable[..., Any]] = {}
|
||||
|
||||
self._def_type_constructor("object", object)
|
||||
self._def_type_constructor("float", float)
|
||||
@@ -34,7 +42,7 @@ class Preamble(Environment):
|
||||
# TODO: use sink
|
||||
self._def_function(
|
||||
name="print",
|
||||
pos=[Param("object", TopType())],
|
||||
pos=[Param("object", TopType(), required=False)],
|
||||
returns=UnitType(),
|
||||
py_function=print,
|
||||
)
|
||||
@@ -64,11 +72,48 @@ class Preamble(Environment):
|
||||
pos=[Param("prompt", TopType(), required=False)],
|
||||
returns=self._types.get_type("str"),
|
||||
)
|
||||
self._def_function(
|
||||
name="len",
|
||||
pos=[Param("object", TopType())],
|
||||
returns=self._types.get_type("int"),
|
||||
)
|
||||
|
||||
T = TypeVar(name="T", bound=None)
|
||||
self._def_overloads(
|
||||
name="max",
|
||||
py_function=max,
|
||||
signatures=[
|
||||
(
|
||||
[Param("arg1", T), Param("arg2", T)],
|
||||
[],
|
||||
[],
|
||||
T,
|
||||
[T],
|
||||
),
|
||||
([Param("iterable", self._list_of(T))], [], [], T, [T]),
|
||||
],
|
||||
)
|
||||
self._def_overloads(
|
||||
name="min",
|
||||
py_function=min,
|
||||
signatures=[
|
||||
(
|
||||
[Param("arg1", T), Param("arg2", T)],
|
||||
[],
|
||||
[],
|
||||
T,
|
||||
[T],
|
||||
),
|
||||
([Param("iterable", self._list_of(T))], [], [], T, [T]),
|
||||
],
|
||||
)
|
||||
|
||||
def _list_of(self, item_type: Type) -> Type:
|
||||
return self._types.apply_generic(self._types.get_type("list"), [item_type])
|
||||
|
||||
def _def_type_constructor(self, name: str, py_function: Optional[Callable] = None):
|
||||
def _def_type_constructor(
|
||||
self, name: str, py_function: Optional[Callable[..., Any]] = None
|
||||
):
|
||||
# TODO: more specific arg types
|
||||
self._def_function(
|
||||
name=name,
|
||||
@@ -121,7 +166,7 @@ class Preamble(Environment):
|
||||
kw: list[Param] = [],
|
||||
returns: Type = UnitType(),
|
||||
type_vars: list[TypeVar] = [],
|
||||
py_function: Optional[Callable] = None,
|
||||
py_function: Optional[Callable[..., Any]] = None,
|
||||
):
|
||||
function: Type = self._make_function(
|
||||
name=name,
|
||||
@@ -135,5 +180,31 @@ class Preamble(Environment):
|
||||
if py_function is not None:
|
||||
self._python_funcs[name] = py_function
|
||||
|
||||
def get_py_func(self, name: str) -> Optional[Callable]:
|
||||
def _def_overloads(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
signatures: list[
|
||||
tuple[list[Param], list[Param], list[Param], Type, list[TypeVar]]
|
||||
],
|
||||
py_function: Optional[Callable[..., Any]] = None,
|
||||
):
|
||||
overloads: list[Type] = []
|
||||
for pos, mixed, kw, returns, type_vars in signatures:
|
||||
overloads.append(
|
||||
self._make_function(
|
||||
name=name,
|
||||
pos=pos,
|
||||
mixed=mixed,
|
||||
kw=kw,
|
||||
returns=returns,
|
||||
type_vars=type_vars,
|
||||
)
|
||||
)
|
||||
function: Type = OverloadedFunction(overloads=overloads)
|
||||
self.define(name, function)
|
||||
if py_function is not None:
|
||||
self._python_funcs[name] = py_function
|
||||
|
||||
def get_py_func(self, name: str) -> Optional[Callable[..., Any]]:
|
||||
return self._python_funcs.get(name)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,8 +7,10 @@ from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ColumnType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
DataFrameType,
|
||||
DerivedType,
|
||||
ExtensionType,
|
||||
Function,
|
||||
@@ -16,6 +18,7 @@ from midas.checker.types import (
|
||||
OverloadedFunction,
|
||||
Predicate,
|
||||
TopType,
|
||||
TupleType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
@@ -110,6 +113,15 @@ class TypesRegistry:
|
||||
raise ValueError(f"Predicate {name} already defined")
|
||||
self._predicates[name] = predicate
|
||||
|
||||
def is_builtin_subtype(self, name1: str, name2: str) -> bool:
|
||||
subtypes: set[str] = BUILTIN_SUBTYPES.get(name2, set())
|
||||
if name1 in subtypes:
|
||||
return True
|
||||
for subtype in subtypes:
|
||||
if self.is_builtin_subtype(name1, subtype):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||
"""Check whether `type1` is a subtype of `type2`
|
||||
|
||||
@@ -147,7 +159,7 @@ class TypesRegistry:
|
||||
return self.is_subtype(base1, type2)
|
||||
|
||||
case (BaseType(name=name1), BaseType(name=name2)):
|
||||
return name1 in BUILTIN_SUBTYPES.get(name2, set())
|
||||
return self.is_builtin_subtype(name1, name2)
|
||||
|
||||
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
||||
for k, t in props2.items():
|
||||
@@ -157,6 +169,24 @@ class TypesRegistry:
|
||||
return False
|
||||
return True
|
||||
|
||||
case (DataFrameType(columns=columns1), DataFrameType(columns=columns2)):
|
||||
# TODO: check order?
|
||||
by_name1: dict[str, DataFrameType.Column] = {
|
||||
col.name: col for col in columns1 if col.name is not None
|
||||
}
|
||||
for col2 in columns2:
|
||||
if col2.name not in by_name1:
|
||||
return False
|
||||
if not self.is_subtype(by_name1[col2.name].type, col2.type):
|
||||
return False
|
||||
return True
|
||||
|
||||
case (ColumnType(type=inner1), ColumnType(type=inner2)):
|
||||
# TODO: invariant, replace ColumnType with simple GenericType
|
||||
if not self.are_equivalent(inner1, inner2):
|
||||
return False
|
||||
return True
|
||||
|
||||
case (Function(), Function()):
|
||||
return self.is_func_subtype(type1, type2)
|
||||
|
||||
@@ -187,6 +217,9 @@ class TypesRegistry:
|
||||
|
||||
return False
|
||||
|
||||
def are_equivalent(self, type1: Type, type2: Type) -> bool:
|
||||
return self.is_subtype(type1, type2) and self.is_subtype(type2, type1)
|
||||
|
||||
# TODO: verify the logic in here
|
||||
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
|
||||
"""Check whether a function is a subtype of another
|
||||
@@ -323,6 +356,9 @@ class TypesRegistry:
|
||||
body=substitute_typevars(body, substitutions),
|
||||
)
|
||||
|
||||
case BaseType(name="tuple"):
|
||||
return TupleType(items=tuple(args))
|
||||
|
||||
case _:
|
||||
raise ValueError(f"{type} is not a generic type")
|
||||
|
||||
|
||||
@@ -128,6 +128,10 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
|
||||
case p.GetExpr():
|
||||
target.accept(self)
|
||||
|
||||
case p.SubscriptExpr():
|
||||
target.accept(self)
|
||||
|
||||
case _:
|
||||
raise Exception(f"Unsupported assignment to {target}")
|
||||
|
||||
@@ -232,5 +236,9 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
if expr.step is not None:
|
||||
self.resolve(expr.step)
|
||||
|
||||
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
|
||||
for item in expr.items:
|
||||
self.resolve(item)
|
||||
|
||||
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
||||
pass
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from typing import Optional, assert_never
|
||||
from typing import Optional, assert_never, cast
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.ast.printer import MidasPrinter
|
||||
@@ -156,6 +156,53 @@ class ConstraintType:
|
||||
return f"{self.type} where {printer.print(self.constraint)}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TupleType:
|
||||
items: tuple[Type, ...]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"({', '.join(map(str, self.items))})"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ColumnType:
|
||||
type: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Column[{self.type}]"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class DataFrameType:
|
||||
columns: list[Column]
|
||||
|
||||
def __str__(self) -> str:
|
||||
schema: list[str] = [f"{col.name}: {col.type}" for col in self.columns]
|
||||
return f"Frame[{', '.join(schema)}]"
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Column:
|
||||
index: int
|
||||
name: Optional[str]
|
||||
type: ColumnType
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class FrameGroupBy:
|
||||
frame: DataFrameType
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"FrameGroupBy[{self.frame}]"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ColumnGroupBy:
|
||||
column: ColumnType
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ColumnGroupBy[{self.column}]"
|
||||
|
||||
|
||||
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
def sub_argument(arg: Function.Argument):
|
||||
return Function.Argument(
|
||||
@@ -165,6 +212,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
required=arg.required,
|
||||
)
|
||||
|
||||
def sub_column(col: DataFrameType.Column):
|
||||
return DataFrameType.Column(
|
||||
index=col.index,
|
||||
name=col.name,
|
||||
type=cast(ColumnType, substitute_typevars(col.type, substitutions)),
|
||||
)
|
||||
|
||||
match type:
|
||||
case TopType():
|
||||
return type
|
||||
@@ -252,6 +306,31 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
body=substitute_typevars(body, substitutions),
|
||||
)
|
||||
|
||||
case TupleType(items=items):
|
||||
return TupleType(
|
||||
items=tuple(substitute_typevars(item, substitutions) for item in items),
|
||||
)
|
||||
|
||||
case ColumnType(type=items_type):
|
||||
return ColumnType(
|
||||
type=substitute_typevars(items_type, substitutions),
|
||||
)
|
||||
|
||||
case DataFrameType(columns=columns):
|
||||
return DataFrameType(
|
||||
columns=list(map(sub_column, columns)),
|
||||
)
|
||||
|
||||
case FrameGroupBy(frame=frame):
|
||||
return FrameGroupBy(
|
||||
frame=cast(DataFrameType, substitute_typevars(frame, substitutions))
|
||||
)
|
||||
|
||||
case ColumnGroupBy(column=column):
|
||||
return ColumnGroupBy(
|
||||
column=cast(ColumnType, substitute_typevars(column, substitutions))
|
||||
)
|
||||
|
||||
case UnknownType() | UnitType():
|
||||
return type
|
||||
|
||||
@@ -319,6 +398,21 @@ def to_annotation(type: Type) -> str:
|
||||
case ConstraintType():
|
||||
return str(type)
|
||||
|
||||
case TupleType(items=items):
|
||||
return f"Tuple[{', '.join(map(to_annotation, items))}]"
|
||||
|
||||
case ColumnType():
|
||||
return "pd.Series"
|
||||
|
||||
case DataFrameType():
|
||||
return "pd.DataFrame"
|
||||
|
||||
case FrameGroupBy():
|
||||
return "pd.api.typing.DataFrameGroupBy"
|
||||
|
||||
case ColumnGroupBy():
|
||||
return "pd.api.typing.SeriesGroupBy"
|
||||
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
@@ -344,4 +438,9 @@ Type = (
|
||||
| GenericType
|
||||
| AppliedType
|
||||
| ConstraintType
|
||||
| TupleType
|
||||
| ColumnType
|
||||
| DataFrameType
|
||||
| FrameGroupBy
|
||||
| ColumnGroupBy
|
||||
)
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import Optional
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
Function,
|
||||
GenericType,
|
||||
TopType,
|
||||
@@ -98,6 +100,30 @@ class Unifier:
|
||||
|
||||
return substitutions
|
||||
|
||||
case (
|
||||
DataFrameType(columns=template_columns),
|
||||
DataFrameType(columns=concrete_columns),
|
||||
) if len(template_columns) == len(concrete_columns):
|
||||
substitutions: dict[str, Type] = {}
|
||||
for template_column, concrete_column in zip(
|
||||
template_columns, concrete_columns
|
||||
):
|
||||
if template_column.index != concrete_column or (
|
||||
template_column.name != concrete_column.name
|
||||
):
|
||||
self.logger.debug(
|
||||
f"Column mismatch: template={template_column}, concrete={concrete_column}"
|
||||
)
|
||||
raise UnificationError
|
||||
new_substistutions: dict[str, Type] = self.match(
|
||||
template_column.type, concrete_column.type
|
||||
)
|
||||
substitutions = self.merge(substitutions, new_substistutions)
|
||||
return substitutions
|
||||
|
||||
case (ColumnType(type=template_column), ColumnType(type=concrete_column)):
|
||||
return self.match(template_column, concrete_column)
|
||||
|
||||
case (Function(), Function()):
|
||||
mapped: list[tuple[Function.Argument, Function.Argument]] = (
|
||||
self.map_params(template, concrete)
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
from typing import Optional, TextIO
|
||||
|
||||
import click
|
||||
|
||||
@@ -19,18 +19,23 @@ from midas.utils import TypedAST
|
||||
@click.command(help="Compile source")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||
@click.option("-s", "--stubs", type=str, multiple=True)
|
||||
@click.option("--ignore-errors", is_flag=True)
|
||||
def compile(
|
||||
file: TextIO,
|
||||
types: tuple[TextIO],
|
||||
stubs: tuple[str],
|
||||
ignore_errors: bool,
|
||||
):
|
||||
source: str = file.read()
|
||||
source_path: Path = Path(file.name).resolve()
|
||||
|
||||
checker = TypeChecker()
|
||||
for types_file in types:
|
||||
checker.import_midas(Path(types_file.name).resolve())
|
||||
type_files: list[tuple[Path, Optional[str]]] = []
|
||||
for i, types_file in enumerate(types):
|
||||
in_path: Path = Path(types_file.name).resolve()
|
||||
checker.import_midas(in_path)
|
||||
type_files.append((in_path, stubs[i] if i < len(stubs) else None))
|
||||
|
||||
typed_ast: TypedAST = checker.type_check_source(source, str(source_path))
|
||||
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
|
||||
@@ -43,4 +48,4 @@ def compile(
|
||||
sys.exit(1)
|
||||
|
||||
generator = Generator(workdir=source_path.parent, types=checker.types)
|
||||
generator.generate(typed_ast, source_path)
|
||||
generator.generate(typed_ast, source_path, type_files=type_files)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ast
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
from typing import Optional, TextIO
|
||||
|
||||
import black
|
||||
import click
|
||||
@@ -38,15 +38,17 @@ class Handler(FileSystemEventHandler):
|
||||
|
||||
@click.command(help="Generate stubs from Midas definitions")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||
@click.option("-o", "--output", type=click.File("w"))
|
||||
@click.option("-w", "--watch", is_flag=True)
|
||||
def stubs(
|
||||
file: TextIO,
|
||||
output: TextIO,
|
||||
output: Optional[TextIO],
|
||||
watch: bool,
|
||||
):
|
||||
source_path: Path = Path(file.name).resolve()
|
||||
out_path: Path = Path(output.name).resolve()
|
||||
out_path: Path = source_path.with_suffix(".pyi")
|
||||
if output is not None:
|
||||
out_path = Path(output.name).resolve()
|
||||
generate_stubs(source_path, out_path)
|
||||
|
||||
if watch:
|
||||
|
||||
@@ -134,9 +134,9 @@ class PythonHighlighter(
|
||||
|
||||
def visit_base_type(self, node: p.BaseType) -> None:
|
||||
self.wrap(node, "base-type")
|
||||
if node.param is not None:
|
||||
self.wrap(node.param, "param")
|
||||
node.param.accept(self)
|
||||
for arg in node.args:
|
||||
self.wrap(arg, "arg")
|
||||
arg.accept(self)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||
self.wrap(node, "constraint-type")
|
||||
@@ -247,6 +247,10 @@ class PythonHighlighter(
|
||||
if expr.step is not None:
|
||||
expr.step.accept(self)
|
||||
|
||||
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
|
||||
for item in expr.items:
|
||||
item.accept(self)
|
||||
|
||||
def visit_raw_expr(self, expr: p.RawExpr) -> None: ...
|
||||
|
||||
def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ...
|
||||
@@ -350,6 +354,14 @@ class MidasHighlighter(
|
||||
for param in spec.pos + spec.mixed + spec.kw:
|
||||
param.type.accept(self)
|
||||
|
||||
def visit_frame_type(self, type: m.FrameType) -> None:
|
||||
self.wrap(type, "frame")
|
||||
for column in type.columns:
|
||||
self._visit_frame_column(column)
|
||||
|
||||
def _visit_frame_column(self, column: m.FrameType.Column) -> None:
|
||||
self.wrap(column, "column")
|
||||
|
||||
|
||||
class DiagnosticsHighlighter(Highlighter):
|
||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
|
||||
|
||||
@@ -3,7 +3,7 @@ span {
|
||||
--col: 108, 233, 108;
|
||||
}
|
||||
|
||||
&.param {
|
||||
&.arg {
|
||||
--col: 103, 192, 224;
|
||||
}
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class DiagnosticPrinter:
|
||||
|
||||
loc: Location = diagnostic.location
|
||||
if loc.lineno != loc.end_lineno:
|
||||
print(diagnostic)
|
||||
self.print_multiline(lines, diagnostic, indent)
|
||||
return
|
||||
|
||||
start_offset: int = loc.col_offset
|
||||
@@ -95,3 +95,27 @@ class DiagnosticPrinter:
|
||||
print(indent_str + before + subject + after)
|
||||
print(indent_str + cursor)
|
||||
print()
|
||||
|
||||
def print_multiline(
|
||||
self, all_lines: list[str], diagnostic: Diagnostic, indent: int = 4
|
||||
):
|
||||
loc: Location = diagnostic.location
|
||||
lines: list[str] = all_lines[loc.lineno - 1 : loc.end_lineno]
|
||||
|
||||
start_offset: int = loc.col_offset
|
||||
end_offset: int = loc.end_col_offset or (start_offset + 1)
|
||||
|
||||
indent_str: str = " " * indent
|
||||
color: int = self.COLORS.get(diagnostic.type, Ansi.WHITE)
|
||||
res: str = indent_str + lines[0][:start_offset]
|
||||
res += Ansi.FG(color) + lines[0][start_offset:]
|
||||
for line in lines[1:-1]:
|
||||
res += "\n" + indent_str + line
|
||||
res += "\n" + indent_str + lines[-1][:end_offset]
|
||||
res += Ansi.RESET + lines[-1][end_offset:]
|
||||
|
||||
print(diagnostic.location_str + ":")
|
||||
print(res)
|
||||
print()
|
||||
print(Ansi.FG(color) + diagnostic.message + Ansi.RESET)
|
||||
print()
|
||||
|
||||
59
midas/generator/collector.py
Normal file
59
midas/generator/collector.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import midas.ast.python as p
|
||||
|
||||
AssertionBuilder = Callable[..., ast.expr]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Assertion:
|
||||
bound_expr: p.Expr
|
||||
inputs: list[p.Expr]
|
||||
builder: AssertionBuilder
|
||||
message: str
|
||||
|
||||
def is_bound_to(self, expr: p.Expr) -> bool:
|
||||
return expr == self.bound_expr
|
||||
|
||||
|
||||
class AssertionCollector:
|
||||
def __init__(self):
|
||||
self.assertions: list[Assertion] = []
|
||||
self.definitions: dict[str, ast.stmt] = {}
|
||||
|
||||
def add(
|
||||
self,
|
||||
bound_expr: p.Expr,
|
||||
inputs: list[p.Expr],
|
||||
builder: AssertionBuilder,
|
||||
message: str,
|
||||
):
|
||||
self.assertions.append(
|
||||
Assertion(
|
||||
bound_expr=bound_expr,
|
||||
inputs=inputs,
|
||||
builder=builder,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
|
||||
def remove(self, assertion: Assertion):
|
||||
try:
|
||||
self.assertions.remove(assertion)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def define(self, name: str, stmt: ast.stmt):
|
||||
if name not in self.definitions:
|
||||
self.definitions[name] = stmt
|
||||
|
||||
def get_definitions(self) -> list[ast.stmt]:
|
||||
return list(self.definitions.values())
|
||||
|
||||
def get_assertions(self) -> list[Assertion]:
|
||||
return self.assertions
|
||||
|
||||
def get_assertions_for(self, expr: p.Expr) -> list[Assertion]:
|
||||
return list(filter(lambda a: a.is_bound_to(expr), self.assertions))
|
||||
@@ -1,4 +1,5 @@
|
||||
import ast
|
||||
import logging
|
||||
import shutil
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
@@ -8,65 +9,96 @@ import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.printer import MidasPrinter
|
||||
from midas.checker.checker import TypeChecker
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
DataFrameType,
|
||||
DerivedType,
|
||||
ExtensionType,
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
TopType,
|
||||
TupleType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.generator.collector import Assertion, AssertionCollector
|
||||
from midas.generator.constraints import ConstraintGenerator
|
||||
from midas.generator.stubs import StubsGenerator
|
||||
from midas.utils import TypedAST
|
||||
|
||||
|
||||
@dataclass
|
||||
class Scope:
|
||||
pre_assertions: list[ast.stmt] = field(default_factory=list)
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
pre_assertions: list[ast.stmt] = field(default_factory=list[ast.stmt])
|
||||
aliases: list[str] = field(default_factory=list[str])
|
||||
|
||||
|
||||
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
IS_DATAFRAME_FUNC = "__midas_is_dataframe__"
|
||||
IS_COLUMN_FUNC = "__midas_is_column__"
|
||||
|
||||
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
|
||||
self.workdir: Path = workdir.resolve()
|
||||
self.build_dir: Path = self.workdir / "build" / "midas"
|
||||
self.rel_src_path: Path = Path()
|
||||
self.logger: logging.Logger = logging.getLogger("Generator")
|
||||
|
||||
self._typed_ast: TypedAST = TypedAST(
|
||||
stmts=[],
|
||||
judgements=[],
|
||||
evaluated_casts=[],
|
||||
assertions=AssertionCollector(),
|
||||
)
|
||||
self._alias_count: int = 0
|
||||
self._predicate_count: int = 0
|
||||
self._scopes: list[Scope] = []
|
||||
self._aliases: list[tuple[p.Expr, ast.expr]] = []
|
||||
|
||||
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
|
||||
self._constraints: list[tuple[m.Expr, ast.expr]] = []
|
||||
|
||||
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
|
||||
self.rel_src_path = src_path.resolve().relative_to(self.workdir)
|
||||
self.define_is_dataframe: bool = False
|
||||
self.define_is_column: bool = False
|
||||
|
||||
def set_src_path(self, path: Path):
|
||||
self.rel_src_path = path.resolve().relative_to(self.workdir)
|
||||
|
||||
def generate_ast(self, typed_ast: TypedAST) -> ast.AST:
|
||||
self._typed_ast = typed_ast
|
||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts, can_be_empty=True)
|
||||
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
|
||||
module = ast.Module(body=predicates + body, type_ignores=[])
|
||||
|
||||
body = predicates + body
|
||||
|
||||
if self.define_is_dataframe:
|
||||
body = [self._is_dataframe_definition()] + body
|
||||
|
||||
if self.define_is_column:
|
||||
body = [self._is_column_definition()] + body
|
||||
|
||||
module = ast.Module(body=body, type_ignores=[])
|
||||
module = ast.fix_missing_locations(module)
|
||||
return module
|
||||
|
||||
def generate(
|
||||
self, typed_ast: TypedAST, src_path: Path, out_path: Optional[Path] = None
|
||||
self,
|
||||
typed_ast: TypedAST,
|
||||
src_path: Path,
|
||||
out_path: Optional[Path] = None,
|
||||
type_files: Optional[list[tuple[Path, Optional[str]]]] = None,
|
||||
) -> Path:
|
||||
module: ast.AST = self.generate_ast(typed_ast, src_path)
|
||||
compiled: str = ast.unparse(module)
|
||||
self.set_src_path(src_path)
|
||||
if out_path is None:
|
||||
if self.build_dir.exists():
|
||||
shutil.rmtree(self.build_dir)
|
||||
@@ -78,43 +110,72 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
raise ValueError(
|
||||
f"Directory traversal, {self.rel_src_path} points outside of parent directory"
|
||||
)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
out_dir: Path = out_path.parent
|
||||
out_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if type_files is not None:
|
||||
for in_path, out_name in type_files:
|
||||
if out_name is None:
|
||||
out_name = in_path.stem
|
||||
self.generate_stubs(in_path, out_dir / f"{out_name}.py")
|
||||
|
||||
module: ast.AST = self.generate_ast(typed_ast)
|
||||
compiled: str = ast.unparse(module)
|
||||
|
||||
out_path.write_text(compiled)
|
||||
return out_path
|
||||
|
||||
def generate_stubs(self, in_path: Path, out_path: Path):
|
||||
checker = TypeChecker()
|
||||
checker.import_midas(in_path)
|
||||
generator = StubsGenerator(checker.types)
|
||||
module: ast.Module = generator.generate_stubs()
|
||||
module = ast.fix_missing_locations(module)
|
||||
output: str = ast.unparse(module)
|
||||
out_path.write_text(output)
|
||||
|
||||
def convert(self, expr: p.Expr) -> ast.expr:
|
||||
for expr2, alias in self._aliases:
|
||||
if expr2 == expr:
|
||||
return alias
|
||||
assertions = self._typed_ast.assertions.get_assertions_for(expr)
|
||||
if len(assertions) != 0:
|
||||
return self._apply_assertions(expr, assertions)
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> ast.expr:
|
||||
return ast.BinOp(
|
||||
left=expr.left.accept(self),
|
||||
left=self.convert(expr.left),
|
||||
op=expr.operator,
|
||||
right=expr.right.accept(self),
|
||||
right=self.convert(expr.right),
|
||||
)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> ast.expr:
|
||||
return ast.Compare(
|
||||
left=expr.left.accept(self),
|
||||
left=self.convert(expr.left),
|
||||
ops=[expr.operator],
|
||||
comparators=[expr.right.accept(self)],
|
||||
comparators=[self.convert(expr.right)],
|
||||
)
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> ast.expr:
|
||||
return ast.UnaryOp(
|
||||
op=expr.operator,
|
||||
operand=expr.right.accept(self),
|
||||
operand=self.convert(expr.right),
|
||||
)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=expr.callee.accept(self),
|
||||
args=[arg.accept(self) for arg in expr.arguments],
|
||||
func=self.convert(expr.callee),
|
||||
args=[self.convert(arg) for arg in expr.arguments],
|
||||
keywords=[
|
||||
ast.keyword(arg=name, value=arg.accept(self))
|
||||
ast.keyword(arg=name, value=self.convert(arg))
|
||||
for name, arg in expr.keywords.items()
|
||||
],
|
||||
)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> ast.expr:
|
||||
return ast.Attribute(
|
||||
value=expr.object.accept(self),
|
||||
value=self.convert(expr.object),
|
||||
attr=expr.name,
|
||||
)
|
||||
|
||||
@@ -127,51 +188,58 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> ast.expr:
|
||||
return ast.BoolOp(
|
||||
op=expr.operator,
|
||||
values=[expr.left.accept(self), expr.right.accept(self)],
|
||||
values=[self.convert(expr.left), self.convert(expr.right)],
|
||||
)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
||||
expr2: ast.expr = expr.expr.accept(self)
|
||||
expr2: ast.expr = self.convert(expr.expr)
|
||||
|
||||
if expr in self._typed_ast.evaluated_casts or expr.unsafe:
|
||||
return expr2
|
||||
|
||||
alias: ast.expr = self._make_alias(expr2)
|
||||
alias: ast.expr = self._make_alias(expr.expr, expr2)
|
||||
|
||||
type: Type = self._get_expr_type(expr)
|
||||
self._make_cast_asserts(expr.location, alias, type)
|
||||
asserts: list[ast.stmt] = self._make_cast_asserts(expr.location, alias, type)
|
||||
for assert_ in asserts:
|
||||
self._add_assert(assert_)
|
||||
|
||||
return alias
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
|
||||
return ast.IfExp(
|
||||
test=expr.test.accept(self),
|
||||
body=expr.if_true.accept(self),
|
||||
orelse=expr.if_false.accept(self),
|
||||
test=self.convert(expr.test),
|
||||
body=self.convert(expr.if_true),
|
||||
orelse=self.convert(expr.if_false),
|
||||
)
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> ast.expr:
|
||||
return ast.List(
|
||||
elts=[item.accept(self) for item in expr.items],
|
||||
elts=[self.convert(item) for item in expr.items],
|
||||
)
|
||||
|
||||
def visit_dict_expr(self, expr: p.DictExpr) -> ast.expr:
|
||||
return ast.Dict(
|
||||
keys=[key.accept(self) if key is not None else None for key in expr.keys],
|
||||
values=[value.accept(self) for value in expr.values],
|
||||
keys=[self.convert(key) if key is not None else None for key in expr.keys],
|
||||
values=[self.convert(value) for value in expr.values],
|
||||
)
|
||||
|
||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr:
|
||||
return ast.Subscript(
|
||||
value=expr.object.accept(self),
|
||||
slice=expr.index.accept(self),
|
||||
value=self.convert(expr.object),
|
||||
slice=self.convert(expr.index),
|
||||
)
|
||||
|
||||
def visit_slice_expr(self, expr: p.SliceExpr) -> ast.expr:
|
||||
return ast.Slice(
|
||||
lower=expr.lower.accept(self) if expr.lower is not None else None,
|
||||
upper=expr.upper.accept(self) if expr.upper is not None else None,
|
||||
step=expr.step.accept(self) if expr.step is not None else None,
|
||||
lower=self.convert(expr.lower) if expr.lower is not None else None,
|
||||
upper=self.convert(expr.upper) if expr.upper is not None else None,
|
||||
step=self.convert(expr.step) if expr.step is not None else None,
|
||||
)
|
||||
|
||||
def visit_tuple_expr(self, expr: p.TupleExpr) -> ast.expr:
|
||||
return ast.Tuple(
|
||||
elts=[self.convert(item) for item in expr.items],
|
||||
)
|
||||
|
||||
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
|
||||
@@ -179,7 +247,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt:
|
||||
return ast.Expr(
|
||||
value=stmt.expr.accept(self),
|
||||
value=self.convert(stmt.expr),
|
||||
)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> ast.stmt:
|
||||
@@ -192,12 +260,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
kwonlyargs=[ast.arg(arg=arg.name) for arg in stmt.kwonlyargs],
|
||||
kwarg=None,
|
||||
defaults=[
|
||||
arg.default.accept(self)
|
||||
self.convert(arg.default)
|
||||
for arg in stmt.posonlyargs + stmt.args
|
||||
if arg.default is not None
|
||||
],
|
||||
kw_defaults=[
|
||||
arg.default.accept(self) if arg.default is not None else None
|
||||
self.convert(arg.default) if arg.default is not None else None
|
||||
for arg in stmt.kwonlyargs
|
||||
],
|
||||
),
|
||||
@@ -211,20 +279,20 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> ast.stmt:
|
||||
return ast.Assign(
|
||||
targets=[target.accept(self) for target in stmt.targets],
|
||||
value=stmt.value.accept(self),
|
||||
targets=[self.convert(target) for target in stmt.targets],
|
||||
value=self.convert(stmt.value),
|
||||
)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> ast.stmt:
|
||||
return ast.Return(
|
||||
value=stmt.value.accept(self) if stmt.value is not None else None,
|
||||
value=self.convert(stmt.value) if stmt.value is not None else None,
|
||||
)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> ast.stmt:
|
||||
return ast.If(
|
||||
test=stmt.test.accept(self),
|
||||
test=self.convert(stmt.test),
|
||||
body=self._visit_body(stmt.body),
|
||||
orelse=self._visit_body(stmt.orelse),
|
||||
orelse=self._visit_body(stmt.orelse, can_be_empty=True),
|
||||
)
|
||||
|
||||
def visit_pass(self, stmt: p.Pass) -> ast.stmt:
|
||||
@@ -232,8 +300,8 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def visit_for_stmt(self, stmt: p.ForStmt) -> ast.stmt:
|
||||
return ast.For(
|
||||
target=stmt.target.accept(self),
|
||||
iter=stmt.iterator.accept(self),
|
||||
target=self.convert(stmt.target),
|
||||
iter=self.convert(stmt.iterator),
|
||||
body=self._visit_body(stmt.body),
|
||||
orelse=[],
|
||||
)
|
||||
@@ -241,7 +309,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
def visit_raw_stmt(self, stmt: p.RawStmt) -> ast.stmt:
|
||||
return stmt.stmt
|
||||
|
||||
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
|
||||
def _visit_body(
|
||||
self, stmts: list[p.Stmt], can_be_empty: bool = False
|
||||
) -> list[ast.stmt]:
|
||||
generated: list[ast.stmt] = []
|
||||
for stmt in stmts:
|
||||
scope = Scope()
|
||||
@@ -259,9 +329,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
# Remove redundant pass statements
|
||||
if len(generated) > 1:
|
||||
generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)]
|
||||
if len(generated) == 0 and not can_be_empty:
|
||||
generated = [ast.Pass()]
|
||||
return generated
|
||||
|
||||
def _make_alias(self, expr: ast.expr) -> ast.expr:
|
||||
def _make_alias(self, node: p.Expr, expr: ast.expr) -> ast.expr:
|
||||
name: str = f"__midas_a{self._alias_count}__"
|
||||
alias = ast.Name(id=name)
|
||||
self._alias_count += 1
|
||||
@@ -272,82 +344,182 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
value=expr,
|
||||
)
|
||||
)
|
||||
self._aliases.append((node, alias))
|
||||
return alias
|
||||
|
||||
def _add_assert(self, expr: ast.expr, message: str | ast.expr):
|
||||
def _build_assert(self, expr: ast.expr, message: str | ast.expr) -> ast.stmt:
|
||||
if isinstance(message, str):
|
||||
message = ast.Constant(value=message)
|
||||
self._scopes[-1].pre_assertions.append(
|
||||
ast.Assert(
|
||||
test=expr,
|
||||
msg=message,
|
||||
)
|
||||
return ast.Assert(
|
||||
test=expr,
|
||||
msg=message,
|
||||
)
|
||||
|
||||
def _add_assert(self, assertion: ast.stmt):
|
||||
self._scopes[-1].pre_assertions.append(assertion)
|
||||
|
||||
def _get_expr_type(self, query: p.Expr) -> Type:
|
||||
for expr, type in self._typed_ast.judgements:
|
||||
if expr == query:
|
||||
return type
|
||||
raise RuntimeError(f"Cannot get type judgement for {query}")
|
||||
|
||||
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
|
||||
def _make_cast_asserts(
|
||||
self, src_location: Location, expr: ast.expr, type: Type
|
||||
) -> list[ast.stmt]:
|
||||
match type:
|
||||
case UnknownType():
|
||||
pass
|
||||
case UnknownType() | TopType():
|
||||
return []
|
||||
|
||||
case BaseType(name=name):
|
||||
self._add_assert(
|
||||
ast.Call(
|
||||
func=ast.Name(id="isinstance"),
|
||||
args=[expr, ast.Name(id=name)],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_cast_assert_message(src_location, expr, type),
|
||||
)
|
||||
return [
|
||||
self._build_assert(
|
||||
ast.Call(
|
||||
func=ast.Name(id="isinstance"),
|
||||
args=[expr, ast.Name(id=name)],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_cast_assert_message(src_location, expr, type),
|
||||
)
|
||||
]
|
||||
|
||||
case DerivedType(type=base):
|
||||
self._make_cast_asserts(src_location, expr, base)
|
||||
return self._make_cast_asserts(src_location, expr, base)
|
||||
|
||||
case UnitType():
|
||||
self._add_assert(
|
||||
ast.Compare(
|
||||
left=expr,
|
||||
ops=[ast.Is()],
|
||||
comparators=[
|
||||
ast.Constant(value=None),
|
||||
],
|
||||
return [
|
||||
self._build_assert(
|
||||
ast.Compare(
|
||||
left=expr,
|
||||
ops=[ast.Is()],
|
||||
comparators=[
|
||||
ast.Constant(value=None),
|
||||
],
|
||||
),
|
||||
self._make_cast_assert_message(src_location, expr, type),
|
||||
),
|
||||
self._make_cast_assert_message(src_location, expr, type),
|
||||
)
|
||||
]
|
||||
|
||||
case AppliedType(body=body):
|
||||
self._make_cast_asserts(src_location, expr, body)
|
||||
return self._make_cast_asserts(src_location, expr, body)
|
||||
|
||||
case ConstraintType(type=base, constraint=constraint):
|
||||
self._make_cast_asserts(src_location, expr, base)
|
||||
self._make_constraint_assert(src_location, expr, constraint)
|
||||
asserts: list[ast.stmt] = self._make_cast_asserts(
|
||||
src_location, expr, base
|
||||
)
|
||||
asserts.append(
|
||||
self._make_constraint_assert(src_location, expr, constraint)
|
||||
)
|
||||
return asserts
|
||||
|
||||
case TypeVar(bound=bound):
|
||||
# TODO: check with type from arguments / use call-site context
|
||||
if bound is not None:
|
||||
self._make_cast_asserts(src_location, expr, bound)
|
||||
if bound is None:
|
||||
return []
|
||||
return self._make_cast_asserts(src_location, expr, bound)
|
||||
|
||||
case TupleType(items=items):
|
||||
asserts: list[ast.stmt] = [
|
||||
self._build_assert(
|
||||
ast.Call(
|
||||
func=ast.Name(id="isinstance"),
|
||||
args=[expr, ast.Name(id="tuple")],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_cast_assert_message(src_location, expr, type),
|
||||
),
|
||||
]
|
||||
assert isinstance(expr, ast.Tuple)
|
||||
for item, item_type in zip(expr.elts, items):
|
||||
asserts.extend(
|
||||
self._make_cast_asserts(src_location, item, item_type)
|
||||
)
|
||||
return asserts
|
||||
|
||||
case DataFrameType(columns=columns):
|
||||
self.define_is_dataframe = True
|
||||
asserts: list[ast.stmt] = [
|
||||
self._build_assert(
|
||||
ast.Call(
|
||||
func=ast.Name(id=self.IS_DATAFRAME_FUNC),
|
||||
args=[expr],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_cast_assert_message(
|
||||
src_location, expr, type, ": Not a dataframe"
|
||||
),
|
||||
),
|
||||
]
|
||||
for column in columns:
|
||||
asserts.append(
|
||||
self._build_assert(
|
||||
ast.Compare(
|
||||
left=ast.Constant(value=column.name),
|
||||
ops=[ast.In()],
|
||||
comparators=[expr],
|
||||
),
|
||||
self._make_cast_assert_message(
|
||||
src_location,
|
||||
expr,
|
||||
type,
|
||||
f": Missing column {column.name}",
|
||||
),
|
||||
)
|
||||
)
|
||||
asserts.extend(
|
||||
self._make_cast_asserts(
|
||||
src_location,
|
||||
ast.Subscript(
|
||||
value=expr, slice=ast.Constant(value=column.name)
|
||||
),
|
||||
column.type,
|
||||
)
|
||||
)
|
||||
return asserts
|
||||
|
||||
case ColumnType():
|
||||
self.define_is_column = True
|
||||
asserts: list[ast.stmt] = [
|
||||
self._build_assert(
|
||||
ast.Call(
|
||||
func=ast.Name(id=self.IS_COLUMN_FUNC),
|
||||
args=[expr],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_cast_assert_message(
|
||||
src_location, expr, type, ": Not a column"
|
||||
),
|
||||
),
|
||||
]
|
||||
inner_assert: Optional[ast.stmt] = self._make_column_inner_assert(
|
||||
src_location, expr, type
|
||||
)
|
||||
if inner_assert is not None:
|
||||
asserts.append(inner_assert)
|
||||
return asserts
|
||||
|
||||
case (
|
||||
TopType()
|
||||
| Function()
|
||||
Function()
|
||||
| OverloadedFunction()
|
||||
| ComplexType()
|
||||
| ExtensionType()
|
||||
| GenericType()
|
||||
| FrameGroupBy()
|
||||
| ColumnGroupBy()
|
||||
):
|
||||
raise NotImplementedError(f"Can't make assertion for type {type}")
|
||||
self.logger.warning(f"Can't make assertion for type {type}")
|
||||
return []
|
||||
|
||||
# Ensure exhaustiveness
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
def _make_cast_assert_message(
|
||||
self, location: Location, expr: ast.expr, type: Type
|
||||
self,
|
||||
location: Location,
|
||||
expr: ast.expr,
|
||||
type: Type,
|
||||
extra: Optional[str] = None,
|
||||
) -> ast.expr:
|
||||
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
|
||||
@@ -365,15 +537,15 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
),
|
||||
conversion=-1,
|
||||
),
|
||||
ast.Constant(f" to {type}"),
|
||||
ast.Constant(f" to {type}{extra or ''}"),
|
||||
]
|
||||
)
|
||||
|
||||
def _make_constraint_assert(
|
||||
self, src_location: Location, expr: ast.expr, constraint: m.Expr
|
||||
):
|
||||
) -> ast.stmt:
|
||||
test_func: ast.expr = self._get_constraint(constraint)
|
||||
self._add_assert(
|
||||
return self._build_assert(
|
||||
ast.Call(
|
||||
func=test_func,
|
||||
args=[expr],
|
||||
@@ -401,3 +573,117 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
constraint: ast.expr = self._constraint_generator.generate(expr)
|
||||
self._constraints.append((expr, constraint))
|
||||
return constraint
|
||||
|
||||
def _is_dataframe_definition(self) -> ast.stmt:
|
||||
"""
|
||||
def IS_DATAFRAME_FUNC(obj) -> bool:
|
||||
import pandas as pd
|
||||
return isinstance(obj, pd.DataFrame)
|
||||
"""
|
||||
|
||||
return ast.FunctionDef(
|
||||
name=self.IS_DATAFRAME_FUNC,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[ast.arg(arg="obj")],
|
||||
args=[],
|
||||
kwonlyargs=[],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
),
|
||||
body=[
|
||||
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
|
||||
ast.Return(
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="isinstance"),
|
||||
args=[
|
||||
ast.Name(id="obj"),
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="DataFrame",
|
||||
),
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
),
|
||||
],
|
||||
decorator_list=[],
|
||||
returns=ast.Name(id="bool"),
|
||||
)
|
||||
|
||||
def _is_column_definition(self) -> ast.stmt:
|
||||
"""
|
||||
def IS_COLUMN_FUNC(obj) -> bool:
|
||||
import pandas as pd
|
||||
return isinstance(obj, pd.Series)
|
||||
"""
|
||||
|
||||
return ast.FunctionDef(
|
||||
name=self.IS_COLUMN_FUNC,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[ast.arg(arg="obj")],
|
||||
args=[],
|
||||
kwonlyargs=[],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
),
|
||||
body=[
|
||||
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
|
||||
ast.Return(
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="isinstance"),
|
||||
args=[
|
||||
ast.Name(id="obj"),
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="Series",
|
||||
),
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
),
|
||||
],
|
||||
decorator_list=[],
|
||||
returns=ast.Name(id="bool"),
|
||||
)
|
||||
|
||||
def _make_column_inner_assert(
|
||||
self, src_location: Location, column: ast.expr, type: ColumnType
|
||||
) -> Optional[ast.stmt]:
|
||||
# TODO: improve message, maybe chain contexts
|
||||
col: ast.expr = ast.Name(id="col")
|
||||
body: list[ast.stmt] = self._make_cast_asserts(src_location, col, type.type)
|
||||
if len(body) == 0:
|
||||
return None
|
||||
return ast.For(
|
||||
target=col,
|
||||
iter=column,
|
||||
body=body,
|
||||
orelse=[],
|
||||
)
|
||||
|
||||
def _convert_assertion(self, assertion: Assertion) -> ast.stmt:
|
||||
inputs: list[ast.expr] = []
|
||||
|
||||
for input in assertion.inputs:
|
||||
converted: ast.expr = self.convert(input)
|
||||
alias: ast.expr = self._make_alias(input, converted)
|
||||
inputs.append(alias)
|
||||
|
||||
test: ast.expr = assertion.builder(*inputs)
|
||||
location: Location = assertion.bound_expr.location
|
||||
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||
return self._build_assert(
|
||||
test, f"{loc_str}: AssertionError: {assertion.message}"
|
||||
)
|
||||
|
||||
def _apply_assertions(self, expr: p.Expr, assertions: list[Assertion]) -> ast.expr:
|
||||
for assertion in assertions:
|
||||
assert_stmt: ast.stmt
|
||||
assert_stmt = self._convert_assertion(assertion)
|
||||
self._add_assert(assert_stmt)
|
||||
|
||||
# Mutating list in frozen dataclass
|
||||
# Not ideal but easiest way to avoid duplicate assertions
|
||||
self._typed_ast.assertions.remove(assertion)
|
||||
|
||||
return expr.accept(self)
|
||||
|
||||
@@ -6,14 +6,19 @@ from midas.checker.registry import Member, TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
DataFrameType,
|
||||
DerivedType,
|
||||
ExtensionType,
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
TopType,
|
||||
TupleType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
@@ -30,6 +35,7 @@ class StubsGenerator:
|
||||
self.types: TypesRegistry = types
|
||||
self.stubs: list[ast.stmt] = []
|
||||
self.typing_imports: set[str] = set()
|
||||
self.import_pandas: bool = False
|
||||
self.protocol_idx: int = 0
|
||||
self.stub_idx: int = 0
|
||||
self.type_var_idx: int = 0
|
||||
@@ -38,6 +44,7 @@ class StubsGenerator:
|
||||
def generate_stubs(self) -> ast.Module:
|
||||
self.stubs = []
|
||||
self.typing_imports = set()
|
||||
self.import_pandas = False
|
||||
for name, type in self.types._types.items():
|
||||
# Skip builtin types, not just based on name so the user can override
|
||||
# TODO: check if added members on builtin type
|
||||
@@ -53,7 +60,7 @@ class StubsGenerator:
|
||||
continue
|
||||
self.generate_stub(name, type)
|
||||
|
||||
imports = [
|
||||
imports: list[ast.stmt] = [
|
||||
ast.ImportFrom(
|
||||
module="__future__",
|
||||
names=[ast.alias(name="annotations")],
|
||||
@@ -70,11 +77,37 @@ class StubsGenerator:
|
||||
level=0,
|
||||
)
|
||||
)
|
||||
if self.import_pandas:
|
||||
imports.append(
|
||||
ast.Import(
|
||||
names=[
|
||||
ast.alias(
|
||||
name="pandas",
|
||||
asname="pd",
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
return ast.Module(body=imports + self.stubs, type_ignores=[])
|
||||
|
||||
def generate_stub(self, name: str, type: Type):
|
||||
base_type: Type = type
|
||||
|
||||
# TODO: improve
|
||||
match type:
|
||||
case DerivedType(name=name_) | GenericType(name=name_) if name_ == name:
|
||||
pass
|
||||
case UnitType() if name == "None":
|
||||
pass
|
||||
case TopType() if name == "Any":
|
||||
pass
|
||||
case _:
|
||||
alias = ast.Assign(
|
||||
targets=[ast.Name(id=name)], value=self.dump_type(type)
|
||||
)
|
||||
self.add_stub(alias)
|
||||
return
|
||||
|
||||
members: dict[str, Member] = self.types._members.get(name, {})
|
||||
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
|
||||
return
|
||||
@@ -231,6 +264,57 @@ class StubsGenerator:
|
||||
case ConstraintType():
|
||||
return self.dump_type(type.type)
|
||||
|
||||
case TupleType(items=items):
|
||||
return ast.Subscript(
|
||||
value=ast.Name(id="tuple"),
|
||||
slice=ast.Tuple(
|
||||
elts=[self.dump_type(item) for item in items],
|
||||
),
|
||||
)
|
||||
|
||||
case ColumnType(type=inner):
|
||||
self.import_pandas = True
|
||||
return ast.Subscript(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="Series",
|
||||
),
|
||||
slice=self.dump_type(inner),
|
||||
)
|
||||
|
||||
case DataFrameType():
|
||||
self.import_pandas = True
|
||||
return ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="DataFrame",
|
||||
)
|
||||
|
||||
case FrameGroupBy():
|
||||
self.import_pandas = True
|
||||
return ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="api",
|
||||
),
|
||||
attr="typing",
|
||||
),
|
||||
attr="DataFrameGroupBy",
|
||||
)
|
||||
|
||||
case ColumnGroupBy():
|
||||
self.import_pandas = True
|
||||
return ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="api",
|
||||
),
|
||||
attr="typing",
|
||||
),
|
||||
attr="SeriesGroupBy",
|
||||
)
|
||||
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
|
||||
@@ -46,8 +46,8 @@ class MidasLexer(Lexer):
|
||||
self.add_token(TokenType.UNDERSCORE)
|
||||
case "-" if self.match(">"):
|
||||
self.add_token(TokenType.ARROW)
|
||||
# case "+":
|
||||
# self.add_token(TokenType.PLUS)
|
||||
case "+":
|
||||
self.add_token(TokenType.PLUS)
|
||||
case "-":
|
||||
self.add_token(TokenType.MINUS)
|
||||
case "*":
|
||||
|
||||
@@ -25,7 +25,7 @@ class TokenType(Enum):
|
||||
DOT = auto()
|
||||
|
||||
# Operators
|
||||
# PLUS = auto()
|
||||
PLUS = auto()
|
||||
MINUS = auto()
|
||||
STAR = auto()
|
||||
SLASH = auto()
|
||||
|
||||
@@ -10,6 +10,7 @@ from midas.ast.midas import (
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
ExtensionType,
|
||||
FrameType,
|
||||
FunctionType,
|
||||
GenericType,
|
||||
GetExpr,
|
||||
@@ -226,8 +227,10 @@ class MidasParser(Parser):
|
||||
return self.generic_type()
|
||||
|
||||
def generic_type(self) -> Type:
|
||||
type: Type = self.named_type()
|
||||
type: NamedType = self.named_type()
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
if type.name.lexeme == "Frame":
|
||||
return self.frame_type()
|
||||
args: list[Type] = self.type_args()
|
||||
return GenericType(
|
||||
location=Location.span(type.location, self.previous().get_location()),
|
||||
@@ -246,7 +249,7 @@ class MidasParser(Parser):
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
|
||||
return args
|
||||
|
||||
def named_type(self) -> Type:
|
||||
def named_type(self) -> NamedType:
|
||||
name: Token = self.consume_identifier("Expected type name")
|
||||
return NamedType(
|
||||
location=name.get_location(),
|
||||
@@ -281,6 +284,32 @@ class MidasParser(Parser):
|
||||
members=members,
|
||||
)
|
||||
|
||||
def frame_type(self) -> FrameType:
|
||||
keyword: Token = self.previous()
|
||||
self.consume(TokenType.LEFT_BRACKET, "Expected '[' to start frame schema")
|
||||
|
||||
columns: list[FrameType.Column] = []
|
||||
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
|
||||
name: Token = self.advance()
|
||||
self.consume(TokenType.COLON, "Expected ':' between column name and type")
|
||||
type: Type = self.type_expr()
|
||||
columns.append(
|
||||
FrameType.Column(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
type=type,
|
||||
)
|
||||
)
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Unclosed frame schema")
|
||||
|
||||
return FrameType(
|
||||
location=keyword.location_to(self.previous()),
|
||||
columns=columns,
|
||||
)
|
||||
|
||||
def constraint(self) -> Expr:
|
||||
"""Parse a constraint
|
||||
|
||||
@@ -332,13 +361,35 @@ class MidasParser(Parser):
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.unary()
|
||||
expr: Expr = self.term()
|
||||
while self.match(
|
||||
TokenType.LESS,
|
||||
TokenType.LESS_EQUAL,
|
||||
TokenType.GREATER,
|
||||
TokenType.GREATER_EQUAL,
|
||||
):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.term()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = BinaryExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
return expr
|
||||
|
||||
def term(self) -> Expr:
|
||||
expr: Expr = self.factor()
|
||||
while self.match(TokenType.PLUS, TokenType.MINUS):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.factor()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = BinaryExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
return expr
|
||||
|
||||
def factor(self) -> Expr:
|
||||
expr: Expr = self.unary()
|
||||
while self.match(TokenType.STAR, TokenType.SLASH):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.unary()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
@@ -370,7 +421,7 @@ class MidasParser(Parser):
|
||||
pos_args: list[Expr] = []
|
||||
kw_args: dict[str, Expr] = {}
|
||||
keywords: bool = False
|
||||
while not self.match(TokenType.RIGHT_PAREN):
|
||||
while not self.check(TokenType.RIGHT_PAREN):
|
||||
if self.check_identifier() and self.check_next(TokenType.EQUAL):
|
||||
keywords = True
|
||||
keyword: Token = self.advance()
|
||||
|
||||
@@ -30,6 +30,7 @@ from midas.ast.python import (
|
||||
Stmt,
|
||||
SubscriptExpr,
|
||||
TernaryExpr,
|
||||
TupleExpr,
|
||||
TypeAssign,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
@@ -300,26 +301,28 @@ class PythonParser:
|
||||
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
|
||||
return self._parse_frame_type(schema)
|
||||
|
||||
case ast.Subscript(value=ast.Name(id=name), slice=param):
|
||||
case ast.Subscript(value=ast.Name(id=name), slice=arg):
|
||||
args: tuple[MidasType, ...] = (
|
||||
tuple(self._parse_type(a) for a in arg.elts)
|
||||
if isinstance(arg, ast.Tuple)
|
||||
else (self._parse_type(arg),)
|
||||
)
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base=name,
|
||||
param=self._parse_type(param),
|
||||
args=args,
|
||||
)
|
||||
|
||||
case ast.Name(id=name):
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base=name,
|
||||
param=None,
|
||||
args=(),
|
||||
)
|
||||
|
||||
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
|
||||
left = self._parse_type(left_expr)
|
||||
match left:
|
||||
case None:
|
||||
raise InvalidSyntaxError()
|
||||
|
||||
# If chained constraints, separate base type and rebuild constraint
|
||||
case ConstraintType(type=left_type, constraint=left_constraint):
|
||||
constraint = ast.BinOp(
|
||||
@@ -345,7 +348,7 @@ class PythonParser:
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base="None",
|
||||
param=None,
|
||||
args=(),
|
||||
)
|
||||
|
||||
case _:
|
||||
@@ -477,6 +480,12 @@ class PythonParser:
|
||||
step=self.parse_expr(step) if step is not None else None,
|
||||
)
|
||||
|
||||
case ast.Tuple(elts=items):
|
||||
return TupleExpr(
|
||||
location=location,
|
||||
items=tuple(self.parse_expr(item) for item in items),
|
||||
)
|
||||
|
||||
case _:
|
||||
print(f"Unsupported expression: {ast.unparse(node)}")
|
||||
return RawExpr(location=location, expr=node)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import Generic, TypeVar
|
||||
from typing import cast as typing_cast
|
||||
|
||||
cast = typing_cast
|
||||
@@ -32,3 +33,20 @@ This operation is unsound, use at your own risk!
|
||||
|
||||
_**Internal Python documentation**_
|
||||
"""
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Frame(Generic[T]):
|
||||
"""A `Frame` is the abstract type implemented by `DataFrame`
|
||||
|
||||
A frame contains any number of named columns (see :class:`Column`)
|
||||
"""
|
||||
|
||||
|
||||
class Column(Generic[T]):
|
||||
"""A `Column` is the abstract type implemented by `Series`
|
||||
|
||||
A column contains a any number of values of the same type
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Callable, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.checker.types import Type
|
||||
from midas.generator.collector import AssertionCollector
|
||||
|
||||
AllowRepeat = Callable[[object], bool]
|
||||
|
||||
@@ -63,3 +64,4 @@ class TypedAST:
|
||||
stmts: list[p.Stmt]
|
||||
judgements: list[tuple[p.Expr, Type]]
|
||||
evaluated_casts: list[p.CastExpr]
|
||||
assertions: AssertionCollector
|
||||
|
||||
@@ -4,7 +4,35 @@
|
||||
"type": "Warning",
|
||||
"location": {
|
||||
"start": [
|
||||
6,
|
||||
8,
|
||||
12
|
||||
],
|
||||
"end": [
|
||||
8,
|
||||
43
|
||||
]
|
||||
},
|
||||
"message": "ConstraintType not yet supported"
|
||||
},
|
||||
{
|
||||
"type": "Warning",
|
||||
"location": {
|
||||
"start": [
|
||||
10,
|
||||
10
|
||||
],
|
||||
"end": [
|
||||
10,
|
||||
18
|
||||
]
|
||||
},
|
||||
"message": "Unknown type 'datetime'"
|
||||
},
|
||||
{
|
||||
"type": "Warning",
|
||||
"location": {
|
||||
"start": [
|
||||
13,
|
||||
4
|
||||
],
|
||||
"end": [
|
||||
@@ -12,7 +40,7 @@
|
||||
5
|
||||
]
|
||||
},
|
||||
"message": "FrameType not yet supported"
|
||||
"message": "Unknown type '_'"
|
||||
}
|
||||
],
|
||||
"judgments": []
|
||||
|
||||
@@ -328,6 +328,19 @@
|
||||
},
|
||||
"type": {}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:9",
|
||||
"to": "L6:10"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:5",
|
||||
@@ -373,19 +386,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:9",
|
||||
"to": "L6:10"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:5",
|
||||
@@ -407,6 +407,32 @@
|
||||
},
|
||||
"type": {}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L7:9",
|
||||
"to": "L7:10"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L7:12",
|
||||
"to": "L7:15"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 2.0
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L7:5",
|
||||
@@ -452,32 +478,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L7:9",
|
||||
"to": "L7:10"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L7:12",
|
||||
"to": "L7:15"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 2.0
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L7:5",
|
||||
@@ -503,6 +503,32 @@
|
||||
},
|
||||
"type": {}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L8:9",
|
||||
"to": "L8:10"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L8:14",
|
||||
"to": "L8:17"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 2.0
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L8:5",
|
||||
@@ -548,32 +574,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L8:9",
|
||||
"to": "L8:10"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L8:14",
|
||||
"to": "L8:17"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 2.0
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L8:5",
|
||||
@@ -600,6 +600,45 @@
|
||||
},
|
||||
"type": {}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L9:9",
|
||||
"to": "L9:10"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L9:12",
|
||||
"to": "L9:15"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 2.0
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L9:17",
|
||||
"to": "L9:23"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "test"
|
||||
},
|
||||
"type": {
|
||||
"name": "str"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L9:5",
|
||||
@@ -645,45 +684,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L9:9",
|
||||
"to": "L9:10"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L9:12",
|
||||
"to": "L9:15"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 2.0
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L9:17",
|
||||
"to": "L9:23"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "test"
|
||||
},
|
||||
"type": {
|
||||
"name": "str"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L9:5",
|
||||
@@ -713,6 +713,45 @@
|
||||
},
|
||||
"type": {}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L10:9",
|
||||
"to": "L10:10"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L10:12",
|
||||
"to": "L10:15"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 2.0
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L10:19",
|
||||
"to": "L10:22"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 3.0
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L10:5",
|
||||
@@ -758,45 +797,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L10:9",
|
||||
"to": "L10:10"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L10:12",
|
||||
"to": "L10:15"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 2.0
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L10:19",
|
||||
"to": "L10:22"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 3.0
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L10:5",
|
||||
@@ -827,6 +827,19 @@
|
||||
},
|
||||
"type": {}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:11",
|
||||
"to": "L11:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:5",
|
||||
@@ -872,19 +885,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:11",
|
||||
"to": "L11:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:5",
|
||||
@@ -906,6 +906,19 @@
|
||||
},
|
||||
"type": {}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:11",
|
||||
"to": "L12:17"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "test"
|
||||
},
|
||||
"type": {
|
||||
"name": "str"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:5",
|
||||
@@ -951,19 +964,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:11",
|
||||
"to": "L12:17"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "test"
|
||||
},
|
||||
"type": {
|
||||
"name": "str"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:5",
|
||||
@@ -985,6 +985,45 @@
|
||||
},
|
||||
"type": {}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:10",
|
||||
"to": "L14:11"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:13",
|
||||
"to": "L14:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 2.0
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:20",
|
||||
"to": "L14:26"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "test"
|
||||
},
|
||||
"type": {
|
||||
"name": "str"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:6",
|
||||
@@ -1030,45 +1069,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:10",
|
||||
"to": "L14:11"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:13",
|
||||
"to": "L14:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 2.0
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:20",
|
||||
"to": "L14:26"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "test"
|
||||
},
|
||||
"type": {
|
||||
"name": "str"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:6",
|
||||
@@ -1101,6 +1101,45 @@
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L15:10",
|
||||
"to": "L15:11"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L15:15",
|
||||
"to": "L15:18"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 2.0
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L15:22",
|
||||
"to": "L15:28"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "test"
|
||||
},
|
||||
"type": {
|
||||
"name": "str"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L15:6",
|
||||
@@ -1146,45 +1185,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L15:10",
|
||||
"to": "L15:11"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L15:15",
|
||||
"to": "L15:18"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 2.0
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L15:22",
|
||||
"to": "L15:28"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "test"
|
||||
},
|
||||
"type": {
|
||||
"name": "str"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L15:6",
|
||||
@@ -1217,6 +1217,45 @@
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L16:10",
|
||||
"to": "L16:11"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L16:15",
|
||||
"to": "L16:21"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "test"
|
||||
},
|
||||
"type": {
|
||||
"name": "str"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L16:25",
|
||||
"to": "L16:28"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 2.0
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L16:6",
|
||||
@@ -1262,45 +1301,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L16:10",
|
||||
"to": "L16:11"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L16:15",
|
||||
"to": "L16:21"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "test"
|
||||
},
|
||||
"type": {
|
||||
"name": "str"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L16:25",
|
||||
"to": "L16:28"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 2.0
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L16:6",
|
||||
@@ -1333,6 +1333,45 @@
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L18:10",
|
||||
"to": "L18:13"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "str"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L18:15",
|
||||
"to": "L18:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 3
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L18:20",
|
||||
"to": "L18:25"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": false
|
||||
},
|
||||
"type": {
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L18:6",
|
||||
@@ -1378,45 +1417,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L18:10",
|
||||
"to": "L18:13"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "str"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L18:15",
|
||||
"to": "L18:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 3
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L18:20",
|
||||
"to": "L18:25"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": false
|
||||
},
|
||||
"type": {
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L18:6",
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Meter",
|
||||
"param": null
|
||||
"args": []
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
@@ -62,7 +62,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Second",
|
||||
"param": null
|
||||
"args": []
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
|
||||
@@ -100,6 +100,32 @@
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:13",
|
||||
"to": "L11:15"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v1"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:17",
|
||||
"to": "L11:19"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v2"
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:5",
|
||||
@@ -135,32 +161,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:13",
|
||||
"to": "L11:15"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v1"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:17",
|
||||
"to": "L11:19"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v2"
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:5",
|
||||
|
||||
@@ -72,29 +72,6 @@
|
||||
}
|
||||
],
|
||||
"judgments": [
|
||||
{
|
||||
"location": {
|
||||
"from": "L26:0",
|
||||
"to": "L26:5"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "print"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "object",
|
||||
"type": {},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"returns": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L27:4",
|
||||
@@ -325,6 +302,29 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L26:0",
|
||||
"to": "L26:5"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "print"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "object",
|
||||
"type": {},
|
||||
"required": false
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"returns": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L26:0",
|
||||
|
||||
@@ -63,31 +63,6 @@
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:11",
|
||||
"to": "L6:15"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "bool"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "object",
|
||||
"type": {},
|
||||
"required": false
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:16",
|
||||
@@ -135,6 +110,31 @@
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:11",
|
||||
"to": "L6:15"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "bool"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "object",
|
||||
"type": {},
|
||||
"required": false
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:11",
|
||||
@@ -367,6 +367,54 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:21",
|
||||
"to": "L12:27"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "double"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [],
|
||||
"args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "value",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:29",
|
||||
"to": "L12:35"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "floats"
|
||||
},
|
||||
"type": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "float"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:17",
|
||||
@@ -455,54 +503,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:21",
|
||||
"to": "L12:27"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "double"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [],
|
||||
"args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "value",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:29",
|
||||
"to": "L12:35"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "floats"
|
||||
},
|
||||
"type": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "float"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:17",
|
||||
@@ -538,6 +538,54 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L13:19",
|
||||
"to": "L13:25"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "double"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [],
|
||||
"args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "value",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L13:27",
|
||||
"to": "L13:31"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "ints"
|
||||
},
|
||||
"type": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "int"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L13:15",
|
||||
@@ -626,54 +674,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L13:19",
|
||||
"to": "L13:25"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "double"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [],
|
||||
"args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "value",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L13:27",
|
||||
"to": "L13:31"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "ints"
|
||||
},
|
||||
"type": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "int"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L13:15",
|
||||
@@ -699,6 +699,54 @@
|
||||
},
|
||||
"type": {}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:15",
|
||||
"to": "L14:21"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "is_odd"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [],
|
||||
"args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "value",
|
||||
"type": {
|
||||
"name": "int"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:23",
|
||||
"to": "L14:27"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "ints"
|
||||
},
|
||||
"type": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "int"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:11",
|
||||
@@ -787,54 +835,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:15",
|
||||
"to": "L14:21"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "is_odd"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [],
|
||||
"args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "value",
|
||||
"type": {
|
||||
"name": "int"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:23",
|
||||
"to": "L14:27"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "ints"
|
||||
},
|
||||
"type": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "int"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:11",
|
||||
|
||||
51
tests/cases/checker/09_frame_ops.py
Normal file
51
tests/cases/checker/09_frame_ops.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
df1: Frame[a:int, b:float]
|
||||
df2: Frame[a:int, b:float]
|
||||
|
||||
_: Any
|
||||
|
||||
# Arithmetic
|
||||
_ = df1 + df2
|
||||
_ = df1 - df2
|
||||
_ = df1 * df2
|
||||
_ = df1 / df2
|
||||
_ = df1 // df2
|
||||
_ = df1 % df2
|
||||
_ = df1**df2
|
||||
|
||||
# Comparisons
|
||||
_ = df1 < df2
|
||||
_ = df1 > df2
|
||||
_ = df1 <= df2
|
||||
_ = df1 >= df2
|
||||
_ = df1 != df2
|
||||
_ = df1 == df2
|
||||
|
||||
# Aggregate
|
||||
_ = df1.kurt()
|
||||
_ = df1.kurtosis()
|
||||
_ = df1.max()
|
||||
_ = df1.mean()
|
||||
_ = df1.median()
|
||||
_ = df1.min()
|
||||
_ = df1.mode()
|
||||
_ = df1.prod()
|
||||
_ = df1.product()
|
||||
_ = df1.std()
|
||||
_ = df1.sum()
|
||||
_ = df1.var()
|
||||
|
||||
# Groupby
|
||||
gb = df1.groupby(by="a")
|
||||
|
||||
_ = gb.kurt()
|
||||
_ = gb.max()
|
||||
_ = gb.mean()
|
||||
_ = gb.median()
|
||||
_ = gb.min()
|
||||
_ = gb.prod()
|
||||
_ = gb.std()
|
||||
_ = gb.sum()
|
||||
_ = gb.var()
|
||||
2771
tests/cases/checker/09_frame_ops.py.ref.json
Normal file
2771
tests/cases/checker/09_frame_ops.py.ref.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -16,7 +16,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "bool",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -25,7 +25,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -36,7 +36,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
"args": []
|
||||
},
|
||||
"constraint": "(_ > 0) + (_ < 250)"
|
||||
}
|
||||
@@ -47,7 +47,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "str",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -56,7 +56,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "datetime",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -65,7 +65,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -79,7 +79,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "_",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "GeoLocation",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -28,11 +28,13 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "BaseType",
|
||||
"base": "GeoLocation",
|
||||
"param": null
|
||||
}
|
||||
"args": [
|
||||
{
|
||||
"_type": "BaseType",
|
||||
"base": "GeoLocation",
|
||||
"args": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -65,11 +67,13 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "BaseType",
|
||||
"base": "GeoLocation",
|
||||
"param": null
|
||||
}
|
||||
"args": [
|
||||
{
|
||||
"_type": "BaseType",
|
||||
"base": "GeoLocation",
|
||||
"args": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -117,7 +121,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Latitude",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -146,7 +150,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Latitude",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -175,11 +179,13 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Difference",
|
||||
"param": {
|
||||
"_type": "BaseType",
|
||||
"base": "Latitude",
|
||||
"param": null
|
||||
}
|
||||
"args": [
|
||||
{
|
||||
"_type": "BaseType",
|
||||
"base": "Latitude",
|
||||
"args": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -217,7 +223,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
"args": []
|
||||
},
|
||||
"constraint": "_ >= 0"
|
||||
}
|
||||
@@ -230,7 +236,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
"args": []
|
||||
},
|
||||
"constraint": "_ >= 0"
|
||||
}
|
||||
@@ -252,7 +258,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
"args": []
|
||||
},
|
||||
"constraint": "Positive"
|
||||
}
|
||||
@@ -265,7 +271,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
"args": []
|
||||
},
|
||||
"constraint": "Positive"
|
||||
}
|
||||
|
||||
@@ -14,15 +14,17 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 1"
|
||||
}
|
||||
"args": [
|
||||
{
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"args": []
|
||||
},
|
||||
"constraint": "0 <= _ <= 1"
|
||||
}
|
||||
]
|
||||
},
|
||||
"default": null
|
||||
},
|
||||
@@ -31,15 +33,17 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 1"
|
||||
}
|
||||
"args": [
|
||||
{
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"args": []
|
||||
},
|
||||
"constraint": "0 <= _ <= 1"
|
||||
}
|
||||
]
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
@@ -50,15 +54,17 @@
|
||||
"returns": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 2"
|
||||
}
|
||||
"args": [
|
||||
{
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"args": []
|
||||
},
|
||||
"constraint": "0 <= _ <= 2"
|
||||
}
|
||||
]
|
||||
},
|
||||
"body": [
|
||||
{
|
||||
@@ -67,15 +73,17 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 2"
|
||||
}
|
||||
"args": [
|
||||
{
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"args": []
|
||||
},
|
||||
"constraint": "0 <= _ <= 2"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -117,7 +125,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
"args": []
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
@@ -128,7 +136,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
"args": []
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
@@ -140,7 +148,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "str",
|
||||
"param": null
|
||||
"args": []
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
|
||||
@@ -46,7 +46,8 @@ class GeneratorTester(Tester):
|
||||
|
||||
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
|
||||
generator = Generator(workdir=path.parent, types=checker.types)
|
||||
result.compiled_ast = generator.generate_ast(typed_ast, path)
|
||||
generator.set_src_path(path)
|
||||
result.compiled_ast = generator.generate_ast(typed_ast)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from midas.ast.midas import (
|
||||
AliasStmt,
|
||||
BinaryExpr,
|
||||
CallExpr,
|
||||
ComplexType,
|
||||
@@ -8,6 +9,7 @@ from midas.ast.midas import (
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
ExtensionType,
|
||||
FrameType,
|
||||
FunctionType,
|
||||
GenericType,
|
||||
GetExpr,
|
||||
@@ -60,6 +62,13 @@ class MidasAstJsonSerializer(
|
||||
"bound": self._serialize_optional(param.bound),
|
||||
}
|
||||
|
||||
def visit_alias_stmt(self, stmt: AliasStmt) -> dict:
|
||||
return {
|
||||
"_type": "AliasStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"type": stmt.type.accept(self),
|
||||
}
|
||||
|
||||
def visit_member_stmt(self, stmt: MemberStmt) -> dict:
|
||||
return {
|
||||
"_type": "MemberStmt",
|
||||
@@ -197,3 +206,15 @@ class MidasAstJsonSerializer(
|
||||
"base": type.base.accept(self),
|
||||
"extension": type.extension.accept(self),
|
||||
}
|
||||
|
||||
def visit_frame_type(self, type: FrameType) -> dict:
|
||||
return {
|
||||
"_type": "FrameType",
|
||||
"columns": [self._serialize_column(col) for col in type.columns],
|
||||
}
|
||||
|
||||
def _serialize_column(self, column: FrameType.Column):
|
||||
return {
|
||||
"name": column.name.lexeme,
|
||||
"type": column.type.accept(self),
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ from midas.ast.python import (
|
||||
Stmt,
|
||||
SubscriptExpr,
|
||||
TernaryExpr,
|
||||
TupleExpr,
|
||||
TypeAssign,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
@@ -98,7 +99,7 @@ class PythonAstJsonSerializer(
|
||||
return {
|
||||
"_type": "BaseType",
|
||||
"base": node.base,
|
||||
"param": self._serialize_optional(node.param),
|
||||
"args": self._serialize_list(node.args),
|
||||
}
|
||||
|
||||
def visit_constraint_type(self, node: ConstraintType) -> dict:
|
||||
@@ -302,6 +303,12 @@ class PythonAstJsonSerializer(
|
||||
"step": self._serialize_optional(expr.step),
|
||||
}
|
||||
|
||||
def visit_tuple_expr(self, expr: TupleExpr) -> dict:
|
||||
return {
|
||||
"_type": "TupleExpr",
|
||||
"items": [item.accept(self) for item in expr.items],
|
||||
}
|
||||
|
||||
def visit_raw_expr(self, expr: RawExpr) -> dict:
|
||||
return {
|
||||
"_type": "RawExpr",
|
||||
|
||||
Reference in New Issue
Block a user