90 Commits

Author SHA1 Message Date
c6cc38bfeb Merge pull request 'Frame / column operations' (#27) from feat/simple-frame-ops into main
Reviewed-on: #27
2026-07-03 10:29:32 +00:00
4d3e3f44a1 fix(checker): correctly check length of frame/column 2026-07-03 12:28:39 +02:00
ec80b1e92e feat(checker): add head/tail methods 2026-07-03 12:13:30 +02:00
4ea15519f3 feate(checker): add some frame/column attributes 2026-07-03 12:07:36 +02:00
7a6e01cff8 fix(checker): delegate frame aggregate methods to columns 2026-07-03 11:42:35 +02:00
733c8736b8 feat(checker): add aggregation ops on column groupby 2026-07-03 11:25:06 +02:00
20173a0b07 feat(tests): add colors and run all tests in base module 2026-07-03 10:58:28 +02:00
a143972ef1 feat(checker): add aggregation ops on frame groupby 2026-07-03 02:20:51 +02:00
0c70048b62 feat(checker): add statistical ops on columns 2026-07-03 01:34:58 +02:00
1c0c917873 feat(checker): add statistical ops on frames 2026-07-03 01:27:16 +02:00
1f6189daa4 feat(checker): add comparison binary ops on columns 2026-07-03 01:05:24 +02:00
66b585c3d6 fix(checker): recursively check builtin subtypes 2026-07-03 01:04:45 +02:00
819ab3c2bf tests: add dataframe operations test 2026-07-03 00:58:29 +02:00
d8c0b17512 feat(checker): add comparison binary ops on frames 2026-07-03 00:57:27 +02:00
6e06f9078e fix(checker): improve unknown method message 2026-07-03 00:57:10 +02:00
ece2e3a6a3 feat(checker): add arithmetic binary ops on columns 2026-07-03 00:42:00 +02:00
74c07c9afb feat(checker): add arithmetic binary ops on frames 2026-07-03 00:38:56 +02:00
be2fd4c837 feat(checker): delegate element operation to inner type
delegate element-wise binary operation on columns to their inner types
2026-07-03 00:05:40 +02:00
1bc4c704c3 feat(checker): delegate element operation to columns
delegate element-wise binary operation on frames to columns
2026-07-02 23:41:08 +02:00
0288a05901 feat(checker): handle assignment to multiple columns 2026-07-02 23:29:10 +02:00
b14f46d405 feat(checker): handle calls on group-bys 2026-07-02 19:53:58 +02:00
8e8ed62266 feat(checker): add add/mean/groupby on columns 2026-07-02 19:30:43 +02:00
2fce2f4bfc feat(checker): add column method registry 2026-07-02 19:23:23 +02:00
640f2d1771 feat(checker): support unification of frames and columns 2026-07-02 19:22:28 +02:00
b48dfe5301 refactor: make MethodRegistry generic on Call 2026-07-02 18:27:26 +02:00
0d5840a4ce refactor: restructure frame method registry in submodule 2026-07-02 18:20:10 +02:00
3c92f0867d feat(types): add ColumnGroupBy 2026-07-02 18:00:25 +02:00
b5acae4078 feat(types): add FrameGroupBy type 2026-07-02 17:45:18 +02:00
5d20f8ec3e docs: mention eager evaluation in manual 2026-07-02 17:22:28 +02:00
955c2233ed feat(checker): statically evaluate casts to Any and None 2026-07-02 17:14:30 +02:00
ff69b65171 feat(checker): add same length assertion on frames
safely adding two dataframes is only possible if the sizes are the same, or null values could be added dynamically to pad the shortest dataframe
2026-07-02 17:14:05 +02:00
8df01afd8c feat(gen): materialize assertions from collector 2026-07-02 17:10:27 +02:00
47b2dfdd73 feat(gen): add assertion collector to TypedAST 2026-07-02 17:09:50 +02:00
bd4d793ce0 feat(gen): add Assertion class 2026-07-02 17:08:43 +02:00
f7a36f61b6 fix(checker): pass AST expression to method registry 2026-07-01 22:34:02 +02:00
ad2fabf471 feat(checker): add assertion collector 2026-07-01 22:32:13 +02:00
a59a58d21a feat(gen): generate alias stubs 2026-07-01 14:43:30 +02:00
3260ae4a1e Merge pull request 'Call dispatcher' (#26) from feat/call-dispatcher into main
Reviewed-on: #26
2026-07-01 12:22:11 +00:00
bd1c9581c7 fix(checker): use dispatcher in frame method registry 2026-07-01 14:17:10 +02:00
663642ea6c fix(tests): serialize alias statements 2026-07-01 14:13:27 +02:00
e2abc04fe4 feat(checker): define min/max in preamble 2026-07-01 14:10:19 +02:00
a4016b55ce feat(checker): handle calls to AppliedType 2026-07-01 14:10:19 +02:00
1ea5da7024 feat(parser): parse binary operations in Midas 2026-07-01 14:10:18 +02:00
a017a8cf1f feat(checker): catch errors when evaluating constraint 2026-07-01 14:10:17 +02:00
8fc5ab623e feat(checker): evaluate literal cast to list/dict 2026-07-01 14:10:16 +02:00
14007db846 feat(checker): evaluate unary op on literals 2026-07-01 14:10:15 +02:00
6ad2ce4b68 feat(checker): improve function unwrapping 2026-07-01 14:10:15 +02:00
9a276c34c7 refactor: reuse CallDispatcher 2026-07-01 11:32:41 +02:00
6e717a3f9e refactor: use CallDispatcher in Midas typer 2026-07-01 11:24:09 +02:00
77aadfa264 refactor: extract function call methods to CallDispatcher 2026-07-01 11:14:08 +02:00
c81287df7f Merge pull request 'Initial dataframe implementation' (#25) from feat/dataframes into main
Reviewed-on: #25
2026-07-01 08:24:36 +00:00
ffccc1bedd feat(cli): generate stubs in build dir when compiling 2026-07-01 10:16:13 +02:00
d14f208897 feat(gen): add tuple expr to generator 2026-07-01 10:16:13 +02:00
293953a078 tests: update with multi-parameter generics 2026-07-01 10:16:12 +02:00
bccc96e4d0 fix: minor fixes 2026-07-01 10:16:11 +02:00
9db56adf56 feat: add Python tuple expression 2026-07-01 10:16:10 +02:00
3f99563ac8 feat: handle multi-parameter generic in Python 2026-07-01 10:16:10 +02:00
b36896cc7b feat(checker): add len() 2026-07-01 10:16:09 +02:00
cb75878ae9 fix(checker): allow some assignments to unknown 2026-07-01 10:16:08 +02:00
a5fe985eb2 feat(checker): add methods on str 2026-07-01 10:16:08 +02:00
e324f414e6 feat(checker): type check tuple instantiation in Midas 2026-07-01 10:16:07 +02:00
256536562f fix(parser): parse empty calls 2026-07-01 10:16:06 +02:00
64f4314f0d fix(gen): prevent empty loop for column asserts 2026-07-01 10:16:06 +02:00
6f6245d283 fix(checker): allow iterating on unknown 2026-07-01 10:16:05 +02:00
3392bc347d fix(checker): allow subtypes and unknown as if test 2026-07-01 10:16:04 +02:00
7e0319906a feat(gen): assertions for column values 2026-07-01 10:16:03 +02:00
75bd203d4a fix(checker): allow calling unknown method on dataframes 2026-07-01 10:15:16 +02:00
db40198357 feat(gen): generate asserts for dataframes and columns 2026-07-01 10:15:16 +02:00
d79e1dee18 fix(checker): change heterogeneous errors to warnings 2026-07-01 10:15:15 +02:00
4ea400265c feat(checker): add mean method on frames 2026-07-01 10:15:14 +02:00
24bffdabd4 fix(checker): type check None literal 2026-07-01 10:15:13 +02:00
d7bb6326de feat(checker): lookup dunders on dataframes 2026-07-01 10:15:12 +02:00
dbf6f9e2db tests: update with reordered argument typing 2026-07-01 10:15:12 +02:00
3cdc9031d3 refactor: use metaclass to collect frame methods 2026-07-01 10:15:11 +02:00
00e2ca8fe3 refactor: add MethodResolver class 2026-07-01 10:15:10 +02:00
4efb01285c feat: add dummy classes for typing frames and columns 2026-07-01 10:15:10 +02:00
f84a19159f fix(checker): improve heterogeneous error message 2026-07-01 10:15:09 +02:00
946b2e0d2e feat(checker): lookup dataframe methods 2026-07-01 10:15:08 +02:00
08dd7408ec feat(checker): defined add method of dataframes 2026-07-01 10:15:07 +02:00
b33fadf768 feat(checker): add structural subtyping rule for dataframes 2026-07-01 10:15:06 +02:00
7219109e5d feat(cli): print context for multiline diagnostics 2026-07-01 10:14:48 +02:00
cdf1725c26 feat(checker): process frame type definitions 2026-07-01 10:14:48 +02:00
7074b074bc feat(cli): add frame type to highlighter 2026-07-01 10:14:17 +02:00
ede7272c09 feat(parser): add frame type to midas syntax 2026-07-01 10:14:16 +02:00
87d5e286d2 feat(gen): add support for tuples and dataframes 2026-07-01 10:14:16 +02:00
c91b206791 feat(checker): handle setting dataframe column 2026-07-01 10:13:30 +02:00
a31d295eb1 feat(checker): type check subscript on dataframes 2026-07-01 10:13:30 +02:00
0d20993f02 feat(types): add TupleType 2026-07-01 10:13:28 +02:00
5357ca8e58 fix(types): add str methods to dataframe types 2026-07-01 10:13:28 +02:00
556765fd35 feat(types): add DataFrameType and ColumnType 2026-07-01 10:13:27 +02:00
54 changed files with 9495 additions and 1564 deletions

View File

@@ -678,6 +678,10 @@ In the following example, a runtime check would be generated to ensure that the
caption: [Typing of `cast` expression],
)
#gc.warning[
Assertions are statements inserted just before a statement using a `cast` expression. This means that the expression is evaluated _before_ its actual intended usage location, which might cause issues if you rely on logical operator short-circuiting. See @eager-eval for more information.
]
There may be some cases where the cost of checking a value at runtime is simply not worth the safety, for example when dealing with a big dataset. If do wish so, you can use `unsafe_cast` which will only tell the type checker the type of the value, without generating a runtime assertion. This maps to the default behavior of `typing`'s own `cast` function.
If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a string, a list of literals, etc.), the assertion is evaluated _at compile-time_ and no runtime assertion is generated.
@@ -695,3 +699,26 @@ If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a
== Generating Stubs (`stubs`) <cmd-stubs>
== Showing Type Judgements (`types`) <cmd-types>
== Validating Definitions (`validate`) <cmd-validate>
= Known limitations <limitations>
== Eager evaluation in runtime assertions <eager-eval>
The process of generating assertions to ensure safety at runtime, mainly for `cast` expressions, leads to the creation of aliases for the expressions being casted. These alias definitions eagerly evaluate before the assertion, and most importantly before the real usage location. This means that you should avoid using `cast` expressions inside logical expressions like `and` or `or`, because the normal "short-circuit" behavior will be irrelevant to the evaluations of the operands.
For example:
#figure(
```py
def foo():
print("Foo")
return True
def bar():
print("Bar")
return True
result = foo() or bar()
# Foo
# Bar
```,
caption: [Runtime assertions may eagerly evaluate expressions and bypass logical operator's short-circuit],
)

View File

@@ -157,4 +157,14 @@ class FunctionType:
required: bool
class FrameType:
columns: list[Column]
@dataclass(frozen=True, kw_only=True)
class Column:
location: Optional[Location] = None
name: Token
type: Type
###<

View File

@@ -15,7 +15,7 @@ from midas.ast.location import Location
###> MidasType | Type annotations | node
class BaseType:
base: str
param: Optional[MidasType]
args: tuple[MidasType, ...]
class ConstraintType:
@@ -174,6 +174,10 @@ class SliceExpr:
step: Optional[Expr]
class TupleExpr:
items: tuple[Expr, ...]
class RawExpr:
expr: ast.expr

View File

@@ -265,6 +265,9 @@ class Type(ABC):
@abstractmethod
def visit_function_type(self, type: FunctionType) -> T: ...
@abstractmethod
def visit_frame_type(self, type: FrameType) -> T: ...
@dataclass(frozen=True)
class NamedType(Type):
@@ -323,3 +326,17 @@ class FunctionType(Type):
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_function_type(self)
@dataclass(frozen=True)
class FrameType(Type):
columns: list[Column]
@dataclass(frozen=True, kw_only=True)
class Column:
location: Optional[Location] = None
name: Token
type: Type
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_frame_type(self)

View File

@@ -358,6 +358,25 @@ class MidasAstPrinter(
arg.type.accept(self)
self._write_line(f"required: {arg.required}", last=True)
def visit_frame_type(self, type: m.FrameType) -> None:
self._write_line("FrameType")
with self._child_level(single=True):
self._write_line("columns")
with self._child_level():
for i, column in enumerate(type.columns):
self._idx = i
if i == len(type.columns) - 1:
self._mark_last()
self._print_frame_column(column)
def _print_frame_column(self, column: m.FrameType.Column) -> None:
self._write_line("Column")
with self._child_level():
self._write_line(f'name: "{column.name.lexeme}"')
self._write_line("type")
with self._child_level(single=True):
column.type.accept(self)
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
def __init__(self, indent: int = 4):
@@ -513,6 +532,23 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
res += "?"
return res
def visit_frame_type(self, type: m.FrameType) -> str:
res: str = self.indented("Frame[")
if len(type.columns) != 0:
res += "\n"
self.level += 1
columns: list[str] = []
for column in type.columns:
columns.append(self.indented(self._print_frame_column(column)))
res += ",\n".join(columns)
self.level -= 1
res += "\n"
res += "]"
return res
def _print_frame_column(self, column: m.FrameType.Column) -> str:
return f"{column.name.lexeme}: {column.type.accept(self)}"
class PythonAstPrinter(
AstPrinter,
@@ -524,7 +560,13 @@ class PythonAstPrinter(
self._write_line("BaseType")
with self._child_level():
self._write_line(f"base: {node.base}")
self._write_optional_child("param", node.param, last=True)
self._write_line("args:", last=True)
with self._child_level():
for i, arg in enumerate(node.args):
self._idx = i
if i == len(node.args) - 1:
self._mark_last()
arg.accept(self)
def visit_constraint_type(self, node: p.ConstraintType) -> None:
self._write_line("ConstraintType")
@@ -837,6 +879,17 @@ class PythonAstPrinter(
self._write_optional_child("upper", expr.upper)
self._write_optional_child("step", expr.step, last=True)
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
self._write_line("TupleExpr")
with self._child_level():
self._write_line("items", last=True)
with self._child_level():
for i, item in enumerate(expr.items):
self._idx = i
if i == len(expr.items) - 1:
self._mark_last()
item.accept(self)
def visit_raw_expr(self, expr: p.RawExpr) -> None:
self._write_line("RawExpr")
with self._child_level(single=True):

View File

@@ -44,7 +44,7 @@ class MidasType(ABC):
@dataclass(frozen=True)
class BaseType(MidasType):
base: str
param: Optional[MidasType]
args: tuple[MidasType, ...]
def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_base_type(self)
@@ -268,6 +268,9 @@ class Expr(ABC):
@abstractmethod
def visit_slice_expr(self, expr: SliceExpr) -> T: ...
@abstractmethod
def visit_tuple_expr(self, expr: TupleExpr) -> T: ...
@abstractmethod
def visit_raw_expr(self, expr: RawExpr) -> T: ...
@@ -402,6 +405,14 @@ class SliceExpr(Expr):
return visitor.visit_slice_expr(self)
@dataclass(frozen=True)
class TupleExpr(Expr):
items: tuple[Expr, ...]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_tuple_expr(self)
@dataclass(frozen=True)
class RawExpr(Expr):
expr: ast.expr

View File

@@ -178,4 +178,100 @@ extend dict[K, V] {
// def __ior__: fn(value: SupportsKeysAndGetItem[K, V], /) -> dict[K, V]
// def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V]
}
}
extend str {
def capitalize: fn() -> str
def casefold: fn() -> str
def center: fn(width: int, fillchar: str?, /) -> str
def count: fn(sub: str, start: None?, end: None?, /) -> int
def count: fn(sub: str, start: int, end: None?, /) -> int
def count: fn(sub: str, start: None, end: int, /) -> int
def count: fn(sub: str, start: int, end: int, /) -> int
def encode: fn(encoding: str?, errors: str?) -> bytes
def endswith: fn(suffix: str, start: None?, end: None?, /) -> bool
def endswith: fn(suffix: str, start: int, end: None?, /) -> bool
def endswith: fn(suffix: str, start: None, end: int, /) -> bool
def endswith: fn(suffix: str, start: int, end: int, /) -> bool
def expandtabs: fn(tabsize: int?) -> str
def find: fn(sub: str, start: None?, end: None?, /) -> int
def find: fn(sub: str, start: int, end: None?, /) -> int
def find: fn(sub: str, start: None, end: int, /) -> int
def find: fn(sub: str, start: int, end: int, /) -> int
// def format: fn(*args: object, **kwargs: object) -> str
// def format_map: fn(mapping: _FormatMapMapping, /) -> str
def index: fn(sub: str, start: None?, end: None?, /) -> int
def index: fn(sub: str, start: int, end: None?, /) -> int
def index: fn(sub: str, start: None, end: int, /) -> int
def index: fn(sub: str, start: int, end: int, /) -> int
def isalnum: fn() -> bool
def isalpha: fn() -> bool
def isascii: fn() -> bool
def isdecimal: fn() -> bool
def isdigit: fn() -> bool
def isidentifier: fn() -> bool
def islower: fn() -> bool
def isnumeric: fn() -> bool
def isprintable: fn() -> bool
def isspace: fn() -> bool
def istitle: fn() -> bool
def isupper: fn() -> bool
def join: fn(iterable: list[str], /) -> str // TODO: use Iterable
def ljust: fn(width: int, fillchar: str?, /) -> str
def lower: fn() -> str
def lstrip: fn(chars: None?, /) -> str
def lstrip: fn(chars: str, /) -> str
def partition: fn(sep: str, /) -> tuple[str, str, str]
def replace: fn(old: str, new: str, count: int?, /) -> str
def removeprefix: fn(prefix: str, /) -> str
def removesuffix: fn(suffix: str, /) -> str
def rfind: fn(sub: str, start: None?, end: None?, /) -> int
def rfind: fn(sub: str, start: int, end: None?, /) -> int
def rfind: fn(sub: str, start: None, end: int, /) -> int
def rfind: fn(sub: str, start: int, end: int, /) -> int
def rindex: fn(sub: str, start: None?, end: None?, /) -> int
def rindex: fn(sub: str, start: int, end: None?, /) -> int
def rindex: fn(sub: str, start: None, end: int, /) -> int
def rindex: fn(sub: str, start: int, end: int, /) -> int
def rjust: fn(width: int, fillchar: str?, /) -> str
def rpartition: fn(sep: str, /) -> tuple[str, str, str]
def rsplit: fn(sep: None?, maxsplit: int?) -> list[str]
def rsplit: fn(sep: str, maxsplit: int?) -> list[str]
def rstrip: fn(chars: None?, /) -> str
def rstrip: fn(chars: str, /) -> str
def split: fn(sep: None?, maxsplit: int?) -> list[str]
def split: fn(sep: str, maxsplit: int?) -> list[str]
def splitlines: fn(keepends: bool?) -> list[str]
def startswith: fn(prefix: str, start: None?, end: None?, /) -> bool
def startswith: fn(prefix: str, start: int, end: None?, /) -> bool
def startswith: fn(prefix: str, start: None, end: int, /) -> bool
def startswith: fn(prefix: str, start: int, end: int, /) -> bool
def strip: fn(chars: None?, /) -> str
def strip: fn(chars: str, /) -> str
def swapcase: fn() -> str
def title: fn() -> str
// def translate: fn(table: _TranslateTable, /) -> str
def upper: fn() -> str
def zfill: fn(width: int, /) -> str
def __add__: fn(value: str, /) -> str
// Incompatible with Sequence.__contains__
def __contains__: fn(key: str, /) -> bool
def __eq__: fn(value: object, /) -> bool
def __ge__: fn(value: str, /) -> bool
def __getitem__: fn(key: slice, /) -> str
def __getitem__: fn(key: int, /) -> str
def __gt__: fn(value: str, /) -> bool
def __hash__: fn() -> int
// def __iter__: fn() -> Iterator[str]
def __le__: fn(value: str, /) -> bool
def __len__: fn() -> int
def __lt__: fn(value: str, /) -> bool
def __mod__: fn(value: Any, /) -> str
def __mul__: fn(value: int, /) -> str
def __ne__: fn(value: object, /) -> bool
def __rmul__: fn(value: int, /) -> str
def __getnewargs__: fn() -> tuple[str]
def __format__: fn(format_spec: str, /) -> str
}

View File

@@ -14,8 +14,10 @@ if TYPE_CHECKING:
from midas.checker.registry import TypesRegistry
# Hard-coded subtype relationships between builtin types
# Circular dependencies and diamond inheritance MUST be avoided
BUILTIN_SUBTYPES: dict[str, set[str]] = {
"object": {"float", "list", "dict", "str"},
"object": {"float", "list", "dict", "str", "bytes", "tuple"},
"float": {"int"},
"int": {"bool"},
}
@@ -26,12 +28,15 @@ def define_builtins(reg: TypesRegistry):
any = reg.define_type("Any", TopType())
unit = reg.define_type("None", UnitType())
object = reg.define_type("object", BaseType(name="object"))
bytes = reg.define_type("bytes", BaseType(name="bytes"))
bool = reg.define_type("bool", BaseType(name="bool"))
int = reg.define_type("int", BaseType(name="int"))
float = reg.define_type("float", BaseType(name="float"))
str = reg.define_type("str", BaseType(name="str"))
slice = reg.define_type("slice", BaseType(name="slice"))
tuple = reg.define_type("tuple", BaseType(name="tuple"))
list = reg.define_type(
"list",
GenericType(

484
midas/checker/dispatcher.py Normal file
View File

@@ -0,0 +1,484 @@
import logging
from dataclasses import dataclass
from enum import StrEnum
from typing import Generic, Optional, Protocol, TypeVar, Union
from midas.ast.location import Location
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import (
AppliedType,
DerivedType,
Function,
GenericType,
OverloadedFunction,
Type,
UnknownType,
)
from midas.checker.unifier import Unifier
class HasLocation(Protocol):
@property
def location(self) -> Location: ...
E = TypeVar("E", bound=HasLocation)
TypedExpr = tuple[E, Type]
@dataclass(frozen=True, kw_only=True)
class MappedArgument(Generic[E]):
expr: E
type: Type
argument: Function.Argument
@dataclass(frozen=True, kw_only=True)
class OverloadCandidate:
function: Function
mapped: list[MappedArgument]
class CallError(StrEnum):
INVALID_ARGS = "Invalid arguments"
NO_MATCHING_OVERLOAD = "No matching overload"
IMPOSSIBLE_UNIFICATION = "Parameters unification failed"
NOT_CALLABLE = "Not callable"
@dataclass(frozen=True, kw_only=True)
class CallResult:
error: Optional[CallError] = None
result: Type = UnknownType()
message: Optional[str] = None
@property
def is_valid(self) -> bool:
return self.error is None
@property
def error_message(self) -> str:
if self.message is not None:
return self.message
if self.error is not None:
return str(self.error)
return ""
class CallDispatcher(Generic[E]):
def __init__(self, types: TypesRegistry, reporter: FileReporter) -> None:
self.types: TypesRegistry = types
self.reporter: FileReporter = reporter
self.logger: logging.Logger = logging.getLogger("CallDispatcher")
def set_reporter(self, reporter: FileReporter):
self.reporter = reporter
def get_result(
self,
location: Location,
callee: Type,
positional: list[TypedExpr[E]],
keywords: dict[str, TypedExpr[E]],
report_errors: bool = True,
) -> CallResult:
"""Get the result type of a function call
If the function has overloads, the function will try to resolve the
appropriate signature.
Argument types are matched to the defined parameters.
The function doesn't take the raw expression as a parameter to accommodate
for desugared calls such as for operators.
Args:
location (Location): the call location
callee (Type): the called function
positional (list[TypedExpr]): the list positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Type: the return type of the call, or `None` if either
the call is invalid or no overload matched the arguments uniquely
"""
match callee:
case Function() as function:
valid: bool
mapped: list[MappedArgument[E]]
valid, mapped = self.map_call_arguments(
function, location, positional, keywords
)
valid = valid and self._are_arguments_valid(mapped, report_errors)
if not valid:
return CallResult(error=CallError.INVALID_ARGS)
return CallResult(result=function.returns)
case OverloadedFunction(overloads=overloads):
res = self._match_overload(
overloads, location, positional, keywords, report_errors
)
if res[0] is None:
return CallResult(
error=CallError.NO_MATCHING_OVERLOAD,
message=res[1],
)
return CallResult(result=res[0].returns)
case AppliedType(body=body):
return self.get_result(
location, body, positional, keywords, report_errors
)
case UnknownType():
return CallResult(result=UnknownType())
case DerivedType(type=base):
return self.get_result(
location, base, positional, keywords, report_errors
)
case GenericType():
unifier: Unifier = Unifier(self.types)
pos: list[Type] = [a[1] for a in positional]
kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()}
unified: Optional[Type] = unifier.unify_call(callee, pos, kw)
if unified is None:
pos_str: str = ", ".join(str(t) for t in pos)
kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items())
message: str = (
f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}"
)
if report_errors:
self.reporter.error(location, message)
return CallResult(
error=CallError.IMPOSSIBLE_UNIFICATION,
message=message,
)
return self.get_result(
location,
unified,
positional,
keywords,
report_errors,
)
case _:
message: str = f"{callee} ({callee.__class__.__name__}) is not callable"
if report_errors:
self.reporter.error(location, message)
return CallResult(
error=CallError.NOT_CALLABLE,
message=message,
)
def _unwrap_function(
self,
callee: Type,
positional: list[TypedExpr[E]],
keywords: dict[str, TypedExpr[E]],
) -> Union[tuple[Function, None], tuple[None, CallError]]:
match callee:
case DerivedType(type=base):
return self._unwrap_function(base, positional, keywords)
case GenericType():
unifier: Unifier = Unifier(self.types)
unified: Optional[Type] = unifier.unify_call(
callee,
[a[1] for a in positional],
{k: v[1] for k, v in keywords.items()},
)
if unified is None:
return None, CallError.IMPOSSIBLE_UNIFICATION
return self._unwrap_function(unified, positional, keywords)
case Function():
return callee, None
case AppliedType(body=body):
return self._unwrap_function(body, positional, keywords)
case _:
return None, CallError.NOT_CALLABLE
def _are_arguments_valid(
self,
arguments: list[MappedArgument[E]],
report_errors: bool = True,
) -> bool:
"""Check whether the passed argument types correspond to their matched parameter definitions
Args:
arguments (list[MappedArgument]): the list of argument/parameter pairs
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
bool: True if all arguments fit the matching parameter definitions, False otherwise
"""
valid: bool = True
for arg in arguments:
if not self.types.is_subtype(arg.type, arg.argument.type):
if report_errors:
self.reporter.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
)
valid = False
return valid
def _match_overload(
self,
overloads: list[Type],
location: Location,
positional: list[TypedExpr[E]],
keywords: dict[str, TypedExpr[E]],
report_errors: bool = True,
) -> Union[tuple[Function, None], tuple[None, str]]:
"""Try and resolve the appropriate overload for the given arguments
Args:
overloads (list[Type]): the list of possible overloads
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keywords arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Optional[Function]: the resolved function signature if it can be
determined unambiguously, or `None`.
"""
candidates: list[OverloadCandidate] = []
errors: list[CallError] = []
for overload in overloads:
function, unwrap_error = self._unwrap_function(
overload, positional, keywords
)
if function is None:
errors.append(unwrap_error) # type: ignore
continue
valid, mapped = self.map_call_arguments(
function=function,
location=location,
positional=positional,
keywords=keywords,
report_errors=False,
)
if valid and self._are_arguments_valid(mapped, report_errors=False):
candidates.append(
OverloadCandidate(
function=function,
mapped=mapped,
)
)
pos_types: str = ", ".join(str(type) for _, type in positional)
kw_types: str = ", ".join(
f"{name}: {type}" for name, (_, type) in keywords.items()
)
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
n_candidates: int = len(candidates)
# Exactly 1 match -> return it
if n_candidates == 1:
return candidates[0].function, None
# No match -> invalid call
if n_candidates == 0:
overloads_str: str = ", ".join(map(str, overloads))
errors_str: str = ", ".join(errors)
message: str = (
f"No matching overload in [{overloads_str}] {for_args} (errors: {errors_str})"
)
if report_errors:
self.reporter.error(location, message)
return None, message
# Multiple matches -> see if one <: all others (more specific)
for i1, c1 in enumerate(candidates):
mapped1: list[MappedArgument[E]] = c1.mapped
best_match: bool = True
for i2, c2 in enumerate(candidates):
if i1 == i2:
continue
mapped2: list[MappedArgument[E]] = c2.mapped
if not self._are_mapped_subtypes(mapped1, mapped2):
best_match = False
break
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
if best_match:
return c1.function, None
candidates_str: str = ", ".join(
str(candidate.function) for candidate in candidates
)
message: str = f"Multiple matching overloads {for_args}: {candidates_str}"
if report_errors:
self.reporter.error(location, message)
return None, message
def map_call_arguments(
self,
function: Function,
location: Location,
positional: list[TypedExpr[E]],
keywords: dict[str, TypedExpr[E]],
report_errors: bool = True,
) -> tuple[bool, list[MappedArgument]]:
"""Map call arguments to a function's parameters as defined in its signature
This method maps positional-only, keyword-only and mixed parameter definitions
with the arguments passed at the call site
Any mismatched, missing or unexpected argument is reported as a diagnostic,
unless `report_errors` is set to `False`
Args:
function (Function): the function definition
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
tuple[bool, list[MappedArgument]]: a boolean reporting whether
the call is valid and the list of mapped arguments
"""
set_args: set[str] = set()
required_positional: list[str] = [
arg.name for arg in function.pos_args + function.args if arg.required
]
required_keyword: list[str] = [
arg.name for arg in function.kw_args if arg.required
]
mapped: list[MappedArgument[E]] = []
pos_params: list[Function.Argument] = list(function.pos_args)
mixed_params: list[Function.Argument] = list(function.args)
kw_params: dict[str, Function.Argument] = {
arg.name: arg for arg in function.kw_args
}
valid_call: bool = True
# TODO: handle *args and **kwargs sinks
for arg in positional:
param: Function.Argument
if len(pos_params) != 0:
param = pos_params.pop(0)
elif len(mixed_params) != 0:
param = mixed_params.pop(0)
else:
if report_errors:
self.reporter.error(
arg[0].location, "Too many positional arguments"
)
valid_call = False
break
name: str = param.name
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
kw_params.update({arg.name: arg for arg in mixed_params})
for name, arg in keywords.items():
param: Function.Argument
if name not in kw_params:
if report_errors:
if name in set_args:
self.reporter.error(
arg[0].location, f"Multiple values for argument '{name}'"
)
else:
self.reporter.error(
arg[0].location, f"Unknown keyword argument '{name}'"
)
valid_call = False
continue
param = kw_params.pop(name)
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
def join_args(args: list[str]) -> str:
args = list(map(lambda a: f"'{a}'", args))
if len(args) == 0:
return ""
if len(args) == 1:
return args[0]
return ", ".join(args[:-1]) + " and " + args[-1]
if len(required_positional) != 0:
plural: str = "" if len(required_positional) == 1 else "s"
args: str = join_args(required_positional)
if report_errors:
self.reporter.error(
location,
f"Missing required positional argument{plural}: {args}",
)
valid_call = False
if len(required_keyword) != 0:
plural: str = "" if len(required_keyword) == 1 else "s"
args: str = join_args(required_keyword)
if report_errors:
self.reporter.error(
location,
f"Missing required keyword argument{plural}: {args}",
)
valid_call = False
return valid_call, mapped
def _are_mapped_subtypes(
self, mapped1: list[MappedArgument[E]], mapped2: list[MappedArgument[E]]
) -> bool:
"""Check whether the given argument mappings are subtype/supertype of one another
This function checks whether the argument mappings `mapped1` are subtypes
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
of the corresponding parameter in `mapped2`, `False` is returned.
This is used to check whether a given overload is
a more specific function/ a subtype of another.
Args:
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
Returns:
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
"""
by_expr: dict[E, Type] = {}
for arg in mapped1:
by_expr[arg.expr] = arg.argument.type
for arg in mapped2:
type2: Type = arg.argument.type
type1: Type = by_expr[arg.expr]
if not self.types.is_subtype(type1, type2):
return False
return True

View File

@@ -0,0 +1,203 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.dispatcher import CallResult
from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import ColumnGroupBy, ColumnType, Function, TopType, Type
if TYPE_CHECKING:
from midas.checker.python import TypedExpr
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
call_expr: p.Expr
groupby: ColumnGroupBy
groupby_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.groupby_expr, self.groupby)
class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
NAMED_ARGS: dict[str, str] = {
"numeric_only": "bool",
"skipna": "bool",
"engine": "str",
"engine_kwargs": "dict",
}
def _aggregate(
self,
call: Call,
args: list[str | tuple[str, str, bool]] = [],
*,
preserve_inner_type: bool = False,
) -> Type:
real_args: list[Function.Argument] = []
for i, arg in enumerate(args):
match arg:
case str() as name:
arg = Function.Argument(
pos=i,
name=name,
type=self.types.get_type(self.NAMED_ARGS[name]),
required=False,
)
case (name, type, required):
arg = Function.Argument(
pos=i,
name=name,
type=self.types.get_type(type),
required=required,
)
real_args.append(arg)
signature = Function(
args=real_args,
returns=(
call.groupby.column
if preserve_inner_type
else ColumnType(type=TopType())
),
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method()
def kurt(self, call: Call) -> Type:
return self._aggregate(
call,
["skipna", "numeric_only"],
)
@method()
def max(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
preserve_inner_type=True,
)
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(
call,
["numeric_only", "skipna", "engine", "engine_kwargs"],
)
@method()
def median(self, call: Call) -> Type:
return self._aggregate(
call,
["numeric_only", "skipna"],
preserve_inner_type=True,
)
@method()
def min(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
preserve_inner_type=True,
)
@method()
def prod(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
],
)
@method()
def std(self, call: Call) -> Type:
return self._aggregate(
call,
[
(
"ddof",
"int",
False,
),
"engine",
"engine_kwargs",
"numeric_only",
"skipna",
],
)
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
)
@method()
def var(self, call: Call) -> Type:
return self._aggregate(
call,
[
(
"var",
"int",
False,
),
"engine",
"engine_kwargs",
"numeric_only",
"skipna",
],
)

View File

@@ -0,0 +1,78 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frames.column_groupby_methods import Call as GroupByCall
from midas.checker.frames.column_groupby_methods import ColumnGroupByMethodRegistry
from midas.checker.frames.column_methods import Call, ColumnMethodRegistry
from midas.checker.registry import TypesRegistry
from midas.checker.types import ColumnGroupBy, ColumnType, Type
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
class ColumnManager:
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
self.method_resolver: ColumnMethodRegistry = ColumnMethodRegistry(self.typer)
self.groupby_method_resolver: ColumnGroupByMethodRegistry = (
ColumnGroupByMethodRegistry(self.typer)
)
def call(
self,
method: str,
location: Location,
call_expr: p.Expr,
column: ColumnType,
column_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: Call = Call(
location=location,
call_expr=call_expr,
column=column,
column_expr=column_expr,
positional=positional,
keywords=keywords,
)
return self.method_resolver.call(method, call)
def groupby_call(
self,
method: str,
location: Location,
call_expr: p.Expr,
groupby: ColumnGroupBy,
groupby_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: GroupByCall = GroupByCall(
location=location,
call_expr=call_expr,
groupby=groupby,
groupby_expr=groupby_expr,
positional=positional,
keywords=keywords,
)
return self.groupby_method_resolver.call(method, call)
def get_attribute(self, column: ColumnType, name: str) -> Optional[Type]:
types: TypesRegistry = self.typer.types
match name:
case "ndim" | "size":
return types.get_type("int")
case "shape":
return types.tuple_of("int")
case "T":
return column
case _:
return None

View File

@@ -0,0 +1,410 @@
from __future__ import annotations
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.dispatcher import CallResult
from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import (
ColumnGroupBy,
ColumnType,
Function,
GenericType,
TopType,
Type,
TypeVar,
UnknownType,
unfold_type,
)
if TYPE_CHECKING:
from midas.checker.python import TypedExpr
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
call_expr: p.Expr
column: ColumnType
column_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.column_expr, self.column)
class ColumnMethodRegistry(MethodRegistry[Call]):
def _element_binary_op(self, call: Call, method: str) -> ColumnType:
"""Compute the result of an element-wise binary operation
This function delegates to the inner types for computing the resulting
type.
Args:
call (Call): the call that triggered this resolution
method (str): the method name
Returns:
ColumnType: the resulting column type
"""
column2: Optional[ColumnType] = None
col_type1: Type = call.column.type
new_column: Type = ColumnType(type=UnknownType())
if len(call.positional) != 0:
other: Type = call.positional[0][1]
unfolded_other: Type = unfold_type(other)
if isinstance(unfolded_other, ColumnType):
column2 = unfolded_other
col_type2: Type = column2.type
new_inner_type = self.typer.result_of_binary_op(
location=call.location,
expr=call.call_expr,
left=(call.column_expr, col_type1),
right=(call.positional[0][0], col_type2),
method=method,
)
new_column = ColumnType(type=new_inner_type)
return new_column
def _element_wise(self, call: Call, method: str) -> Type:
# TODO: support add with scalar
# Build signature with new column type and generic operand
param_type: TypeVar = TypeVar(name="T", bound=None)
signature = GenericType(
name="add",
params=[param_type],
body=Function(
args=[
Function.Argument(
pos=0,
name="other",
type=ColumnType(type=param_type),
required=True,
),
],
returns=self._element_binary_op(call, method),
),
)
# Map arguments and compute result type
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
if result.is_valid:
self._assert_same_length(
call.call_expr, call.column_expr, call.positional[0][0]
)
return result.result
@method("add", "__add__")
def add(self, call: Call) -> Type:
return self._element_wise(call, "__add__")
@method("sub", "__sub__")
def sub(self, call: Call) -> Type:
return self._element_wise(call, "__sub__")
@method("mul", "__mul__")
def mul(self, call: Call) -> Type:
return self._element_wise(call, "__mul__")
@method("div", "truediv", "__truediv__")
def truediv(self, call: Call) -> Type:
return self._element_wise(call, "__truediv__")
@method("floordiv", "__floordiv__")
def floordiv(self, call: Call) -> Type:
return self._element_wise(call, "__floordiv__")
@method("mod", "__mod__")
def mod(self, call: Call) -> Type:
return self._element_wise(call, "__mod__")
@method("pow", "__pow__")
def pow(self, call: Call) -> Type:
return self._element_wise(call, "__pow__")
@method("lt", "__lt__")
def lt(self, call: Call) -> Type:
return self._element_wise(call, "__lt__")
@method("gt", "__gt__")
def gt(self, call: Call) -> Type:
return self._element_wise(call, "__gt__")
@method("le", "__le__")
def le(self, call: Call) -> Type:
return self._element_wise(call, "__le__")
@method("ge", "__ge__")
def ge(self, call: Call) -> Type:
return self._element_wise(call, "__ge__")
@method("ne", "__ne__")
def ne(self, call: Call) -> Type:
return self._element_wise(call, "__ne__")
@method("eq", "__eq__")
def eq(self, call: Call) -> Type:
return self._element_wise(call, "__eq__")
def _aggregate(
self,
call: Call,
kwargs: list[Function.Argument] = [],
*,
preserve_inner_type: bool = False,
) -> Type:
signature = Function(
kw_args=[
Function.Argument(
pos=0,
name="axis",
type=TopType(),
required=False,
),
*kwargs,
],
returns=call.column if preserve_inner_type else ColumnType(type=TopType()),
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method("kurtosis", "kurt")
def kurtosis(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def max(self, call: Call) -> Type:
return self._aggregate(call, preserve_inner_type=True)
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def median(self, call: Call) -> Type:
return self._aggregate(call, preserve_inner_type=True)
@method()
def min(self, call: Call) -> Type:
return self._aggregate(call, preserve_inner_type=True)
@method()
def mode(self, call: Call) -> Type:
return self._aggregate(call, preserve_inner_type=True)
@method("product", "prod")
def product(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def std(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Argument(
pos=1,
name="ddof",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def var(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Argument(
pos=1,
name="var",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def head(self, call: Call) -> Type:
signature = Function(
args=[
Function.Argument(
pos=0,
name="n",
type=self.types.get_type("int"),
required=False,
),
],
returns=call.column,
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method()
def tail(self, call: Call) -> Type:
signature = Function(
args=[
Function.Argument(
pos=0,
name="n",
type=self.types.get_type("int"),
required=False,
),
],
returns=call.column,
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method()
def groupby(self, call: Call) -> Type:
bool_: Type = self.types.get_type("bool")
function: Function = Function(
args=[
Function.Argument(
pos=0,
name="by",
type=TopType(),
required=False,
),
Function.Argument(
pos=1,
name="level",
type=TopType(),
required=False,
),
],
kw_args=[
Function.Argument(
pos=2,
name="as_index",
type=bool_,
required=False,
),
Function.Argument(
pos=3,
name="sort",
type=bool_,
required=False,
),
Function.Argument(
pos=4,
name="group_keys",
type=bool_,
required=False,
),
Function.Argument(
pos=5,
name="observed",
type=bool_,
required=False,
),
Function.Argument(
pos=6,
name="dropna",
type=bool_,
required=False,
),
],
returns=ColumnGroupBy(column=call.column),
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=function,
positional=call.positional,
keywords=call.keywords,
)
return result.result
def _assert_same_length(self, call_expr: p.Expr, column1: p.Expr, column2: p.Expr):
func_name: str = "__midas_column_same_length__"
# Efficiently compute length
# https://stackoverflow.com/a/15943975/11109181
def len_of_col(col: ast.expr) -> ast.expr:
return ast.Call(
func=ast.Name(id="len"),
args=[
ast.Attribute(
value=col,
attr="index",
)
],
keywords=[],
)
self.assertions.define(
func_name,
ast.FunctionDef(
name=func_name,
args=ast.arguments(
posonlyargs=[],
args=[
ast.arg(arg="column1"),
ast.arg(arg="column2"),
],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Return(
value=ast.Compare(
left=len_of_col(ast.Name(id="column1")),
ops=[ast.Eq()],
comparators=[
len_of_col(ast.Name(id="column2")),
],
)
)
],
decorator_list=[],
),
)
self.assertions.add(
bound_expr=call_expr,
inputs=[column1, column2],
builder=lambda c1, c2: ast.Call(
func=ast.Name(id=func_name),
args=[c1, c2],
keywords=[],
),
message="Columns must have the same length",
)

View File

@@ -0,0 +1,103 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import (
ColumnGroupBy,
ColumnType,
DataFrameType,
FrameGroupBy,
Type,
UnknownType,
)
if TYPE_CHECKING:
from midas.checker.python import TypedExpr
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
call_expr: p.Expr
groupby: FrameGroupBy
groupby_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.groupby_expr, self.groupby)
class FrameGroupByMethodRegistry(MethodRegistry[Call]):
NAMED_ARGS: dict[str, str] = {
"numeric_only": "bool",
"skipna": "bool",
"engine": "str",
"engine_kwargs": "dict",
}
def _aggregate(self, call: Call, method: str) -> Type:
new_columns: list[DataFrameType.Column] = []
for column in call.groupby.frame.columns:
column_groupby: ColumnGroupBy = ColumnGroupBy(column=column.type)
result_type: Type = self.typer.call_method(
location=call.location,
call_expr=call.call_expr,
obj=(call.groupby_expr, column_groupby),
method_name=method,
positional=call.positional,
keywords=call.keywords,
)
if not isinstance(result_type, ColumnType):
result_type = ColumnType(type=UnknownType())
new_columns.append(
DataFrameType.Column(
index=column.index,
name=column.name,
type=result_type,
)
)
return DataFrameType(columns=new_columns)
@method()
def kurt(self, call: Call) -> Type:
return self._aggregate(call, "kurt")
@method()
def max(self, call: Call) -> Type:
return self._aggregate(call, "max")
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(call, "mean")
@method()
def median(self, call: Call) -> Type:
return self._aggregate(call, "median")
@method()
def min(self, call: Call) -> Type:
return self._aggregate(call, "min")
@method()
def prod(self, call: Call) -> Type:
return self._aggregate(call, "prod")
@method()
def std(self, call: Call) -> Type:
return self._aggregate(call, "std")
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(call, "sum")
@method()
def var(self, call: Call) -> Type:
return self._aggregate(call, "var")

View File

@@ -0,0 +1,255 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, TypeGuard, cast
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frames.frame_groupby_methods import Call as GroupByCall
from midas.checker.frames.frame_groupby_methods import FrameGroupByMethodRegistry
from midas.checker.frames.frame_methods import Call, FrameMethodRegistry
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import (
ColumnGroupBy,
ColumnType,
DataFrameType,
FrameGroupBy,
TupleType,
Type,
UnknownType,
)
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
return all(isinstance(expr, p.LiteralExpr) for expr in exprs)
class FrameManager:
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
self.method_resolver: FrameMethodRegistry = FrameMethodRegistry(self.typer)
self.groupby_method_resolver: FrameGroupByMethodRegistry = (
FrameGroupByMethodRegistry(self.typer)
)
def assign(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
index: p.Expr,
value_type: Type,
) -> Type:
match index:
case p.LiteralExpr(value=str() as name):
return self.assign_column(reporter, location, frame, name, value_type)
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
isinstance(index.value, str) for index in indices
):
names: list[str] = [cast(str, index.value) for index in indices]
if not isinstance(value_type, TupleType):
reporter.error(
location,
f"Cannot assign {type} to dataframe columns. Must be a tuple of columns",
)
return UnknownType()
if len(names) != len(value_type.items):
reporter.error(
location,
f"Wrong number of columns. Cannot assign {len(value_type.items)} to {len(names)} targets",
)
return UnknownType()
new_frame: Type = frame
for name, value in zip(names, value_type.items):
new_frame = self.assign_column(
reporter,
location,
new_frame,
name,
value,
)
if not isinstance(new_frame, DataFrameType):
return new_frame
return new_frame
case _:
reporter.error(
location, f"Invalid index type {index} on {frame} (assignment)"
)
return UnknownType()
def assign_column(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
name: str,
type: Type,
) -> Type:
if not isinstance(type, ColumnType):
reporter.error(
location,
f"Cannot assign {type} to dataframe column. Must be a ColumnType",
)
return frame
return self._set_column(frame, name, type)
def get(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
index: p.Expr,
) -> Type:
match index:
case p.LiteralExpr(value=str() as name):
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
if column is None:
reporter.error(location, f"Unknown column '{name}' on {frame}")
return UnknownType()
return column
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
isinstance(index.value, str) for index in indices
):
names: list[str] = [cast(str, index.value) for index in indices]
columns: list[ColumnType] = []
for name in names:
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
if column is None:
reporter.error(location, f"Unknown column '{name}' on {frame}")
return UnknownType()
columns.append(column)
return TupleType(items=tuple(columns))
case _:
reporter.error(
location, f"Invalid index type {index} on {frame} (access)"
)
return UnknownType()
def groupby_get(
self,
reporter: FileReporter,
location: Location,
groupby: FrameGroupBy,
index: p.Expr,
) -> Type:
result: Type = self.get(reporter, location, groupby.frame, index)
match result:
case ColumnType():
result = ColumnGroupBy(column=result)
case TupleType(items=columns):
result = TupleType(
items=tuple(
ColumnGroupBy(column=cast(ColumnType, column))
for column in columns
)
)
return result
@classmethod
def _set_column(
cls, frame: DataFrameType, name: str, column: ColumnType
) -> DataFrameType:
new_columns: list[DataFrameType.Column] = []
index: int = len(frame.columns)
replace: bool = False
for i, col in enumerate(frame.columns):
if col.name == name:
index = i
replace = True
# TODO: check column type here to prevent changing it
new_columns.append(col)
new_col: DataFrameType.Column = DataFrameType.Column(
index=index,
name=name,
type=column,
)
if replace:
new_columns[index] = new_col
else:
new_columns.append(new_col)
return DataFrameType(columns=new_columns)
@classmethod
def _set_columns(
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
) -> DataFrameType:
for name, col in zip(names, columns):
frame = cls._set_column(frame, name, col)
return frame
@classmethod
def _get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
for col in frame.columns:
if col.name == name:
return col.type
return None
@classmethod
def _get_columns(
cls, frame: DataFrameType, names: list[str]
) -> list[Optional[ColumnType]]:
return [cls._get_column(frame, name) for name in names]
def call(
self,
method: str,
location: Location,
call_expr: p.Expr,
frame: DataFrameType,
frame_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: Call = Call(
location=location,
call_expr=call_expr,
frame=frame,
frame_expr=frame_expr,
positional=positional,
keywords=keywords,
)
return self.method_resolver.call(method, call)
def groupby_call(
self,
method: str,
location: Location,
call_expr: p.Expr,
groupby: FrameGroupBy,
groupby_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: GroupByCall = GroupByCall(
location=location,
call_expr=call_expr,
groupby=groupby,
groupby_expr=groupby_expr,
positional=positional,
keywords=keywords,
)
return self.groupby_method_resolver.call(method, call)
def get_attribute(self, frame: DataFrameType, name: str) -> Optional[Type]:
types: TypesRegistry = self.typer.types
match name:
case "ndim" | "size":
return types.get_type("int")
case "shape":
return types.tuple_of("int", "int")
case _:
return None

View File

@@ -0,0 +1,487 @@
from __future__ import annotations
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.dispatcher import CallResult
from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import (
ColumnType,
DataFrameType,
FrameGroupBy,
Function,
OverloadedFunction,
TopType,
Type,
UnknownType,
unfold_type,
)
if TYPE_CHECKING:
from midas.checker.python import TypedExpr
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
call_expr: p.Expr
frame: DataFrameType
frame_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.frame_expr, self.frame)
class FrameMethodRegistry(MethodRegistry[Call]):
def _get_method_result(
self,
call: Call,
column1: ColumnType,
column2: ColumnType,
method: str,
) -> ColumnType:
"""Get the result of calling a method on a column, passing a second
This function delegates to the main typer the resolution of the method
member, as well as computing the result type. Because we don't have any
AST expression for the individual columns, the frame expressions are
used instead.
Args:
call (Call): the call that triggered this resolution
column1 (ColumnType): the first column, i.e. left operand
column2 (ColumnType): the second column, i.e. right operand
method (str): the method name
Returns:
ColumnType: the resulting column.
If the operation is invalid / doesn't exist,
`ColumnType(type=UnknownType())` is returned
"""
result: Type = self.typer.result_of_binary_op(
location=call.location,
expr=call.call_expr,
left=(call.frame_expr, column1),
right=(call.positional[0][0], column2),
method=method,
)
if not isinstance(result, ColumnType):
return ColumnType(type=UnknownType())
return result
def _element_binary_op(self, call: Call, method: str) -> DataFrameType:
"""Compute the result of an element-wise binary operation
This function delegates to the matching columns for computing resulting
types. Any column only present in one of the frames is forwarded as a
generic `ColumnType(type=UnknownType())`. Columns only in the second
frame are append at the end of the schema.
Args:
call (Call): the call that triggered this resolution
method (str): the method name
Returns:
DataFrameType: the resulting frame type
"""
new_columns: list[DataFrameType.Column] = []
by_name: dict[str, DataFrameType.Column] = {}
frame2: Optional[DataFrameType] = None
# Get map of operand's columns by name, if there is at least 1 operand, which is a dataframe
if len(call.positional) != 0:
operand: TypedExpr = call.positional[0]
unfolded_other: Type = unfold_type(operand[1])
if isinstance(unfolded_other, DataFrameType):
frame2 = unfolded_other
by_name = {
col.name: col for col in frame2.columns if col.name is not None
}
# Compute new schema:
# Step 1: for all columns in frame1:
# - if present in frame2 -> delegate operation to columns
# - if not -> add to schema as unknown
in_frame1: set[str] = set()
for column in call.frame.columns:
if column.name is not None:
in_frame1.add(column.name)
col_type1: ColumnType = column.type
col_type: ColumnType = ColumnType(type=UnknownType())
if column.name in by_name:
column2 = by_name[column.name]
col_type2: ColumnType = column2.type
col_type = self._get_method_result(call, col_type1, col_type2, method)
new_column = DataFrameType.Column(
index=column.index,
name=column.name,
type=col_type,
)
new_columns.append(new_column)
# Step 2: for all columns in frame2
# - if not in frame1 -> add to schema as unknown
if frame2 is not None:
for column in frame2.columns:
if column.name in in_frame1:
continue
new_columns.append(
DataFrameType.Column(
index=len(new_columns),
name=column.name,
type=ColumnType(type=UnknownType()),
)
)
return DataFrameType(columns=new_columns)
def _element_wise(self, call: Call, method: str) -> Type:
# TODO: support scalar, sequence, Series, dict operand
# Build signature with new schema and generic operand
signature = Function(
args=[
Function.Argument(
pos=0,
name="other",
type=DataFrameType(columns=[]),
required=True,
),
],
returns=self._element_binary_op(call, method),
)
# Map arguments and compute result type
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
if result.is_valid:
self._assert_same_length(
call.call_expr, call.frame_expr, call.positional[0][0]
)
return result.result
@method("add", "__add__")
def add(self, call: Call) -> Type:
return self._element_wise(call, "__add__")
@method("sub", "__sub__")
def sub(self, call: Call) -> Type:
return self._element_wise(call, "__sub__")
@method("mul", "__mul__")
def mul(self, call: Call) -> Type:
return self._element_wise(call, "__mul__")
@method("div", "truediv", "__truediv__")
def truediv(self, call: Call) -> Type:
return self._element_wise(call, "__truediv__")
@method("floordiv", "__floordiv__")
def floordiv(self, call: Call) -> Type:
return self._element_wise(call, "__floordiv__")
@method("mod", "__mod__")
def mod(self, call: Call) -> Type:
return self._element_wise(call, "__mod__")
@method("pow", "__pow__")
def pow(self, call: Call) -> Type:
return self._element_wise(call, "__pow__")
@method("lt", "__lt__")
def lt(self, call: Call) -> Type:
return self._element_wise(call, "__lt__")
@method("gt", "__gt__")
def gt(self, call: Call) -> Type:
return self._element_wise(call, "__gt__")
@method("le", "__le__")
def le(self, call: Call) -> Type:
return self._element_wise(call, "__le__")
@method("ge", "__ge__")
def ge(self, call: Call) -> Type:
return self._element_wise(call, "__ge__")
@method("ne", "__ne__")
def ne(self, call: Call) -> Type:
return self._element_wise(call, "__ne__")
@method("eq", "__eq__")
def eq(self, call: Call) -> Type:
return self._element_wise(call, "__eq__")
def _aggregate(self, call: Call, kwargs: list[Function.Argument] = []) -> Type:
with_axis = Function(
kw_args=[
Function.Argument(
pos=0,
name="axis",
type=self.types.get_type("int"),
required=False,
),
*kwargs,
],
returns=ColumnType(type=TopType()),
)
without_axis = Function(
kw_args=[
Function.Argument(
pos=0,
name="axis",
type=self.types.get_type("None"),
required=True,
),
*kwargs,
],
returns=TopType(),
)
overload = OverloadedFunction(
overloads=[
with_axis,
without_axis,
]
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=overload,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method("kurtosis", "kurt")
def kurtosis(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def max(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def median(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def min(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def mode(self, call: Call) -> Type:
return self._aggregate(call)
@method("product", "prod")
def product(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def std(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Argument(
pos=1,
name="ddof",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def var(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Argument(
pos=1,
name="var",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def head(self, call: Call) -> Type:
signature = Function(
args=[
Function.Argument(
pos=0,
name="n",
type=self.types.get_type("int"),
required=False,
),
],
returns=call.frame,
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method()
def tail(self, call: Call) -> Type:
signature = Function(
args=[
Function.Argument(
pos=0,
name="n",
type=self.types.get_type("int"),
required=False,
),
],
returns=call.frame,
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method()
def groupby(self, call: Call) -> Type:
bool_: Type = self.types.get_type("bool")
function: Function = Function(
args=[
Function.Argument(
pos=0,
name="by",
type=TopType(),
required=False,
),
Function.Argument(
pos=1,
name="level",
type=TopType(),
required=False,
),
],
kw_args=[
Function.Argument(
pos=2,
name="as_index",
type=bool_,
required=False,
),
Function.Argument(
pos=3,
name="sort",
type=bool_,
required=False,
),
Function.Argument(
pos=4,
name="group_keys",
type=bool_,
required=False,
),
Function.Argument(
pos=5,
name="observed",
type=bool_,
required=False,
),
Function.Argument(
pos=6,
name="dropna",
type=bool_,
required=False,
),
],
returns=FrameGroupBy(frame=call.frame),
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=function,
positional=call.positional,
keywords=call.keywords,
)
return result.result
def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr):
func_name: str = "__midas_frame_same_length__"
# Efficiently compute length
# https://stackoverflow.com/a/15943975/11109181
def len_of_df(df: ast.expr) -> ast.expr:
return ast.Call(
func=ast.Name(id="len"),
args=[
ast.Attribute(
value=df,
attr="index",
)
],
keywords=[],
)
self.assertions.define(
func_name,
ast.FunctionDef(
name=func_name,
args=ast.arguments(
posonlyargs=[],
args=[
ast.arg(arg="frame1"),
ast.arg(arg="frame2"),
],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Return(
value=ast.Compare(
left=len_of_df(ast.Name(id="frame1")),
ops=[ast.Eq()],
comparators=[len_of_df(ast.Name(id="frame2"))],
)
)
],
decorator_list=[],
),
)
self.assertions.add(
bound_expr=call_expr,
inputs=[frame1, frame2],
builder=lambda f1, f2: ast.Call(
func=ast.Name(id=func_name),
args=[f1, f2],
keywords=[],
),
message="DataFrames must have the same length",
)

View File

@@ -0,0 +1,100 @@
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Optional,
Protocol,
Self,
TypeVar,
)
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.dispatcher import CallDispatcher
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import Type, UnknownType
from midas.generator.collector import AssertionCollector
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
class _MethodRegistryMeta(type):
_methods: dict[str, Callable[..., Type]] = {}
def __new__(
cls,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
):
new_class = super().__new__(cls, name, bases, namespace)
new_class._methods = {}
for attr in namespace.values():
if callable(attr) and hasattr(attr, "__method_names__"):
for name in attr.__method_names__: # type: ignore
new_class._methods[name] = attr # type: ignore
return new_class
class MethodCall(Protocol):
@property
def location(self) -> Location: ...
@property
def call_expr(self) -> p.Expr: ...
@property
def subject(self) -> TypedExpr: ...
T = TypeVar("T", bound=MethodCall)
class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
@property
def reporter(self) -> FileReporter:
return self.typer.reporter
@property
def types(self) -> TypesRegistry:
return self.typer.types
@property
def dispatcher(self) -> CallDispatcher[p.Expr]:
return self.typer.dispatcher
@property
def assertions(self) -> AssertionCollector:
return self.typer.assertions
def call(self, method: str, call: T) -> Type:
func: Optional[Callable[[Self, T], Type]] = self._methods.get(method)
if func is None:
self.reporter.warning(
call.location, f"Unknown method {method} on {call.subject[1]}"
)
return UnknownType()
return func(self, call)
_Self = TypeVar("_Self", bound=MethodRegistry[Any])
Method = Callable[[_Self, T], Type]
def method(*names: str) -> Callable[[Method[_Self, T]], Method[_Self, T]]:
def wrapper(func: Method[_Self, T]) -> Method[_Self, T]:
names_: tuple[str, ...] = names
if len(names_) == 0:
names_ = (func.__name__,)
setattr(func, "__method_names__", names_)
return func
return wrapper

View File

@@ -6,25 +6,25 @@ from typing import Optional
import midas.ast.midas as m
from midas.ast.location import Location
from midas.checker.builtins import define_builtins
from midas.checker.dispatcher import CallDispatcher, CallResult
from midas.checker.environment import Environment
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import (
AppliedType,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
DerivedType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
Predicate,
Type,
TypeVar,
UnknownType,
unfold_type,
)
from midas.checker.variance import VarianceInferrer
from midas.lexer.midas import MidasLexer
@@ -39,9 +39,6 @@ class TypedParamSpec:
kw: list[Function.Argument]
TypedExpr = tuple[m.Expr, Type]
class ReturnException(Exception):
pass
@@ -65,8 +62,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
self.logger: logging.Logger = logging.getLogger("MidasTyper")
self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types
self.dispatcher: CallDispatcher[m.Expr] = CallDispatcher[m.Expr](
self.types, self.reporter
)
self._local_variables: dict[str, TypeVar] = {}
self._predicate_params: dict[str, Type] = {}
@@ -81,8 +81,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
self._preamble: Environment = Preamble(self.types)
def set_reporter(self, reporter: FileReporter):
self.reporter = reporter
self.dispatcher.set_reporter(reporter)
def process(self, source: str, path: Optional[str]):
self.reporter = self.reporter.for_file(path)
reporter: FileReporter = self.reporter.for_file(path)
self.set_reporter(reporter)
lexer: MidasLexer = MidasLexer(source)
tokens: list[Token] = lexer.process()
parser: MidasParser = MidasParser(tokens)
@@ -257,13 +263,13 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
)
return UnknownType()
result: Optional[Type] = self._get_call_result(
location,
operation,
[(right_expr, right)],
{},
result: CallResult = self.dispatcher.get_result(
location=location,
callee=operation,
positional=[(right_expr, right)],
keywords={},
)
return result or UnknownType()
return result.result
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
@@ -283,31 +289,29 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
)
return UnknownType()
result: Optional[Type] = self._get_call_result(
expr.location,
operation,
[],
{},
result: CallResult = self.dispatcher.get_result(
location=expr.location,
callee=operation,
positional=[],
keywords={},
)
return result or UnknownType()
return result.result
def visit_call_expr(self, expr: m.CallExpr) -> Type:
callee: Type = expr.callee.accept(self)
positional: list[TypedExpr] = [
positional: list[tuple[m.Expr, Type]] = [
(arg, self.type_of(arg)) for arg in expr.arguments
]
keywords: dict[str, TypedExpr] = {
keywords: dict[str, tuple[m.Expr, Type]] = {
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
}
return (
self._get_call_result(
expr.location,
callee,
positional,
keywords,
)
or UnknownType()
result: CallResult = self.dispatcher.get_result(
location=expr.location,
callee=callee,
positional=positional,
keywords=keywords,
)
return result.result
def visit_get_expr(self, expr: m.GetExpr) -> Type:
object: Type = expr.expr.accept(self)
@@ -408,6 +412,18 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
)
def visit_frame_type(self, type: m.FrameType) -> Type:
def process_column(i: int, col: m.FrameType.Column) -> DataFrameType.Column:
return DataFrameType.Column(
index=i,
name=col.name.lexeme,
type=ColumnType(type=col.type.accept(self)),
)
return DataFrameType(
columns=[process_column(i, col) for i, col in enumerate(type.columns)]
)
def _resolve_type_params(self, params: list[m.TypeParam]):
vars: list[TypeVar] = []
for param in params:
@@ -419,343 +435,3 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
self._local_variables[name] = var
vars.append(var)
return vars
def _get_call_result(
self,
location: Location,
callee: Type,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> Optional[Type]:
"""Get the result type of a function call
If the function has overloads, the function will try to resolve the
appropriate signature.
Argument types are matched to the defined parameters.
The function doesn't take the raw expression as a parameter to accommodate
for desugared calls such as for operators.
Args:
location (Location): the call location
callee (Type): the called function
positional (list[TypedExpr]): the list positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Type: the return type of the call, or `None` if either
the call is invalid or no overload matched the arguments uniquely
"""
match callee:
case Function() as function:
valid: bool
mapped: list[MappedArgument]
valid, mapped = self.map_call_arguments(
function, location, positional, keywords
)
valid = valid and self._are_arguments_valid(mapped, report_errors)
if not valid:
return None
return function.returns
case OverloadedFunction(overloads=overloads):
function = self._match_overload(
overloads, location, positional, keywords, report_errors
)
if function is None:
return None
return function.returns
case AppliedType(body=body):
return self._get_call_result(
location, body, positional, keywords, report_errors
)
case UnknownType():
return UnknownType()
case _:
if report_errors:
self.reporter.error(location, f"{callee} is not callable")
return None
def _are_arguments_valid(
self,
arguments: list[MappedArgument],
report_errors: bool = True,
) -> bool:
"""Check whether the passed argument types correspond to their matched parameter definitions
Args:
arguments (list[MappedArgument]): the list of argument/parameter pairs
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
bool: True if all arguments fit the matching parameter definitions, False otherwise
"""
valid: bool = True
for arg in arguments:
if not self.types.is_subtype(arg.type, arg.argument.type):
if report_errors:
self.reporter.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
)
valid = False
return valid
def _match_overload(
self,
overloads: list[Type],
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> Optional[Function]:
"""Try and resolve the appropriate overload for the given arguments
Args:
overloads (list[Type]): the list of possible overloads
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keywords arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Optional[Function]: the resolved function signature if it can be
determined unambiguously, or `None`.
"""
candidates: list[OverloadCandidate] = []
for overload in overloads:
function: Type = unfold_type(overload)
if not isinstance(function, Function):
if report_errors:
self.logger.error(
f"Overload is not a function: {overload} is {function}"
)
continue
valid, mapped = self.map_call_arguments(
function=function,
location=location,
positional=positional,
keywords=keywords,
report_errors=False,
)
if valid and self._are_arguments_valid(mapped, report_errors=False):
candidates.append(
OverloadCandidate(
function=function,
mapped=mapped,
)
)
pos_types: str = ", ".join(str(type) for _, type in positional)
kw_types: str = ", ".join(
f"{name}: {type}" for name, (_, type) in keywords.items()
)
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
n_candidates: int = len(candidates)
# Exactly 1 match -> return it
if n_candidates == 1:
return candidates[0].function
# No match -> invalid call
if n_candidates == 0:
overloads_str: str = ", ".join(map(str, overloads))
if report_errors:
self.reporter.error(
location,
f"No matching overload in [{overloads_str}] {for_args}",
)
return None
# Multiple matches -> see if one <: all others (more specific)
for i1, c1 in enumerate(candidates):
mapped1: list[MappedArgument] = c1.mapped
best_match: bool = True
for i2, c2 in enumerate(candidates):
if i1 == i2:
continue
mapped2: list[MappedArgument] = c2.mapped
if not self._are_mapped_subtypes(mapped1, mapped2):
best_match = False
break
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
if best_match:
return c1.function
candidates_str: str = ", ".join(
str(candidate.function) for candidate in candidates
)
if report_errors:
self.reporter.error(
location,
f"Multiple matching overloads {for_args}: {candidates_str}",
)
return None
def map_call_arguments(
self,
function: Function,
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> tuple[bool, list[MappedArgument]]:
"""Map call arguments to a function's parameters as defined in its signature
This method maps positional-only, keyword-only and mixed parameter definitions
with the arguments passed at the call site
Any mismatched, missing or unexpected argument is reported as a diagnostic,
unless `report_errors` is set to `False`
Args:
function (Function): the function definition
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
tuple[bool, list[MappedArgument]]: a boolean reporting whether
the call is valid and the list of mapped arguments
"""
set_args: set[str] = set()
required_positional: list[str] = [
arg.name for arg in function.pos_args + function.args if arg.required
]
required_keyword: list[str] = [
arg.name for arg in function.kw_args if arg.required
]
mapped: list[MappedArgument] = []
pos_params: list[Function.Argument] = list(function.pos_args)
mixed_params: list[Function.Argument] = list(function.args)
kw_params: dict[str, Function.Argument] = {
arg.name: arg for arg in function.kw_args
}
valid_call: bool = True
# TODO: handle *args and **kwargs sinks
for arg in positional:
param: Function.Argument
if len(pos_params) != 0:
param = pos_params.pop(0)
elif len(mixed_params) != 0:
param = mixed_params.pop(0)
else:
if report_errors:
self.reporter.error(
arg[0].location, "Too many positional arguments"
)
valid_call = False
break
name: str = param.name
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
kw_params.update({arg.name: arg for arg in mixed_params})
for name, arg in keywords.items():
param: Function.Argument
if name not in kw_params:
if report_errors:
if name in set_args:
self.reporter.error(
arg[0].location, f"Multiple values for argument '{name}'"
)
else:
self.reporter.error(
arg[0].location, f"Unknown keyword argument '{name}'"
)
valid_call = False
continue
param = kw_params.pop(name)
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
def join_args(args: list[str]) -> str:
args = list(map(lambda a: f"'{a}'", args))
if len(args) == 0:
return ""
if len(args) == 1:
return args[0]
return ", ".join(args[:-1]) + " and " + args[-1]
if len(required_positional) != 0:
plural: str = "" if len(required_positional) == 1 else "s"
args: str = join_args(required_positional)
if report_errors:
self.reporter.error(
location,
f"Missing required positional argument{plural}: {args}",
)
valid_call = False
if len(required_keyword) != 0:
plural: str = "" if len(required_keyword) == 1 else "s"
args: str = join_args(required_keyword)
if report_errors:
self.reporter.error(
location,
f"Missing required keyword argument{plural}: {args}",
)
valid_call = False
return valid_call, mapped
def _are_mapped_subtypes(
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
) -> bool:
"""Check whether the given argument mappings are subtype/supertype of one another
This function checks whether the argument mappings `mapped1` are subtypes
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
of the corresponding parameter in `mapped2`, `False` is returned.
This is used to check whether a given overload is
a more specific function/ a subtype of another.
Args:
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
Returns:
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
"""
by_expr: dict[m.Expr, Type] = {}
for arg in mapped1:
by_expr[arg.expr] = arg.argument.type
for arg in mapped2:
type2: Type = arg.argument.type
type1: Type = by_expr[arg.expr]
if not self.types.is_subtype(type1, type2):
return False
return True

View File

@@ -41,7 +41,7 @@ PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = {
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
# TokenType.PLUS: "__add__",
TokenType.PLUS: "__add__",
TokenType.MINUS: "__sub__",
TokenType.STAR: "__mul__",
TokenType.SLASH: "__truediv__",

View File

@@ -1,9 +1,17 @@
from dataclasses import dataclass
from typing import Callable, Optional
from typing import Any, Callable, Optional
from midas.checker.environment import Environment
from midas.checker.registry import TypesRegistry
from midas.checker.types import Function, GenericType, TopType, Type, TypeVar, UnitType
from midas.checker.types import (
Function,
GenericType,
OverloadedFunction,
TopType,
Type,
TypeVar,
UnitType,
)
@dataclass(frozen=True)
@@ -17,7 +25,7 @@ class Preamble(Environment):
def __init__(self, types: TypesRegistry) -> None:
super().__init__()
self._types: TypesRegistry = types
self._python_funcs: dict[str, Callable] = {}
self._python_funcs: dict[str, Callable[..., Any]] = {}
self._def_type_constructor("object", object)
self._def_type_constructor("float", float)
@@ -34,7 +42,7 @@ class Preamble(Environment):
# TODO: use sink
self._def_function(
name="print",
pos=[Param("object", TopType())],
pos=[Param("object", TopType(), required=False)],
returns=UnitType(),
py_function=print,
)
@@ -64,11 +72,48 @@ class Preamble(Environment):
pos=[Param("prompt", TopType(), required=False)],
returns=self._types.get_type("str"),
)
self._def_function(
name="len",
pos=[Param("object", TopType())],
returns=self._types.get_type("int"),
)
def _list_of(self, item_type: Type) -> Type:
return self._types.apply_generic(self._types.get_type("list"), [item_type])
T = TypeVar(name="T", bound=None)
self._def_overloads(
name="max",
py_function=max,
signatures=[
(
[Param("arg1", T), Param("arg2", T)],
[],
[],
T,
[T],
),
([Param("iterable", self._list_of(T))], [], [], T, [T]),
],
)
self._def_overloads(
name="min",
py_function=min,
signatures=[
(
[Param("arg1", T), Param("arg2", T)],
[],
[],
T,
[T],
),
([Param("iterable", self._list_of(T))], [], [], T, [T]),
],
)
def _def_type_constructor(self, name: str, py_function: Optional[Callable] = None):
def _list_of(self, item_type: str | Type) -> Type:
return self._types.list_of(item_type)
def _def_type_constructor(
self, name: str, py_function: Optional[Callable[..., Any]] = None
):
# TODO: more specific arg types
self._def_function(
name=name,
@@ -121,7 +166,7 @@ class Preamble(Environment):
kw: list[Param] = [],
returns: Type = UnitType(),
type_vars: list[TypeVar] = [],
py_function: Optional[Callable] = None,
py_function: Optional[Callable[..., Any]] = None,
):
function: Type = self._make_function(
name=name,
@@ -135,5 +180,31 @@ class Preamble(Environment):
if py_function is not None:
self._python_funcs[name] = py_function
def get_py_func(self, name: str) -> Optional[Callable]:
def _def_overloads(
self,
*,
name: str,
signatures: list[
tuple[list[Param], list[Param], list[Param], Type, list[TypeVar]]
],
py_function: Optional[Callable[..., Any]] = None,
):
overloads: list[Type] = []
for pos, mixed, kw, returns, type_vars in signatures:
overloads.append(
self._make_function(
name=name,
pos=pos,
mixed=mixed,
kw=kw,
returns=returns,
type_vars=type_vars,
)
)
function: Type = OverloadedFunction(overloads=overloads)
self.define(name, function)
if py_function is not None:
self._python_funcs[name] = py_function
def get_py_func(self, name: str) -> Optional[Callable[..., Any]]:
return self._python_funcs.get(name)

File diff suppressed because it is too large Load Diff

View File

@@ -7,8 +7,10 @@ from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.types import (
AppliedType,
BaseType,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
DerivedType,
ExtensionType,
Function,
@@ -16,6 +18,7 @@ from midas.checker.types import (
OverloadedFunction,
Predicate,
TopType,
TupleType,
Type,
TypeVar,
UnknownType,
@@ -110,6 +113,15 @@ class TypesRegistry:
raise ValueError(f"Predicate {name} already defined")
self._predicates[name] = predicate
def is_builtin_subtype(self, name1: str, name2: str) -> bool:
subtypes: set[str] = BUILTIN_SUBTYPES.get(name2, set())
if name1 in subtypes:
return True
for subtype in subtypes:
if self.is_builtin_subtype(name1, subtype):
return True
return False
def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2`
@@ -147,7 +159,7 @@ class TypesRegistry:
return self.is_subtype(base1, type2)
case (BaseType(name=name1), BaseType(name=name2)):
return name1 in BUILTIN_SUBTYPES.get(name2, set())
return self.is_builtin_subtype(name1, name2)
case (ComplexType(properties=props1), ComplexType(properties=props2)):
for k, t in props2.items():
@@ -157,6 +169,24 @@ class TypesRegistry:
return False
return True
case (DataFrameType(columns=columns1), DataFrameType(columns=columns2)):
# TODO: check order?
by_name1: dict[str, DataFrameType.Column] = {
col.name: col for col in columns1 if col.name is not None
}
for col2 in columns2:
if col2.name not in by_name1:
return False
if not self.is_subtype(by_name1[col2.name].type, col2.type):
return False
return True
case (ColumnType(type=inner1), ColumnType(type=inner2)):
# TODO: invariant, replace ColumnType with simple GenericType
if not self.are_equivalent(inner1, inner2):
return False
return True
case (Function(), Function()):
return self.is_func_subtype(type1, type2)
@@ -187,6 +217,9 @@ class TypesRegistry:
return False
def are_equivalent(self, type1: Type, type2: Type) -> bool:
return self.is_subtype(type1, type2) and self.is_subtype(type2, type1)
# TODO: verify the logic in here
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
"""Check whether a function is a subtype of another
@@ -323,6 +356,9 @@ class TypesRegistry:
body=substitute_typevars(body, substitutions),
)
case BaseType(name="tuple"):
return TupleType(items=tuple(args))
case _:
raise ValueError(f"{type} is not a generic type")
@@ -416,3 +452,29 @@ class TypesRegistry:
def lookup_predicate(self, name: str) -> Optional[Predicate]:
return self._predicates.get(name)
def _by_name_or_type(self, name_or_type: str | Type) -> Type:
if isinstance(name_or_type, str):
return self.get_type(name_or_type)
return name_or_type
def list_of(self, item_type: str | Type) -> Type:
list_ = self.get_type("list")
return self.apply_generic(list_, [self._by_name_or_type(item_type)])
def tuple_of(self, *item_types: str | Type) -> Type:
tuple_ = self.get_type("tuple")
return self.apply_generic(
tuple_,
[self._by_name_or_type(item_type) for item_type in item_types],
)
def dict_of(self, key_type: str | Type, value_type: str | Type) -> Type:
dict_ = self.get_type("dict")
return self.apply_generic(
dict_,
[
self._by_name_or_type(key_type),
self._by_name_or_type(value_type),
],
)

View File

@@ -128,6 +128,10 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
case p.GetExpr():
target.accept(self)
case p.SubscriptExpr():
target.accept(self)
case _:
raise Exception(f"Unsupported assignment to {target}")
@@ -232,5 +236,9 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
if expr.step is not None:
self.resolve(expr.step)
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
for item in expr.items:
self.resolve(item)
def visit_raw_expr(self, expr: p.RawExpr) -> None:
pass

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Optional, assert_never
from typing import Optional, assert_never, cast
import midas.ast.midas as m
from midas.ast.printer import MidasPrinter
@@ -156,6 +156,53 @@ class ConstraintType:
return f"{self.type} where {printer.print(self.constraint)}"
@dataclass(frozen=True, kw_only=True)
class TupleType:
items: tuple[Type, ...]
def __str__(self) -> str:
return f"({', '.join(map(str, self.items))})"
@dataclass(frozen=True, kw_only=True)
class ColumnType:
type: Type
def __str__(self) -> str:
return f"Column[{self.type}]"
@dataclass(frozen=True, kw_only=True)
class DataFrameType:
columns: list[Column]
def __str__(self) -> str:
schema: list[str] = [f"{col.name}: {col.type}" for col in self.columns]
return f"Frame[{', '.join(schema)}]"
@dataclass(frozen=True, kw_only=True)
class Column:
index: int
name: Optional[str]
type: ColumnType
@dataclass(frozen=True, kw_only=True)
class FrameGroupBy:
frame: DataFrameType
def __str__(self) -> str:
return f"FrameGroupBy[{self.frame}]"
@dataclass(frozen=True, kw_only=True)
class ColumnGroupBy:
column: ColumnType
def __str__(self) -> str:
return f"ColumnGroupBy[{self.column}]"
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def sub_argument(arg: Function.Argument):
return Function.Argument(
@@ -165,6 +212,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
required=arg.required,
)
def sub_column(col: DataFrameType.Column):
return DataFrameType.Column(
index=col.index,
name=col.name,
type=cast(ColumnType, substitute_typevars(col.type, substitutions)),
)
match type:
case TopType():
return type
@@ -252,6 +306,31 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
body=substitute_typevars(body, substitutions),
)
case TupleType(items=items):
return TupleType(
items=tuple(substitute_typevars(item, substitutions) for item in items),
)
case ColumnType(type=items_type):
return ColumnType(
type=substitute_typevars(items_type, substitutions),
)
case DataFrameType(columns=columns):
return DataFrameType(
columns=list(map(sub_column, columns)),
)
case FrameGroupBy(frame=frame):
return FrameGroupBy(
frame=cast(DataFrameType, substitute_typevars(frame, substitutions))
)
case ColumnGroupBy(column=column):
return ColumnGroupBy(
column=cast(ColumnType, substitute_typevars(column, substitutions))
)
case UnknownType() | UnitType():
return type
@@ -319,6 +398,21 @@ def to_annotation(type: Type) -> str:
case ConstraintType():
return str(type)
case TupleType(items=items):
return f"Tuple[{', '.join(map(to_annotation, items))}]"
case ColumnType():
return "pd.Series"
case DataFrameType():
return "pd.DataFrame"
case FrameGroupBy():
return "pd.api.typing.DataFrameGroupBy"
case ColumnGroupBy():
return "pd.api.typing.SeriesGroupBy"
case _:
assert_never(type)
@@ -344,4 +438,9 @@ Type = (
| GenericType
| AppliedType
| ConstraintType
| TupleType
| ColumnType
| DataFrameType
| FrameGroupBy
| ColumnGroupBy
)

View File

@@ -4,6 +4,8 @@ from typing import Optional
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
AppliedType,
ColumnType,
DataFrameType,
Function,
GenericType,
TopType,
@@ -98,6 +100,30 @@ class Unifier:
return substitutions
case (
DataFrameType(columns=template_columns),
DataFrameType(columns=concrete_columns),
) if len(template_columns) == len(concrete_columns):
substitutions: dict[str, Type] = {}
for template_column, concrete_column in zip(
template_columns, concrete_columns
):
if template_column.index != concrete_column or (
template_column.name != concrete_column.name
):
self.logger.debug(
f"Column mismatch: template={template_column}, concrete={concrete_column}"
)
raise UnificationError
new_substistutions: dict[str, Type] = self.match(
template_column.type, concrete_column.type
)
substitutions = self.merge(substitutions, new_substistutions)
return substitutions
case (ColumnType(type=template_column), ColumnType(type=concrete_column)):
return self.match(template_column, concrete_column)
case (Function(), Function()):
mapped: list[tuple[Function.Argument, Function.Argument]] = (
self.map_params(template, concrete)

View File

@@ -5,7 +5,7 @@
import sys
from pathlib import Path
from typing import TextIO
from typing import Optional, TextIO
import click
@@ -19,18 +19,23 @@ from midas.utils import TypedAST
@click.command(help="Compile source")
@click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-s", "--stubs", type=str, multiple=True)
@click.option("--ignore-errors", is_flag=True)
def compile(
file: TextIO,
types: tuple[TextIO],
stubs: tuple[str],
ignore_errors: bool,
):
source: str = file.read()
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
for types_file in types:
checker.import_midas(Path(types_file.name).resolve())
type_files: list[tuple[Path, Optional[str]]] = []
for i, types_file in enumerate(types):
in_path: Path = Path(types_file.name).resolve()
checker.import_midas(in_path)
type_files.append((in_path, stubs[i] if i < len(stubs) else None))
typed_ast: TypedAST = checker.type_check_source(source, str(source_path))
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
@@ -43,4 +48,4 @@ def compile(
sys.exit(1)
generator = Generator(workdir=source_path.parent, types=checker.types)
generator.generate(typed_ast, source_path)
generator.generate(typed_ast, source_path, type_files=type_files)

View File

@@ -1,7 +1,7 @@
import ast
import time
from pathlib import Path
from typing import TextIO
from typing import Optional, TextIO
import black
import click
@@ -38,15 +38,17 @@ class Handler(FileSystemEventHandler):
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
@click.option("-o", "--output", type=click.File("w"))
@click.option("-w", "--watch", is_flag=True)
def stubs(
file: TextIO,
output: TextIO,
output: Optional[TextIO],
watch: bool,
):
source_path: Path = Path(file.name).resolve()
out_path: Path = Path(output.name).resolve()
out_path: Path = source_path.with_suffix(".pyi")
if output is not None:
out_path = Path(output.name).resolve()
generate_stubs(source_path, out_path)
if watch:

View File

@@ -134,9 +134,9 @@ class PythonHighlighter(
def visit_base_type(self, node: p.BaseType) -> None:
self.wrap(node, "base-type")
if node.param is not None:
self.wrap(node.param, "param")
node.param.accept(self)
for arg in node.args:
self.wrap(arg, "arg")
arg.accept(self)
def visit_constraint_type(self, node: p.ConstraintType) -> None:
self.wrap(node, "constraint-type")
@@ -247,6 +247,10 @@ class PythonHighlighter(
if expr.step is not None:
expr.step.accept(self)
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
for item in expr.items:
item.accept(self)
def visit_raw_expr(self, expr: p.RawExpr) -> None: ...
def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ...
@@ -350,6 +354,14 @@ class MidasHighlighter(
for param in spec.pos + spec.mixed + spec.kw:
param.type.accept(self)
def visit_frame_type(self, type: m.FrameType) -> None:
self.wrap(type, "frame")
for column in type.columns:
self._visit_frame_column(column)
def _visit_frame_column(self, column: m.FrameType.Column) -> None:
self.wrap(column, "column")
class DiagnosticsHighlighter(Highlighter):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"

View File

@@ -3,7 +3,7 @@ span {
--col: 108, 233, 108;
}
&.param {
&.arg {
--col: 103, 192, 224;
}

View File

@@ -68,7 +68,7 @@ class DiagnosticPrinter:
loc: Location = diagnostic.location
if loc.lineno != loc.end_lineno:
print(diagnostic)
self.print_multiline(lines, diagnostic, indent)
return
start_offset: int = loc.col_offset
@@ -95,3 +95,27 @@ class DiagnosticPrinter:
print(indent_str + before + subject + after)
print(indent_str + cursor)
print()
def print_multiline(
self, all_lines: list[str], diagnostic: Diagnostic, indent: int = 4
):
loc: Location = diagnostic.location
lines: list[str] = all_lines[loc.lineno - 1 : loc.end_lineno]
start_offset: int = loc.col_offset
end_offset: int = loc.end_col_offset or (start_offset + 1)
indent_str: str = " " * indent
color: int = self.COLORS.get(diagnostic.type, Ansi.WHITE)
res: str = indent_str + lines[0][:start_offset]
res += Ansi.FG(color) + lines[0][start_offset:]
for line in lines[1:-1]:
res += "\n" + indent_str + line
res += "\n" + indent_str + lines[-1][:end_offset]
res += Ansi.RESET + lines[-1][end_offset:]
print(diagnostic.location_str + ":")
print(res)
print()
print(Ansi.FG(color) + diagnostic.message + Ansi.RESET)
print()

View File

@@ -0,0 +1,59 @@
import ast
from dataclasses import dataclass
from typing import Callable
import midas.ast.python as p
AssertionBuilder = Callable[..., ast.expr]
@dataclass
class Assertion:
bound_expr: p.Expr
inputs: list[p.Expr]
builder: AssertionBuilder
message: str
def is_bound_to(self, expr: p.Expr) -> bool:
return expr == self.bound_expr
class AssertionCollector:
def __init__(self):
self.assertions: list[Assertion] = []
self.definitions: dict[str, ast.stmt] = {}
def add(
self,
bound_expr: p.Expr,
inputs: list[p.Expr],
builder: AssertionBuilder,
message: str,
):
self.assertions.append(
Assertion(
bound_expr=bound_expr,
inputs=inputs,
builder=builder,
message=message,
)
)
def remove(self, assertion: Assertion):
try:
self.assertions.remove(assertion)
except ValueError:
pass
def define(self, name: str, stmt: ast.stmt):
if name not in self.definitions:
self.definitions[name] = stmt
def get_definitions(self) -> list[ast.stmt]:
return list(self.definitions.values())
def get_assertions(self) -> list[Assertion]:
return self.assertions
def get_assertions_for(self, expr: p.Expr) -> list[Assertion]:
return list(filter(lambda a: a.is_bound_to(expr), self.assertions))

View File

@@ -1,4 +1,5 @@
import ast
import logging
import shutil
from dataclasses import dataclass, field
from pathlib import Path
@@ -8,65 +9,96 @@ import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasPrinter
from midas.checker.checker import TypeChecker
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
AppliedType,
BaseType,
ColumnGroupBy,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
DerivedType,
ExtensionType,
FrameGroupBy,
Function,
GenericType,
OverloadedFunction,
TopType,
TupleType,
Type,
TypeVar,
UnitType,
UnknownType,
)
from midas.generator.collector import Assertion, AssertionCollector
from midas.generator.constraints import ConstraintGenerator
from midas.generator.stubs import StubsGenerator
from midas.utils import TypedAST
@dataclass
class Scope:
pre_assertions: list[ast.stmt] = field(default_factory=list)
aliases: list[str] = field(default_factory=list)
pre_assertions: list[ast.stmt] = field(default_factory=list[ast.stmt])
aliases: list[str] = field(default_factory=list[str])
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
IS_DATAFRAME_FUNC = "__midas_is_dataframe__"
IS_COLUMN_FUNC = "__midas_is_column__"
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
self.workdir: Path = workdir.resolve()
self.build_dir: Path = self.workdir / "build" / "midas"
self.rel_src_path: Path = Path()
self.logger: logging.Logger = logging.getLogger("Generator")
self._typed_ast: TypedAST = TypedAST(
stmts=[],
judgements=[],
evaluated_casts=[],
assertions=AssertionCollector(),
)
self._alias_count: int = 0
self._predicate_count: int = 0
self._scopes: list[Scope] = []
self._aliases: list[tuple[p.Expr, ast.expr]] = []
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
self._constraints: list[tuple[m.Expr, ast.expr]] = []
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
self.rel_src_path = src_path.resolve().relative_to(self.workdir)
self.define_is_dataframe: bool = False
self.define_is_column: bool = False
def set_src_path(self, path: Path):
self.rel_src_path = path.resolve().relative_to(self.workdir)
def generate_ast(self, typed_ast: TypedAST) -> ast.AST:
self._typed_ast = typed_ast
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
body: list[ast.stmt] = self._visit_body(typed_ast.stmts, can_be_empty=True)
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
module = ast.Module(body=predicates + body, type_ignores=[])
body = predicates + body
if self.define_is_dataframe:
body = [self._is_dataframe_definition()] + body
if self.define_is_column:
body = [self._is_column_definition()] + body
module = ast.Module(body=body, type_ignores=[])
module = ast.fix_missing_locations(module)
return module
def generate(
self, typed_ast: TypedAST, src_path: Path, out_path: Optional[Path] = None
self,
typed_ast: TypedAST,
src_path: Path,
out_path: Optional[Path] = None,
type_files: Optional[list[tuple[Path, Optional[str]]]] = None,
) -> Path:
module: ast.AST = self.generate_ast(typed_ast, src_path)
compiled: str = ast.unparse(module)
self.set_src_path(src_path)
if out_path is None:
if self.build_dir.exists():
shutil.rmtree(self.build_dir)
@@ -78,43 +110,72 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
raise ValueError(
f"Directory traversal, {self.rel_src_path} points outside of parent directory"
)
out_path.parent.mkdir(parents=True, exist_ok=True)
out_dir: Path = out_path.parent
out_dir.parent.mkdir(parents=True, exist_ok=True)
if type_files is not None:
for in_path, out_name in type_files:
if out_name is None:
out_name = in_path.stem
self.generate_stubs(in_path, out_dir / f"{out_name}.py")
module: ast.AST = self.generate_ast(typed_ast)
compiled: str = ast.unparse(module)
out_path.write_text(compiled)
return out_path
def generate_stubs(self, in_path: Path, out_path: Path):
checker = TypeChecker()
checker.import_midas(in_path)
generator = StubsGenerator(checker.types)
module: ast.Module = generator.generate_stubs()
module = ast.fix_missing_locations(module)
output: str = ast.unparse(module)
out_path.write_text(output)
def convert(self, expr: p.Expr) -> ast.expr:
for expr2, alias in self._aliases:
if expr2 == expr:
return alias
assertions = self._typed_ast.assertions.get_assertions_for(expr)
if len(assertions) != 0:
return self._apply_assertions(expr, assertions)
return expr.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> ast.expr:
return ast.BinOp(
left=expr.left.accept(self),
left=self.convert(expr.left),
op=expr.operator,
right=expr.right.accept(self),
right=self.convert(expr.right),
)
def visit_compare_expr(self, expr: p.CompareExpr) -> ast.expr:
return ast.Compare(
left=expr.left.accept(self),
left=self.convert(expr.left),
ops=[expr.operator],
comparators=[expr.right.accept(self)],
comparators=[self.convert(expr.right)],
)
def visit_unary_expr(self, expr: p.UnaryExpr) -> ast.expr:
return ast.UnaryOp(
op=expr.operator,
operand=expr.right.accept(self),
operand=self.convert(expr.right),
)
def visit_call_expr(self, expr: p.CallExpr) -> ast.expr:
return ast.Call(
func=expr.callee.accept(self),
args=[arg.accept(self) for arg in expr.arguments],
func=self.convert(expr.callee),
args=[self.convert(arg) for arg in expr.arguments],
keywords=[
ast.keyword(arg=name, value=arg.accept(self))
ast.keyword(arg=name, value=self.convert(arg))
for name, arg in expr.keywords.items()
],
)
def visit_get_expr(self, expr: p.GetExpr) -> ast.expr:
return ast.Attribute(
value=expr.object.accept(self),
value=self.convert(expr.object),
attr=expr.name,
)
@@ -127,51 +188,58 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_logical_expr(self, expr: p.LogicalExpr) -> ast.expr:
return ast.BoolOp(
op=expr.operator,
values=[expr.left.accept(self), expr.right.accept(self)],
values=[self.convert(expr.left), self.convert(expr.right)],
)
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
expr2: ast.expr = expr.expr.accept(self)
expr2: ast.expr = self.convert(expr.expr)
if expr in self._typed_ast.evaluated_casts or expr.unsafe:
return expr2
alias: ast.expr = self._make_alias(expr2)
alias: ast.expr = self._make_alias(expr.expr, expr2)
type: Type = self._get_expr_type(expr)
self._make_cast_asserts(expr.location, alias, type)
asserts: list[ast.stmt] = self._make_cast_asserts(expr.location, alias, type)
for assert_ in asserts:
self._add_assert(assert_)
return alias
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
return ast.IfExp(
test=expr.test.accept(self),
body=expr.if_true.accept(self),
orelse=expr.if_false.accept(self),
test=self.convert(expr.test),
body=self.convert(expr.if_true),
orelse=self.convert(expr.if_false),
)
def visit_list_expr(self, expr: p.ListExpr) -> ast.expr:
return ast.List(
elts=[item.accept(self) for item in expr.items],
elts=[self.convert(item) for item in expr.items],
)
def visit_dict_expr(self, expr: p.DictExpr) -> ast.expr:
return ast.Dict(
keys=[key.accept(self) if key is not None else None for key in expr.keys],
values=[value.accept(self) for value in expr.values],
keys=[self.convert(key) if key is not None else None for key in expr.keys],
values=[self.convert(value) for value in expr.values],
)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr:
return ast.Subscript(
value=expr.object.accept(self),
slice=expr.index.accept(self),
value=self.convert(expr.object),
slice=self.convert(expr.index),
)
def visit_slice_expr(self, expr: p.SliceExpr) -> ast.expr:
return ast.Slice(
lower=expr.lower.accept(self) if expr.lower is not None else None,
upper=expr.upper.accept(self) if expr.upper is not None else None,
step=expr.step.accept(self) if expr.step is not None else None,
lower=self.convert(expr.lower) if expr.lower is not None else None,
upper=self.convert(expr.upper) if expr.upper is not None else None,
step=self.convert(expr.step) if expr.step is not None else None,
)
def visit_tuple_expr(self, expr: p.TupleExpr) -> ast.expr:
return ast.Tuple(
elts=[self.convert(item) for item in expr.items],
)
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
@@ -179,7 +247,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt:
return ast.Expr(
value=stmt.expr.accept(self),
value=self.convert(stmt.expr),
)
def visit_function(self, stmt: p.Function) -> ast.stmt:
@@ -192,12 +260,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
kwonlyargs=[ast.arg(arg=arg.name) for arg in stmt.kwonlyargs],
kwarg=None,
defaults=[
arg.default.accept(self)
self.convert(arg.default)
for arg in stmt.posonlyargs + stmt.args
if arg.default is not None
],
kw_defaults=[
arg.default.accept(self) if arg.default is not None else None
self.convert(arg.default) if arg.default is not None else None
for arg in stmt.kwonlyargs
],
),
@@ -211,20 +279,20 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_assign_stmt(self, stmt: p.AssignStmt) -> ast.stmt:
return ast.Assign(
targets=[target.accept(self) for target in stmt.targets],
value=stmt.value.accept(self),
targets=[self.convert(target) for target in stmt.targets],
value=self.convert(stmt.value),
)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> ast.stmt:
return ast.Return(
value=stmt.value.accept(self) if stmt.value is not None else None,
value=self.convert(stmt.value) if stmt.value is not None else None,
)
def visit_if_stmt(self, stmt: p.IfStmt) -> ast.stmt:
return ast.If(
test=stmt.test.accept(self),
test=self.convert(stmt.test),
body=self._visit_body(stmt.body),
orelse=self._visit_body(stmt.orelse),
orelse=self._visit_body(stmt.orelse, can_be_empty=True),
)
def visit_pass(self, stmt: p.Pass) -> ast.stmt:
@@ -232,8 +300,8 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_for_stmt(self, stmt: p.ForStmt) -> ast.stmt:
return ast.For(
target=stmt.target.accept(self),
iter=stmt.iterator.accept(self),
target=self.convert(stmt.target),
iter=self.convert(stmt.iterator),
body=self._visit_body(stmt.body),
orelse=[],
)
@@ -241,7 +309,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_raw_stmt(self, stmt: p.RawStmt) -> ast.stmt:
return stmt.stmt
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
def _visit_body(
self, stmts: list[p.Stmt], can_be_empty: bool = False
) -> list[ast.stmt]:
generated: list[ast.stmt] = []
for stmt in stmts:
scope = Scope()
@@ -259,9 +329,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
# Remove redundant pass statements
if len(generated) > 1:
generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)]
if len(generated) == 0 and not can_be_empty:
generated = [ast.Pass()]
return generated
def _make_alias(self, expr: ast.expr) -> ast.expr:
def _make_alias(self, node: p.Expr, expr: ast.expr) -> ast.expr:
name: str = f"__midas_a{self._alias_count}__"
alias = ast.Name(id=name)
self._alias_count += 1
@@ -272,82 +344,182 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
value=expr,
)
)
self._aliases.append((node, alias))
return alias
def _add_assert(self, expr: ast.expr, message: str | ast.expr):
def _build_assert(self, expr: ast.expr, message: str | ast.expr) -> ast.stmt:
if isinstance(message, str):
message = ast.Constant(value=message)
self._scopes[-1].pre_assertions.append(
ast.Assert(
test=expr,
msg=message,
)
return ast.Assert(
test=expr,
msg=message,
)
def _add_assert(self, assertion: ast.stmt):
self._scopes[-1].pre_assertions.append(assertion)
def _get_expr_type(self, query: p.Expr) -> Type:
for expr, type in self._typed_ast.judgements:
if expr == query:
return type
raise RuntimeError(f"Cannot get type judgement for {query}")
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
def _make_cast_asserts(
self, src_location: Location, expr: ast.expr, type: Type
) -> list[ast.stmt]:
match type:
case UnknownType():
pass
case UnknownType() | TopType():
return []
case BaseType(name=name):
self._add_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id=name)],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
)
return [
self._build_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id=name)],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
)
]
case DerivedType(type=base):
self._make_cast_asserts(src_location, expr, base)
return self._make_cast_asserts(src_location, expr, base)
case UnitType():
self._add_assert(
ast.Compare(
left=expr,
ops=[ast.Is()],
comparators=[
ast.Constant(value=None),
],
return [
self._build_assert(
ast.Compare(
left=expr,
ops=[ast.Is()],
comparators=[
ast.Constant(value=None),
],
),
self._make_cast_assert_message(src_location, expr, type),
),
self._make_cast_assert_message(src_location, expr, type),
)
]
case AppliedType(body=body):
self._make_cast_asserts(src_location, expr, body)
return self._make_cast_asserts(src_location, expr, body)
case ConstraintType(type=base, constraint=constraint):
self._make_cast_asserts(src_location, expr, base)
self._make_constraint_assert(src_location, expr, constraint)
asserts: list[ast.stmt] = self._make_cast_asserts(
src_location, expr, base
)
asserts.append(
self._make_constraint_assert(src_location, expr, constraint)
)
return asserts
case TypeVar(bound=bound):
# TODO: check with type from arguments / use call-site context
if bound is not None:
self._make_cast_asserts(src_location, expr, bound)
if bound is None:
return []
return self._make_cast_asserts(src_location, expr, bound)
case TupleType(items=items):
asserts: list[ast.stmt] = [
self._build_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id="tuple")],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
),
]
assert isinstance(expr, ast.Tuple)
for item, item_type in zip(expr.elts, items):
asserts.extend(
self._make_cast_asserts(src_location, item, item_type)
)
return asserts
case DataFrameType(columns=columns):
self.define_is_dataframe = True
asserts: list[ast.stmt] = [
self._build_assert(
ast.Call(
func=ast.Name(id=self.IS_DATAFRAME_FUNC),
args=[expr],
keywords=[],
),
self._make_cast_assert_message(
src_location, expr, type, ": Not a dataframe"
),
),
]
for column in columns:
asserts.append(
self._build_assert(
ast.Compare(
left=ast.Constant(value=column.name),
ops=[ast.In()],
comparators=[expr],
),
self._make_cast_assert_message(
src_location,
expr,
type,
f": Missing column {column.name}",
),
)
)
asserts.extend(
self._make_cast_asserts(
src_location,
ast.Subscript(
value=expr, slice=ast.Constant(value=column.name)
),
column.type,
)
)
return asserts
case ColumnType():
self.define_is_column = True
asserts: list[ast.stmt] = [
self._build_assert(
ast.Call(
func=ast.Name(id=self.IS_COLUMN_FUNC),
args=[expr],
keywords=[],
),
self._make_cast_assert_message(
src_location, expr, type, ": Not a column"
),
),
]
inner_assert: Optional[ast.stmt] = self._make_column_inner_assert(
src_location, expr, type
)
if inner_assert is not None:
asserts.append(inner_assert)
return asserts
case (
TopType()
| Function()
Function()
| OverloadedFunction()
| ComplexType()
| ExtensionType()
| GenericType()
| FrameGroupBy()
| ColumnGroupBy()
):
raise NotImplementedError(f"Can't make assertion for type {type}")
self.logger.warning(f"Can't make assertion for type {type}")
return []
# Ensure exhaustiveness
case _:
assert_never(type)
def _make_cast_assert_message(
self, location: Location, expr: ast.expr, type: Type
self,
location: Location,
expr: ast.expr,
type: Type,
extra: Optional[str] = None,
) -> ast.expr:
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
@@ -365,15 +537,15 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
),
conversion=-1,
),
ast.Constant(f" to {type}"),
ast.Constant(f" to {type}{extra or ''}"),
]
)
def _make_constraint_assert(
self, src_location: Location, expr: ast.expr, constraint: m.Expr
):
) -> ast.stmt:
test_func: ast.expr = self._get_constraint(constraint)
self._add_assert(
return self._build_assert(
ast.Call(
func=test_func,
args=[expr],
@@ -401,3 +573,117 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
constraint: ast.expr = self._constraint_generator.generate(expr)
self._constraints.append((expr, constraint))
return constraint
def _is_dataframe_definition(self) -> ast.stmt:
"""
def IS_DATAFRAME_FUNC(obj) -> bool:
import pandas as pd
return isinstance(obj, pd.DataFrame)
"""
return ast.FunctionDef(
name=self.IS_DATAFRAME_FUNC,
args=ast.arguments(
posonlyargs=[ast.arg(arg="obj")],
args=[],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
ast.Return(
value=ast.Call(
func=ast.Name(id="isinstance"),
args=[
ast.Name(id="obj"),
ast.Attribute(
value=ast.Name(id="pd"),
attr="DataFrame",
),
],
keywords=[],
)
),
],
decorator_list=[],
returns=ast.Name(id="bool"),
)
def _is_column_definition(self) -> ast.stmt:
"""
def IS_COLUMN_FUNC(obj) -> bool:
import pandas as pd
return isinstance(obj, pd.Series)
"""
return ast.FunctionDef(
name=self.IS_COLUMN_FUNC,
args=ast.arguments(
posonlyargs=[ast.arg(arg="obj")],
args=[],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
ast.Return(
value=ast.Call(
func=ast.Name(id="isinstance"),
args=[
ast.Name(id="obj"),
ast.Attribute(
value=ast.Name(id="pd"),
attr="Series",
),
],
keywords=[],
)
),
],
decorator_list=[],
returns=ast.Name(id="bool"),
)
def _make_column_inner_assert(
self, src_location: Location, column: ast.expr, type: ColumnType
) -> Optional[ast.stmt]:
# TODO: improve message, maybe chain contexts
col: ast.expr = ast.Name(id="col")
body: list[ast.stmt] = self._make_cast_asserts(src_location, col, type.type)
if len(body) == 0:
return None
return ast.For(
target=col,
iter=column,
body=body,
orelse=[],
)
def _convert_assertion(self, assertion: Assertion) -> ast.stmt:
inputs: list[ast.expr] = []
for input in assertion.inputs:
converted: ast.expr = self.convert(input)
alias: ast.expr = self._make_alias(input, converted)
inputs.append(alias)
test: ast.expr = assertion.builder(*inputs)
location: Location = assertion.bound_expr.location
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
return self._build_assert(
test, f"{loc_str}: AssertionError: {assertion.message}"
)
def _apply_assertions(self, expr: p.Expr, assertions: list[Assertion]) -> ast.expr:
for assertion in assertions:
assert_stmt: ast.stmt
assert_stmt = self._convert_assertion(assertion)
self._add_assert(assert_stmt)
# Mutating list in frozen dataclass
# Not ideal but easiest way to avoid duplicate assertions
self._typed_ast.assertions.remove(assertion)
return expr.accept(self)

View File

@@ -6,14 +6,19 @@ from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import (
AppliedType,
BaseType,
ColumnGroupBy,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
DerivedType,
ExtensionType,
FrameGroupBy,
Function,
GenericType,
OverloadedFunction,
TopType,
TupleType,
Type,
TypeVar,
UnitType,
@@ -30,6 +35,7 @@ class StubsGenerator:
self.types: TypesRegistry = types
self.stubs: list[ast.stmt] = []
self.typing_imports: set[str] = set()
self.import_pandas: bool = False
self.protocol_idx: int = 0
self.stub_idx: int = 0
self.type_var_idx: int = 0
@@ -38,6 +44,7 @@ class StubsGenerator:
def generate_stubs(self) -> ast.Module:
self.stubs = []
self.typing_imports = set()
self.import_pandas = False
for name, type in self.types._types.items():
# Skip builtin types, not just based on name so the user can override
# TODO: check if added members on builtin type
@@ -53,7 +60,7 @@ class StubsGenerator:
continue
self.generate_stub(name, type)
imports = [
imports: list[ast.stmt] = [
ast.ImportFrom(
module="__future__",
names=[ast.alias(name="annotations")],
@@ -70,11 +77,37 @@ class StubsGenerator:
level=0,
)
)
if self.import_pandas:
imports.append(
ast.Import(
names=[
ast.alias(
name="pandas",
asname="pd",
)
],
)
)
return ast.Module(body=imports + self.stubs, type_ignores=[])
def generate_stub(self, name: str, type: Type):
base_type: Type = type
# TODO: improve
match type:
case DerivedType(name=name_) | GenericType(name=name_) if name_ == name:
pass
case UnitType() if name == "None":
pass
case TopType() if name == "Any":
pass
case _:
alias = ast.Assign(
targets=[ast.Name(id=name)], value=self.dump_type(type)
)
self.add_stub(alias)
return
members: dict[str, Member] = self.types._members.get(name, {})
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
return
@@ -231,6 +264,57 @@ class StubsGenerator:
case ConstraintType():
return self.dump_type(type.type)
case TupleType(items=items):
return ast.Subscript(
value=ast.Name(id="tuple"),
slice=ast.Tuple(
elts=[self.dump_type(item) for item in items],
),
)
case ColumnType(type=inner):
self.import_pandas = True
return ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="pd"),
attr="Series",
),
slice=self.dump_type(inner),
)
case DataFrameType():
self.import_pandas = True
return ast.Attribute(
value=ast.Name(id="pd"),
attr="DataFrame",
)
case FrameGroupBy():
self.import_pandas = True
return ast.Attribute(
value=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id="pd"),
attr="api",
),
attr="typing",
),
attr="DataFrameGroupBy",
)
case ColumnGroupBy():
self.import_pandas = True
return ast.Attribute(
value=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id="pd"),
attr="api",
),
attr="typing",
),
attr="SeriesGroupBy",
)
case _:
assert_never(type)

View File

@@ -46,8 +46,8 @@ class MidasLexer(Lexer):
self.add_token(TokenType.UNDERSCORE)
case "-" if self.match(">"):
self.add_token(TokenType.ARROW)
# case "+":
# self.add_token(TokenType.PLUS)
case "+":
self.add_token(TokenType.PLUS)
case "-":
self.add_token(TokenType.MINUS)
case "*":

View File

@@ -25,7 +25,7 @@ class TokenType(Enum):
DOT = auto()
# Operators
# PLUS = auto()
PLUS = auto()
MINUS = auto()
STAR = auto()
SLASH = auto()

View File

@@ -10,6 +10,7 @@ from midas.ast.midas import (
Expr,
ExtendStmt,
ExtensionType,
FrameType,
FunctionType,
GenericType,
GetExpr,
@@ -226,8 +227,10 @@ class MidasParser(Parser):
return self.generic_type()
def generic_type(self) -> Type:
type: Type = self.named_type()
type: NamedType = self.named_type()
if self.check(TokenType.LEFT_BRACKET):
if type.name.lexeme == "Frame":
return self.frame_type()
args: list[Type] = self.type_args()
return GenericType(
location=Location.span(type.location, self.previous().get_location()),
@@ -246,7 +249,7 @@ class MidasParser(Parser):
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
return args
def named_type(self) -> Type:
def named_type(self) -> NamedType:
name: Token = self.consume_identifier("Expected type name")
return NamedType(
location=name.get_location(),
@@ -281,6 +284,32 @@ class MidasParser(Parser):
members=members,
)
def frame_type(self) -> FrameType:
keyword: Token = self.previous()
self.consume(TokenType.LEFT_BRACKET, "Expected '[' to start frame schema")
columns: list[FrameType.Column] = []
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
name: Token = self.advance()
self.consume(TokenType.COLON, "Expected ':' between column name and type")
type: Type = self.type_expr()
columns.append(
FrameType.Column(
location=name.location_to(self.previous()),
name=name,
type=type,
)
)
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Unclosed frame schema")
return FrameType(
location=keyword.location_to(self.previous()),
columns=columns,
)
def constraint(self) -> Expr:
"""Parse a constraint
@@ -332,13 +361,35 @@ class MidasParser(Parser):
Returns:
Expr: the parsed expression
"""
expr: Expr = self.unary()
expr: Expr = self.term()
while self.match(
TokenType.LESS,
TokenType.LESS_EQUAL,
TokenType.GREATER,
TokenType.GREATER_EQUAL,
):
operator: Token = self.previous()
right: Expr = self.term()
location: Location = Location.span(expr.location, right.location)
expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right
)
return expr
def term(self) -> Expr:
expr: Expr = self.factor()
while self.match(TokenType.PLUS, TokenType.MINUS):
operator: Token = self.previous()
right: Expr = self.factor()
location: Location = Location.span(expr.location, right.location)
expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right
)
return expr
def factor(self) -> Expr:
expr: Expr = self.unary()
while self.match(TokenType.STAR, TokenType.SLASH):
operator: Token = self.previous()
right: Expr = self.unary()
location: Location = Location.span(expr.location, right.location)
@@ -370,7 +421,7 @@ class MidasParser(Parser):
pos_args: list[Expr] = []
kw_args: dict[str, Expr] = {}
keywords: bool = False
while not self.match(TokenType.RIGHT_PAREN):
while not self.check(TokenType.RIGHT_PAREN):
if self.check_identifier() and self.check_next(TokenType.EQUAL):
keywords = True
keyword: Token = self.advance()

View File

@@ -30,6 +30,7 @@ from midas.ast.python import (
Stmt,
SubscriptExpr,
TernaryExpr,
TupleExpr,
TypeAssign,
UnaryExpr,
VariableExpr,
@@ -300,26 +301,28 @@ class PythonParser:
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
return self._parse_frame_type(schema)
case ast.Subscript(value=ast.Name(id=name), slice=param):
case ast.Subscript(value=ast.Name(id=name), slice=arg):
args: tuple[MidasType, ...] = (
tuple(self._parse_type(a) for a in arg.elts)
if isinstance(arg, ast.Tuple)
else (self._parse_type(arg),)
)
return BaseType(
location=loc,
base=name,
param=self._parse_type(param),
args=args,
)
case ast.Name(id=name):
return BaseType(
location=loc,
base=name,
param=None,
args=(),
)
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
left = self._parse_type(left_expr)
match left:
case None:
raise InvalidSyntaxError()
# If chained constraints, separate base type and rebuild constraint
case ConstraintType(type=left_type, constraint=left_constraint):
constraint = ast.BinOp(
@@ -345,7 +348,7 @@ class PythonParser:
return BaseType(
location=loc,
base="None",
param=None,
args=(),
)
case _:
@@ -477,6 +480,12 @@ class PythonParser:
step=self.parse_expr(step) if step is not None else None,
)
case ast.Tuple(elts=items):
return TupleExpr(
location=location,
items=tuple(self.parse_expr(item) for item in items),
)
case _:
print(f"Unsupported expression: {ast.unparse(node)}")
return RawExpr(location=location, expr=node)

View File

@@ -1,3 +1,4 @@
from typing import Generic, TypeVar
from typing import cast as typing_cast
cast = typing_cast
@@ -32,3 +33,20 @@ This operation is unsound, use at your own risk!
_**Internal Python documentation**_
"""
T = TypeVar("T")
class Frame(Generic[T]):
"""A `Frame` is the abstract type implemented by `DataFrame`
A frame contains any number of named columns (see :class:`Column`)
"""
class Column(Generic[T]):
"""A `Column` is the abstract type implemented by `Series`
A column contains a any number of values of the same type
"""

View File

@@ -3,6 +3,7 @@ from typing import Any, Callable, Optional
import midas.ast.python as p
from midas.checker.types import Type
from midas.generator.collector import AssertionCollector
AllowRepeat = Callable[[object], bool]
@@ -63,3 +64,4 @@ class TypedAST:
stmts: list[p.Stmt]
judgements: list[tuple[p.Expr, Type]]
evaluated_casts: list[p.CastExpr]
assertions: AssertionCollector

43
tests/__main__.py Normal file
View File

@@ -0,0 +1,43 @@
from typing import Type
from midas.cli.ansi import Ansi
from tests.base import Tester
from tests.checker import CheckerTester
from tests.generator import GeneratorTester
from tests.midas import MidasTester
from tests.python import PythonTester
def print_banner(name: str):
horizontal: str = "+" + "-" * (len(name) + 2) + "+"
print(horizontal)
print(f"| {name} |")
print(horizontal)
def run_tests(tester_cls: Type[Tester]) -> bool:
print_banner(tester_cls.__name__)
tester: Tester = tester_cls()
success: bool = tester.run_all_tests()
print()
return success
def main():
testers: list[Type[Tester]] = [
PythonTester,
MidasTester,
CheckerTester,
GeneratorTester,
]
success: bool = all(map(run_tests, testers))
if success:
print(Ansi.FG(Ansi.BRIGHT_GREEN) + "All tests passed!" + Ansi.RESET)
else:
print(Ansi.FG(Ansi.BRIGHT_RED) + "Some tests failed!" + Ansi.RESET)
if __name__ == "__main__":
main()

View File

@@ -7,6 +7,8 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import Iterator, Protocol
from midas.cli.ansi import Ansi
class CaseResult(Protocol):
def dumps(self) -> str: ...
@@ -44,8 +46,11 @@ class Tester(ABC):
print(rule)
for i, test in enumerate(tests):
print(f"Case {i+1}/{n}: {test.resolve().relative_to(self.CASES_DIR)}")
path: Path = test.resolve().relative_to(self.CASES_DIR)
print(f"{Ansi.FG(Ansi.BRIGHT_CYAN)}Case {i+1}/{n}: {path}{Ansi.RESET}")
print(Ansi.DIM, end="")
success: bool = self._run_test(test)
print(Ansi.RESET, end="")
if success:
successes += 1
else:
@@ -146,8 +151,9 @@ class Tester(ABC):
if not success:
sys.exit(1)
case None:
print("No subcommand provided. Available subcommands: run, update")
sys.exit(1)
success: bool = tester.run_all_tests()
if not success:
sys.exit(1)
case _:
print(f"Unknown subcommand '{args.subcommand}'")
sys.exit(1)

View File

@@ -4,7 +4,35 @@
"type": "Warning",
"location": {
"start": [
6,
8,
12
],
"end": [
8,
43
]
},
"message": "ConstraintType not yet supported"
},
{
"type": "Warning",
"location": {
"start": [
10,
10
],
"end": [
10,
18
]
},
"message": "Unknown type 'datetime'"
},
{
"type": "Warning",
"location": {
"start": [
13,
4
],
"end": [
@@ -12,7 +40,7 @@
5
]
},
"message": "FrameType not yet supported"
"message": "Unknown type '_'"
}
],
"judgments": []

View File

@@ -328,6 +328,19 @@
},
"type": {}
},
{
"location": {
"from": "L6:9",
"to": "L6:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:5",
@@ -373,19 +386,6 @@
}
}
},
{
"location": {
"from": "L6:9",
"to": "L6:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:5",
@@ -407,6 +407,32 @@
},
"type": {}
},
{
"location": {
"from": "L7:9",
"to": "L7:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L7:12",
"to": "L7:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L7:5",
@@ -452,32 +478,6 @@
}
}
},
{
"location": {
"from": "L7:9",
"to": "L7:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L7:12",
"to": "L7:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L7:5",
@@ -503,6 +503,32 @@
},
"type": {}
},
{
"location": {
"from": "L8:9",
"to": "L8:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L8:14",
"to": "L8:17"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L8:5",
@@ -548,32 +574,6 @@
}
}
},
{
"location": {
"from": "L8:9",
"to": "L8:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L8:14",
"to": "L8:17"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L8:5",
@@ -600,6 +600,45 @@
},
"type": {}
},
{
"location": {
"from": "L9:9",
"to": "L9:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L9:12",
"to": "L9:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:17",
"to": "L9:23"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L9:5",
@@ -645,45 +684,6 @@
}
}
},
{
"location": {
"from": "L9:9",
"to": "L9:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L9:12",
"to": "L9:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:17",
"to": "L9:23"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L9:5",
@@ -713,6 +713,45 @@
},
"type": {}
},
{
"location": {
"from": "L10:9",
"to": "L10:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:12",
"to": "L10:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L10:19",
"to": "L10:22"
},
"expr": {
"_type": "LiteralExpr",
"value": 3.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L10:5",
@@ -758,45 +797,6 @@
}
}
},
{
"location": {
"from": "L10:9",
"to": "L10:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:12",
"to": "L10:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L10:19",
"to": "L10:22"
},
"expr": {
"_type": "LiteralExpr",
"value": 3.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L10:5",
@@ -827,6 +827,19 @@
},
"type": {}
},
{
"location": {
"from": "L11:11",
"to": "L11:12"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L11:5",
@@ -872,19 +885,6 @@
}
}
},
{
"location": {
"from": "L11:11",
"to": "L11:12"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L11:5",
@@ -906,6 +906,19 @@
},
"type": {}
},
{
"location": {
"from": "L12:11",
"to": "L12:17"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L12:5",
@@ -951,19 +964,6 @@
}
}
},
{
"location": {
"from": "L12:11",
"to": "L12:17"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L12:5",
@@ -985,6 +985,45 @@
},
"type": {}
},
{
"location": {
"from": "L14:10",
"to": "L14:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L14:13",
"to": "L14:16"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L14:20",
"to": "L14:26"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L14:6",
@@ -1030,45 +1069,6 @@
}
}
},
{
"location": {
"from": "L14:10",
"to": "L14:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L14:13",
"to": "L14:16"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L14:20",
"to": "L14:26"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L14:6",
@@ -1101,6 +1101,45 @@
"name": "bool"
}
},
{
"location": {
"from": "L15:10",
"to": "L15:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L15:15",
"to": "L15:18"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L15:22",
"to": "L15:28"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L15:6",
@@ -1146,45 +1185,6 @@
}
}
},
{
"location": {
"from": "L15:10",
"to": "L15:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L15:15",
"to": "L15:18"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L15:22",
"to": "L15:28"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L15:6",
@@ -1217,6 +1217,45 @@
"name": "bool"
}
},
{
"location": {
"from": "L16:10",
"to": "L16:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L16:15",
"to": "L16:21"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L16:25",
"to": "L16:28"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L16:6",
@@ -1262,45 +1301,6 @@
}
}
},
{
"location": {
"from": "L16:10",
"to": "L16:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L16:15",
"to": "L16:21"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L16:25",
"to": "L16:28"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L16:6",
@@ -1333,6 +1333,45 @@
"name": "bool"
}
},
{
"location": {
"from": "L18:10",
"to": "L18:13"
},
"expr": {
"_type": "LiteralExpr",
"value": "a"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L18:15",
"to": "L18:16"
},
"expr": {
"_type": "LiteralExpr",
"value": 3
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L18:20",
"to": "L18:25"
},
"expr": {
"_type": "LiteralExpr",
"value": false
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L18:6",
@@ -1378,45 +1417,6 @@
}
}
},
{
"location": {
"from": "L18:10",
"to": "L18:13"
},
"expr": {
"_type": "LiteralExpr",
"value": "a"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L18:15",
"to": "L18:16"
},
"expr": {
"_type": "LiteralExpr",
"value": 3
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L18:20",
"to": "L18:25"
},
"expr": {
"_type": "LiteralExpr",
"value": false
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L18:6",

View File

@@ -24,7 +24,7 @@
"type": {
"_type": "BaseType",
"base": "Meter",
"param": null
"args": []
},
"expr": {
"_type": "LiteralExpr",
@@ -62,7 +62,7 @@
"type": {
"_type": "BaseType",
"base": "Second",
"param": null
"args": []
},
"expr": {
"_type": "LiteralExpr",

View File

@@ -100,6 +100,32 @@
"name": "float"
}
},
{
"location": {
"from": "L11:13",
"to": "L11:15"
},
"expr": {
"_type": "VariableExpr",
"name": "v1"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L11:17",
"to": "L11:19"
},
"expr": {
"_type": "VariableExpr",
"name": "v2"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L11:5",
@@ -135,32 +161,6 @@
}
}
},
{
"location": {
"from": "L11:13",
"to": "L11:15"
},
"expr": {
"_type": "VariableExpr",
"name": "v1"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L11:17",
"to": "L11:19"
},
"expr": {
"_type": "VariableExpr",
"name": "v2"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L11:5",

View File

@@ -72,29 +72,6 @@
}
],
"judgments": [
{
"location": {
"from": "L26:0",
"to": "L26:5"
},
"expr": {
"_type": "VariableExpr",
"name": "print"
},
"type": {
"pos_args": [
{
"pos": 0,
"name": "object",
"type": {},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {}
}
},
{
"location": {
"from": "L27:4",
@@ -325,6 +302,29 @@
}
}
},
{
"location": {
"from": "L26:0",
"to": "L26:5"
},
"expr": {
"_type": "VariableExpr",
"name": "print"
},
"type": {
"pos_args": [
{
"pos": 0,
"name": "object",
"type": {},
"required": false
}
],
"args": [],
"kw_args": [],
"returns": {}
}
},
{
"location": {
"from": "L26:0",

View File

@@ -63,31 +63,6 @@
"name": "float"
}
},
{
"location": {
"from": "L6:11",
"to": "L6:15"
},
"expr": {
"_type": "VariableExpr",
"name": "bool"
},
"type": {
"pos_args": [
{
"pos": 0,
"name": "object",
"type": {},
"required": false
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "bool"
}
}
},
{
"location": {
"from": "L6:16",
@@ -135,6 +110,31 @@
"name": "int"
}
},
{
"location": {
"from": "L6:11",
"to": "L6:15"
},
"expr": {
"_type": "VariableExpr",
"name": "bool"
},
"type": {
"pos_args": [
{
"pos": 0,
"name": "object",
"type": {},
"required": false
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "bool"
}
}
},
{
"location": {
"from": "L6:11",
@@ -367,6 +367,54 @@
}
}
},
{
"location": {
"from": "L12:21",
"to": "L12:27"
},
"expr": {
"_type": "VariableExpr",
"name": "double"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "float"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "float"
}
}
},
{
"location": {
"from": "L12:29",
"to": "L12:35"
},
"expr": {
"_type": "VariableExpr",
"name": "floats"
},
"type": {
"name": "list",
"args": [
{
"name": "float"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L12:17",
@@ -455,54 +503,6 @@
}
}
},
{
"location": {
"from": "L12:21",
"to": "L12:27"
},
"expr": {
"_type": "VariableExpr",
"name": "double"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "float"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "float"
}
}
},
{
"location": {
"from": "L12:29",
"to": "L12:35"
},
"expr": {
"_type": "VariableExpr",
"name": "floats"
},
"type": {
"name": "list",
"args": [
{
"name": "float"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L12:17",
@@ -538,6 +538,54 @@
}
}
},
{
"location": {
"from": "L13:19",
"to": "L13:25"
},
"expr": {
"_type": "VariableExpr",
"name": "double"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "float"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "float"
}
}
},
{
"location": {
"from": "L13:27",
"to": "L13:31"
},
"expr": {
"_type": "VariableExpr",
"name": "ints"
},
"type": {
"name": "list",
"args": [
{
"name": "int"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L13:15",
@@ -626,54 +674,6 @@
}
}
},
{
"location": {
"from": "L13:19",
"to": "L13:25"
},
"expr": {
"_type": "VariableExpr",
"name": "double"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "float"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "float"
}
}
},
{
"location": {
"from": "L13:27",
"to": "L13:31"
},
"expr": {
"_type": "VariableExpr",
"name": "ints"
},
"type": {
"name": "list",
"args": [
{
"name": "int"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L13:15",
@@ -699,6 +699,54 @@
},
"type": {}
},
{
"location": {
"from": "L14:15",
"to": "L14:21"
},
"expr": {
"_type": "VariableExpr",
"name": "is_odd"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "int"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "bool"
}
}
},
{
"location": {
"from": "L14:23",
"to": "L14:27"
},
"expr": {
"_type": "VariableExpr",
"name": "ints"
},
"type": {
"name": "list",
"args": [
{
"name": "int"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L14:11",
@@ -787,54 +835,6 @@
}
}
},
{
"location": {
"from": "L14:15",
"to": "L14:21"
},
"expr": {
"_type": "VariableExpr",
"name": "is_odd"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "int"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "bool"
}
}
},
{
"location": {
"from": "L14:23",
"to": "L14:27"
},
"expr": {
"_type": "VariableExpr",
"name": "ints"
},
"type": {
"name": "list",
"args": [
{
"name": "int"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L14:11",

View File

@@ -0,0 +1,117 @@
# type: ignore
# ruff: disable [F821]
df1: Frame[a:int, b:float]
df2: Frame[a:int, b:float]
_: Any
# Arithmetic
_ = df1 + df2
_ = df1 - df2
_ = df1 * df2
_ = df1 / df2
_ = df1 // df2
_ = df1 % df2
_ = df1**df2
# Comparisons
_ = df1 < df2
_ = df1 > df2
_ = df1 <= df2
_ = df1 >= df2
_ = df1 != df2
_ = df1 == df2
# Aggregate
_ = df1.kurt()
_ = df1.kurtosis()
_ = df1.max()
_ = df1.mean()
_ = df1.median()
_ = df1.min()
_ = df1.mode()
_ = df1.prod()
_ = df1.product()
_ = df1.std()
_ = df1.sum()
_ = df1.var()
# Groupby
df_gb = df1.groupby(by="a")
_ = df_gb.kurt()
_ = df_gb.max()
_ = df_gb.mean()
_ = df_gb.median()
_ = df_gb.min()
_ = df_gb.prod()
_ = df_gb.std()
_ = df_gb.sum()
_ = df_gb.var()
# Columns
col1 = df1["a"]
col2 = df1["a"]
# Arithmetic
_ = col1 + col2
_ = col1 - col2
_ = col1 * col2
_ = col1 / col2
_ = col1 // col2
_ = col1 % col2
_ = col1**col2
# Comparisons
_ = col1 < col2
_ = col1 > col2
_ = col1 <= col2
_ = col1 >= col2
_ = col1 != col2
_ = col1 == col2
# Aggregate
_ = col1.kurt()
_ = col1.kurtosis()
_ = col1.max()
_ = col1.mean()
_ = col1.median()
_ = col1.min()
_ = col1.mode()
_ = col1.prod()
_ = col1.product()
_ = col1.std()
_ = col1.sum()
_ = col1.var()
# Groupby
col_gb = col1.groupby(level=0)
_ = col_gb.kurt()
_ = col_gb.max()
_ = col_gb.mean()
_ = col_gb.median()
_ = col_gb.min()
_ = col_gb.prod()
_ = col_gb.std()
_ = col_gb.sum()
_ = col_gb.var()
# Attributes
_ = df1.ndim # int
_ = df1.size # int
_ = df1.shape # (int, int)
_ = col1.ndim # int
_ = col1.size # int
_ = col1.shape # (int)
_ = col1.T # Column[int]
# Misc
_ = df1.head()
_ = df1.tail()
_ = col1.head()
_ = col1.tail()

File diff suppressed because it is too large Load Diff

View File

@@ -16,7 +16,7 @@
"type": {
"_type": "BaseType",
"base": "bool",
"param": null
"args": []
}
},
{
@@ -25,7 +25,7 @@
"type": {
"_type": "BaseType",
"base": "int",
"param": null
"args": []
}
},
{
@@ -36,7 +36,7 @@
"type": {
"_type": "BaseType",
"base": "float",
"param": null
"args": []
},
"constraint": "(_ > 0) + (_ < 250)"
}
@@ -47,7 +47,7 @@
"type": {
"_type": "BaseType",
"base": "str",
"param": null
"args": []
}
},
{
@@ -56,7 +56,7 @@
"type": {
"_type": "BaseType",
"base": "datetime",
"param": null
"args": []
}
},
{
@@ -65,7 +65,7 @@
"type": {
"_type": "BaseType",
"base": "float",
"param": null
"args": []
}
},
{
@@ -79,7 +79,7 @@
"type": {
"_type": "BaseType",
"base": "_",
"param": null
"args": []
}
}
]

View File

@@ -16,7 +16,7 @@
"type": {
"_type": "BaseType",
"base": "GeoLocation",
"param": null
"args": []
}
}
]
@@ -28,11 +28,13 @@
"type": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "BaseType",
"base": "GeoLocation",
"param": null
}
"args": [
{
"_type": "BaseType",
"base": "GeoLocation",
"args": []
}
]
}
},
{
@@ -65,11 +67,13 @@
"type": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "BaseType",
"base": "GeoLocation",
"param": null
}
"args": [
{
"_type": "BaseType",
"base": "GeoLocation",
"args": []
}
]
}
},
{
@@ -117,7 +121,7 @@
"type": {
"_type": "BaseType",
"base": "Latitude",
"param": null
"args": []
}
},
{
@@ -146,7 +150,7 @@
"type": {
"_type": "BaseType",
"base": "Latitude",
"param": null
"args": []
}
},
{
@@ -175,11 +179,13 @@
"type": {
"_type": "BaseType",
"base": "Difference",
"param": {
"_type": "BaseType",
"base": "Latitude",
"param": null
}
"args": [
{
"_type": "BaseType",
"base": "Latitude",
"args": []
}
]
}
},
{
@@ -217,7 +223,7 @@
"type": {
"_type": "BaseType",
"base": "int",
"param": null
"args": []
},
"constraint": "_ >= 0"
}
@@ -230,7 +236,7 @@
"type": {
"_type": "BaseType",
"base": "float",
"param": null
"args": []
},
"constraint": "_ >= 0"
}
@@ -252,7 +258,7 @@
"type": {
"_type": "BaseType",
"base": "int",
"param": null
"args": []
},
"constraint": "Positive"
}
@@ -265,7 +271,7 @@
"type": {
"_type": "BaseType",
"base": "float",
"param": null
"args": []
},
"constraint": "Positive"
}

View File

@@ -14,15 +14,17 @@
"type": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "0 <= _ <= 1"
}
"args": [
{
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"args": []
},
"constraint": "0 <= _ <= 1"
}
]
},
"default": null
},
@@ -31,15 +33,17 @@
"type": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "0 <= _ <= 1"
}
"args": [
{
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"args": []
},
"constraint": "0 <= _ <= 1"
}
]
},
"default": null
}
@@ -50,15 +54,17 @@
"returns": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "0 <= _ <= 2"
}
"args": [
{
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"args": []
},
"constraint": "0 <= _ <= 2"
}
]
},
"body": [
{
@@ -67,15 +73,17 @@
"type": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "0 <= _ <= 2"
}
"args": [
{
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"args": []
},
"constraint": "0 <= _ <= 2"
}
]
}
},
{
@@ -117,7 +125,7 @@
"type": {
"_type": "BaseType",
"base": "int",
"param": null
"args": []
},
"default": null
}
@@ -128,7 +136,7 @@
"type": {
"_type": "BaseType",
"base": "float",
"param": null
"args": []
},
"default": null
}
@@ -140,7 +148,7 @@
"type": {
"_type": "BaseType",
"base": "str",
"param": null
"args": []
},
"default": null
}

View File

@@ -46,7 +46,8 @@ class GeneratorTester(Tester):
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
generator = Generator(workdir=path.parent, types=checker.types)
result.compiled_ast = generator.generate_ast(typed_ast, path)
generator.set_src_path(path)
result.compiled_ast = generator.generate_ast(typed_ast)
return result

View File

@@ -1,6 +1,7 @@
from typing import Optional, Sequence
from midas.ast.midas import (
AliasStmt,
BinaryExpr,
CallExpr,
ComplexType,
@@ -8,6 +9,7 @@ from midas.ast.midas import (
Expr,
ExtendStmt,
ExtensionType,
FrameType,
FunctionType,
GenericType,
GetExpr,
@@ -60,6 +62,13 @@ class MidasAstJsonSerializer(
"bound": self._serialize_optional(param.bound),
}
def visit_alias_stmt(self, stmt: AliasStmt) -> dict:
return {
"_type": "AliasStmt",
"name": stmt.name.lexeme,
"type": stmt.type.accept(self),
}
def visit_member_stmt(self, stmt: MemberStmt) -> dict:
return {
"_type": "MemberStmt",
@@ -197,3 +206,15 @@ class MidasAstJsonSerializer(
"base": type.base.accept(self),
"extension": type.extension.accept(self),
}
def visit_frame_type(self, type: FrameType) -> dict:
return {
"_type": "FrameType",
"columns": [self._serialize_column(col) for col in type.columns],
}
def _serialize_column(self, column: FrameType.Column):
return {
"name": column.name.lexeme,
"type": column.type.accept(self),
}

View File

@@ -30,6 +30,7 @@ from midas.ast.python import (
Stmt,
SubscriptExpr,
TernaryExpr,
TupleExpr,
TypeAssign,
UnaryExpr,
VariableExpr,
@@ -98,7 +99,7 @@ class PythonAstJsonSerializer(
return {
"_type": "BaseType",
"base": node.base,
"param": self._serialize_optional(node.param),
"args": self._serialize_list(node.args),
}
def visit_constraint_type(self, node: ConstraintType) -> dict:
@@ -302,6 +303,12 @@ class PythonAstJsonSerializer(
"step": self._serialize_optional(expr.step),
}
def visit_tuple_expr(self, expr: TupleExpr) -> dict:
return {
"_type": "TupleExpr",
"items": [item.accept(self) for item in expr.items],
}
def visit_raw_expr(self, expr: RawExpr) -> dict:
return {
"_type": "RawExpr",