21 Commits

Author SHA1 Message Date
5b0c5c01ad feat(checker): add mean method on frames 2026-06-26 11:21:38 +02:00
43e40396a1 fix(checker): type check None literal 2026-06-26 11:21:17 +02:00
0d265ef24c feat(checker): lookup dunders on dataframes 2026-06-26 10:35:50 +02:00
88c56c9d15 tests: update with reordered argument typing 2026-06-26 10:28:12 +02:00
d1c217a335 refactor: use metaclass to collect frame methods 2026-06-25 22:31:59 +02:00
5b3e87afcb refactor: add MethodResolver class 2026-06-25 22:14:25 +02:00
894d5a7196 feat: add dummy classes for typing frames and columns 2026-06-25 21:35:47 +02:00
eb809c6341 fix(checker): improve heterogeneous error message 2026-06-25 21:35:19 +02:00
bd68d1003f feat(checker): lookup dataframe methods 2026-06-25 21:34:59 +02:00
72c9236650 feat(checker): defined add method of dataframes 2026-06-25 21:34:00 +02:00
90051c7981 feat(checker): add structural subtyping rule for dataframes 2026-06-25 21:09:14 +02:00
dd1e2e693c feat(cli): print context for multiline diagnostics 2026-06-25 16:32:15 +02:00
78e10e0895 feat(checker): process frame type definitions 2026-06-24 14:36:53 +02:00
c81e4a9560 feat(cli): add frame type to highlighter 2026-06-24 14:36:53 +02:00
6d0cf1a055 feat(parser): add frame type to midas syntax 2026-06-24 14:36:52 +02:00
cc5e7af143 feat(gen): add support for tuples and dataframes 2026-06-24 14:36:51 +02:00
3bdbc80079 feat(checker): handle setting dataframe column 2026-06-24 14:36:51 +02:00
c1b5284f72 feat(checker): type check subscript on dataframes 2026-06-24 14:36:28 +02:00
5e9ccd4e13 feat(types): add TupleType 2026-06-24 14:36:04 +02:00
cf083fc0c3 fix(types): add str methods to dataframe types 2026-06-24 14:35:31 +02:00
a80da5db2c feat(types): add DataFrameType and ColumnType 2026-06-24 14:35:30 +02:00
53 changed files with 1254 additions and 9565 deletions

View File

@@ -1,117 +0,0 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<!-- Created with Inkscape (http://www.inkscape.org/) -->
<svg
width="128"
height="128"
viewBox="0 0 128 128"
version="1.1"
id="svg1"
inkscape:export-filename="logo.png"
inkscape:export-xdpi="96"
inkscape:export-ydpi="96"
inkscape:version="1.4.4 (1:1.4.4+202605061436+dcaf3e7d9e)"
sodipodi:docname="logo.svg"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns:xlink="http://www.w3.org/1999/xlink"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<sodipodi:namedview
id="namedview1"
pagecolor="#ffffff"
bordercolor="#000000"
borderopacity="0.25"
inkscape:showpageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:deskcolor="#d1d1d1"
inkscape:document-units="mm"
showgrid="true"
inkscape:zoom="1.9332778"
inkscape:cx="-8.2760999"
inkscape:cy="112.2446"
inkscape:window-width="2584"
inkscape:window-height="1028"
inkscape:window-x="0"
inkscape:window-y="24"
inkscape:window-maximized="1"
inkscape:current-layer="layer1">
<inkscape:grid
id="grid1"
units="px"
originx="0"
originy="0"
spacingx="4"
spacingy="4"
empcolor="#0099e5"
empopacity="0.30196078"
color="#0099e5"
opacity="0.14901961"
empspacing="4"
enabled="true"
visible="true" />
</sodipodi:namedview>
<defs
id="defs1">
<linearGradient
inkscape:collect="always"
xlink:href="#linearGradient4689"
id="linearGradient1478"
gradientUnits="userSpaceOnUse"
gradientTransform="matrix(0.562541,0,0,0.567972,-9.399749,-5.305317)"
x1="26.648937"
y1="20.603781"
x2="135.66525"
y2="114.39767" />
<linearGradient
id="linearGradient4689">
<stop
style="stop-color:#e1be1e;stop-opacity:1;"
offset="0"
id="stop4691" />
<stop
style="stop-color:#ffeb82;stop-opacity:1;"
offset="1"
id="stop4693" />
</linearGradient>
<linearGradient
inkscape:collect="always"
xlink:href="#linearGradient4671"
id="linearGradient1475"
gradientUnits="userSpaceOnUse"
gradientTransform="matrix(0.562541,0,0,0.567972,-9.399749,-5.305317)"
x1="150.96111"
y1="192.35176"
x2="112.03144"
y2="137.27299" />
<linearGradient
id="linearGradient4671">
<stop
style="stop-color:#ffdc21;stop-opacity:1;"
offset="0"
id="stop4673" />
<stop
style="stop-color:#ffeb82;stop-opacity:1;"
offset="1"
id="stop4675" />
</linearGradient>
</defs>
<g
inkscape:label="Calque 1"
inkscape:groupmode="layer"
id="layer1">
<g
id="g1"
transform="translate(2.911719,3.414527)">
<path
style="fill:url(#linearGradient1478);fill-opacity:1"
d="m 60.510156,6.3979729 c -4.583653,0.021298 -8.960939,0.4122177 -12.8125,1.09375 C 36.35144,9.4962267 34.291407,13.691825 34.291406,21.429223 v 10.21875 h 26.8125 v 3.40625 h -26.8125 -10.0625 c -7.792459,0 -14.6157592,4.683717 -16.7500002,13.59375 -2.46182,10.212966 -2.5710151,16.586023 0,27.25 1.9059283,7.937852 6.4575432,13.593748 14.2500002,13.59375 h 9.21875 v -12.25 c 0,-8.849902 7.657144,-16.656248 16.75,-16.65625 h 26.78125 c 7.454951,0 13.406253,-6.138164 13.40625,-13.625 v -25.53125 c 0,-7.266339 -6.12998,-12.7247775 -13.40625,-13.9375001 -4.605987,-0.7667253 -9.385097,-1.1150483 -13.96875,-1.09375 z m -14.5,8.2187501 c 2.769547,0 5.03125,2.298646 5.03125,5.125 -2e-6,2.816336 -2.261703,5.09375 -5.03125,5.09375 -2.779476,-1e-6 -5.03125,-2.277415 -5.03125,-5.09375 -1e-6,-2.826353 2.251774,-5.125 5.03125,-5.125 z"
id="path1948" />
<path
style="fill:url(#linearGradient1475);fill-opacity:1"
d="m 91.228906,35.054223 v 11.90625 c 0,9.230755 -7.825895,16.999999 -16.75,17 h -26.78125 c -7.335833,0 -13.406249,6.278483 -13.40625,13.625 v 25.531247 c 0,7.26634 6.318588,11.54032 13.40625,13.625 8.487331,2.49561 16.626237,2.94663 26.78125,0 6.750155,-1.95439 13.406253,-5.88761 13.40625,-13.625 V 92.897973 h -26.78125 v -3.40625 h 26.78125 13.406254 c 7.79246,0 10.69625,-5.435408 13.40624,-13.59375 2.79933,-8.398886 2.68022,-16.475776 0,-27.25 -1.92578,-7.757441 -5.60387,-13.59375 -13.40624,-13.59375 z m -15.0625,64.65625 c 2.779478,3e-6 5.03125,2.277417 5.03125,5.093747 -2e-6,2.82635 -2.251775,5.125 -5.03125,5.125 -2.76955,0 -5.03125,-2.29865 -5.03125,-5.125 2e-6,-2.81633 2.261697,-5.093747 5.03125,-5.093747 z"
id="path1950" />
</g>
</g>
</svg>

Before

Width:  |  Height:  |  Size: 4.7 KiB

View File

@@ -1,724 +0,0 @@
//#import "@preview/codly:1.3.0": codly, codly-init
// Fix unaligned highlights in v0.15.0 ()
// See https://github.com/Dherse/codly/pull/132
#import "@local/codly:1.3.1": codly, codly-init
#import "@preview/codly-languages:0.1.10": codly-languages
#import "template.typ": TODO, project
#import "@preview/gentle-clues:1.3.1" as gc
#let midas-version = toml("../pyproject.toml").project.version
#let head-ref = read("../.git/HEAD").split(":").at(1).trim()
#let commit-hash = read("../.git/" + head-ref).slice(0, 8)
#show: project.with(
title: [Midas User Manual],
author: "Louis Heredero",
version: midas-version,
hash: commit-hash,
icon-path: path("../assets/icon.svg"),
)
#show: codly-init
#codly(
languages: codly-languages
+ (
midas: (
name: "Midas",
color: rgb("#eedd47"),
icon: box(
image(
"../assets/icon.svg",
height: 130%,
fit: "contain",
),
),
),
),
)
= Introduction
Python is a very popular programming language, especially in data sciences.
However, it has been designed for simplicity, distancing itself from typed languages such as Java or C to embrace dynamic typing.
What this means is that in Python, type checks are deferred to runtime when operations are concretely executed.
For developers, it might seem like a great way of simplifying the language and making it very flexible, but it does come with a cost.
Indeed, type errors are very easy to make in Python. While passing an integer where a string is expected might not be an issue in some cases, these are the sort of thing that can cause crashes or incorrect results without a clear diagnostic to help the user fix it.
Fortunately, developers using IDEs or properly configured text editors can benefit from external type checkers such as MyPy which will perform static type analysis of their Python code. Some can also be configured to be very strict, forcing the user to make the whole code typeable statically, thus avoiding any runtime type errors.
This is not the end of the problem though. Some parts of a program, especially in data related fields, may not be available at "compile-time". For example, a dataset can be loaded from an external file, or data can be fetched from an API, with no guarantees of having the expected format when analyzing the code statically.
In turn, that can cause a range of loud and silent errors at runtime. A malformed number will probably crash the program when trying to convert it, but a NaN in a series of value might just produce wrong results without any exception. Combine this with often long-running data-processing pipelines and this is how developers can waste hours of precious computation time.
Midas is a type system which can be used on top of Python to provide better type checking capabilities and gradual typing.
It aims at providing optional but strict type annotations and casting operations which can produce runtime assertions. It also allows the user to define dependent types with value constraints that are translated into runtime checks.
= Installation
Midas comes as a very light Python package that you can install on your system in a few simple steps.
== Requirements
Here below are the requirements for installing Midas. All Python dependencies will be installed by `uv` in the installation process described in @install-steps.
- Python 3.11+
- `uv`
== Steps <install-steps>
1. Clone the repository
```bash
git clone https://git.kb28.ch/HEL/midas.git
```
2. Navigate inside the directory
```bash
cd midas
```
3. Install Midas as a tool in your local user space
```bash
uv tool install .
```
And that's it ! You can now use Midas commands anywhere, like this:
```bash
midas --help
```
= Quick Start
This chapter will give you the keys to quickly start using Midas in your project.
== Defining custom types
To begin with, you might want to define some custom types for your project, to avoid handling anonymous float values everywhere. To do so, create a `*.midas` file in your project, and write some definitions for your types. See @midas-ref for more information on syntax and features.
@qs-midas shows a simple example of what it might look like.
#codly(header: [types.midas])
#figure(
```midas
type Meter = float
extend Meter {
def __add__: fn(Meter, /) -> Meter
def __sub__: fn(Meter, /) -> Meter
}
type Coordinate = object
extend Coordinate {
prop x: Meter
prop y: Meter
}
```,
caption: [Example Midas type definitions],
) <qs-midas>
You can check for any syntax error using the following command:
```bash
midas validate types.midas
```
When you are happy with your definitions, you can generate Python stubs to use in your source code. This allows other type checkers like MyPy to recognize your custom types and avoid reporting them as undefined. It can also help catch some type errors in your IDE.
```bash
midas stubs types.midas -o stubs.pyi
```
This command will generate a file as shown in @qs-stubs, providing stub classes to represent the type lattice including methods and properties.
#codly(header: [stubs.pyi])
#figure(
```pyi
from __future__ import annotations
class Meter(float):
def __add__(self, _0: Meter, /) -> Meter: ...
def __sub__(self, _0: Meter, /) -> Meter: ...
class Coordinate(object):
x: Meter
y: Meter
```,
caption: [Generated stubs from example definitions of @qs-midas],
) <qs-stubs>
== Using Midas in Python
You can now write your Python program as you would normally. You can import your custom types from the generated stubs file and use them in type annotations.
You can also import the `cast` and `unsafe_cast` functions from `midas.typing` to explicitly cast a value to a specific type (see @cast for more information).
An example Python script is shown in @qs-python, demonstrating how you can use custom types in type annotations. Notice the comments describing errors that will be caught by the type checker in @qs-type-checking.
#codly(header: [script.py])
#figure(
```python
from lib import load_coordinate
from midas.typing import cast
from stubs import Coordinate, Meter
p1 = cast(Coordinate, load_coordinate(0))
p2 = cast(Coordinate, load_coordinate(1))
diff_x = p2.x - p1.x
diff_y = p2.y - p1.y
dist = diff_x + diff_y
p2.x += cast(Meter, 1)
p2.y = True # invalid, wrong type
p2.z = 3 # invalid, no property 'z' on Coordinate
p2.x.a = 3 # invalid, no properties on Meter
```,
caption: [Example Python script],
) <qs-python>
== Type checking <qs-type-checking>
Now that you have defined some types and written a script, you can run the type checker with the following command. You can also skip this step and directly run the compilation command in @qs-compilation.
```bash
midas check -t types.midas script.py
```
== Compiling <qs-compilation>
The final step is to compile your code. This step will produce a runnable Python script, including runtime assertions generated by `cast` expressions.
```bash
midas compile -t types.midas script.py
python3 build/midas/script.py
```
= Midas Language Reference <midas-ref>
In this chapter, you will find a complete reference for the Midas definition language.
A `*.midas` file contains a number of statements, which can be:
- *`type`* statements (see @type-stmt): to define a new type
- *`extend`* statements (see @extend-stmt): to define member of a type
- *`predicate`* statements (see @predicate-stmt): to define named predicates that can be used in constraint types
== Type Statement <type-stmt>
A *`type`* statement lets you define a new type. It requires a unique name and base type.
The simplest form of a *`type`* statement is:
#figure(
```midas
type MyType = float
```,
caption: [Simple `type` statement declaring a new type "`MyType`" as a subtype of `float`],
) <midas-simple-alias>
This statement defines a new type called `MyType` which is a subtype of `float`. `MyType` is a `float` but a `float` is not necessarily `MyType`.
=== Builtin / base types
A number of base types are provided out of the box, which can be used to derive other types.
They correspond to Python's builtin types:
```py object```,
```py str```,
```py float```,
```py int```,
```py bool```,
```py list```,
```py dict```,
```py None```.
Some differences are to be noted however.
1. ```py bool``` is not a subtype of ```py int```
2. ```py list``` are homogeneous, i.e. all items must be of the same type
3. ```py dict``` keys and values are homogeneous, i.e. all keys must be of the same type and all values must be of the same type (can be different from keys).
=== Function types
A function type is written in a similar notation to Python function definitions:
#figure(
```midas
type Repeater = fn(text: str, count: int) -> str
```,
caption: [Simple function type definition],
)
Midas supports positional-only, keyword-only and mixed arguments (using the `/` and `*` separators). You may omit the name of positional-only arguments. The return type is required.
Optional parameters can be indicated by adding a question mark (`?`) after their type:
#figure(
```midas
type Repeater = fn(text: str, count: int, *, sep: str?) -> str
```,
caption: [Function type definition with an optional keyword-only parameter],
)
#gc.warning[
Sink arguments (`*args`, `**kwargs`) are not currently supported.
]
=== Constraint types
A useful feature provided by Midas is the possibility to combine types with custom value constraints. For example, you might want to define a type for positive amounts of money:
#figure(
```midas
type Money = float
type Income = Money where _ >= 0
```,
caption: [Simple constraint type definition],
)
Constraints can be combined with any type using the `where` keyword, followed by a constraint expression (see @constraint-expr).
=== Generic types
For more complex types, you might want to use type parameters. For example, to define a container, we might write:
#figure(
```midas
type Container[T] = object
```,
caption: [Simple generic container type definition],
)
To better refine a generic type, you can also bound type parameters using the following syntax:
#figure(
```midas
type Container[T <: float] = object
```,
caption: [Generic container type definition with a bound],
)
This can be read as "`Container` is a generic type which takes one type parameter `T` that must be a subtype of `float`".
You can use a generic type, i.e. instantiate it, by using a similar syntax with concrete type as arguments:
#figure(
```midas
type MyContainer = Container[MyType]
```,
caption: [Application of a generic type],
)
Generic types can also take multiple parameters, which are then separated by commas:
#figure(
```midas
type ZipCodeRegistry = dict[int, str]
```,
caption: [Application of a multi-parameter generic type],
)
The _body_ of a generic type, i.e. the right-hand side of the definition, can contain or even be equal to any number of its parameters.#footnote[The latter is not something that is expressible in standard Python, yet it brings a semantic distinction on top of structurally equivalent values.] For example, the following is a valid type statement:
#figure(
```midas
type Price[T <: Currency] = T where _ > 0
```,
caption: [Type parameters in a generic type's body],
)
#pagebreak()
== Extend Statement <extend-stmt>
Type statements allow you to define new types, kind of like type aliases. However, a type might have properties or methods of its own. These might override those of the parent type or be brand new members.
This is where the `extend` statement comes into play. It allows defining members on a given type. Members can either be properties (`prop`) or methods (`def`). The only difference between the two is that methods must be functions and can be overloaded.
Here is a simple example showing how to define a property and a method on a custom type:
#figure(
```midas
type MyType = float
extend MyType {
prop norm: float
def double: fn() -> MyType
}
```,
caption: [Simple `extend` statement defining a property and a method],
)
An `extend` statement can appear anywhere after the type it extends has been defined.
You may want to override Python's dunder methods to implement type checking for some basic operators, like `__add__` for the `+` operator.
#figure(
```midas
type Money = float
extend Money {
def __add__(Money, /) -> Money
def __mul__(float, /) -> Money
}
```,
caption: [Simple `extend` statement overriding some dunder methods],
)
When extending generic type, you must specify the whole type, including its parameter(s):
#figure(
```midas
type Container[T <: float] = object
extend Container[T <: float] {
prop content: T
def set_content: fn(content: T) -> None
}
```,
caption: [Generic `extend` statement using type parameters in the declared members],
)
#pagebreak()
== Predicate Statement <predicate-stmt>
A *`predicate`* statement lets you define a named constraint expression, like a function, which can then be used in other constraint expressions (either in other predicate statements or in constraint types). See @constraint-expr for more information about the syntax of constraint expressions.
The left-hand side of a predicate statement is written as a function signature, without a return type. The right-hand side is a constraint expression. For example:
#figure(
```midas
predicate is_positive(v: float) = v >= 0
```,
caption: [Simple `predicate` statement defining an `is_positive` predicate],
)
The left-hand side can also be curried to allow partial application. For example:
#figure(
```midas
predicate in_range(mn: float, mx: float)(v: float) = mn <= v & v <= mx
predicate is_ratio = in_range(0.0, 1.0)
```,
caption: [Curried `predicate` statement and partial application],
) <midas-predicate-partial>
Notice that the second predicate statement doesn't take any parameters. This is simply a partial application of another predicate, kind of like an alias. You can use it in other expressions to finalize the call:
#figure(
```midas
type Efficiency = float where is_ratio(_)
```,
caption: [Constraint type definition using the partially applied predicate from @midas-predicate-partial],
)
Of course you can also directly call `in_range`:
#figure(
```midas
type Efficiency = float where in_range(0.0, 1.0)(_)
```,
caption: [Full call of curried predicate from @midas-predicate-partial],
)
When compiled, named predicates are translated to Python functions which are used in runtime assertions. Only predicates that are referenced are compiled.
#pagebreak()
== Constraint Expressions <constraint-expr>
*Constraint expressions* are Python-like expressions which can appear in *`predicate`* statements or in constraint types.
They can contain comparisons, simple computations, logical operations and must evaluate to a boolean value.
Context is quite restricted inside these expressions. You can only reference some builtin functions, such as type constructors (`float(...)`, `str(...)`, etc.), parameters of predicate statements, and named predicates. In constraint type, the special variable `_` can be used to reference the value targeted by the type. For example:
#figure(
```midas
predicate not_nan(v: float) = v != float("nan")
type RealFloat = float where not_nan(_)
```,
caption: [Example constraint expressions],
) <ex-constraint-expr>
In the predicate statement (@ex-constraint-expr:1), we reference the parameter `v` and the builtin `float` function.
In the constraint type definition (@ex-constraint-expr:2), we then reference the named predicate `not_nan`, passing the value targeted by the type itself ( `_` )
= Supported Python Syntax <python-ref>
Midas integrates naturally in Python via type annotations. Through generated stubs, even other type checker can detect your custom types (see @cmd-stubs).
It has been designed to leave the user free of typing any amount of their code but be strict about the parts that are annotated. By default, any untyped Python expression is assigned `UnknownType`.
Any operation is permitted on `UnknownType` and will result in `UnknownType` values.
The moment an expression can be typed, that be thanks to an annotation or a literal value, the type checker kicks in and will validate your statements.
Because Python is very flexible language with many features, some expressions and statements might be more complex to properly type check, thus only a subset of the Python language is fully supported. This chapter lists all supported features of Python and how they affect type checking.
Some examples are presented in the following sections in the form of code blocks. Highlights in the code blocks indicate the type assigned to each expression by the type checker. Some types may be omitted for readability. For example:
#codly(
highlights: (
(
line: 1,
start: 5,
fill: green,
tag: [_int_],
),
(
line: 2,
start: 7,
end: 7,
fill: green,
tag: [_int_],
),
),
)
```python
v = 3
print(v)
```
== Literals
Literal Python values are type checked using builtin types. Lists and dictionaries of literals are also typed liked literals. This does not include comprehension lists/dicts (```py [. for . in .]```), nor formatted strings (```py f"..."```). @supported-literals shows the list of supported literal values and their type.
#let supported-literals = table(
columns: 2,
table.header[*Example value*][*Judged Type*],
```py 42```, ```py int```,
```py 3.14```, ```py float```,
```py True```, ```py bool```,
```py "Midas"```, ```py str```,
```py None```, ```py None```,
```py [1, 2, 3]```, ```py list[int]```,
```py {1: "One", 2: "Two"}```, ```py dict[int, str]```,
```py ("1", 1, True)```, ```py tuple[str, int, bool]```,
)
#figure(
supported-literals,
caption: [Supported literal values and their judged types],
) <supported-literals>
== Assignments
Variable assignments allow assigning a new value to a variable. For the type checker, this implies two things:
1. If the variable was not already declared in the current scope, it is declared at that point with the type of the right-hand side expression
2. If the variable was already declared, the type of the right-hand side expression is checked against the declared type of the variable. Only a subtype of the variable's type can be assigned to it
Once a variable has been given a type, it cannot be changed in the same scope.
The walrus operator (```py :=```) is not currently supported.
A simple annotation declaration, without assigning a value, is enough to declare a variable. For example:
#figure(
```python
var: float
```,
caption: [Bare Python variable annotation without assignment],
)
Because unpacking is not supported, assigning to multiple values is also not handled by the type checker.
== Arithmetic
- All basic binary operators are supported, through dunder methods.
- All comparison operators except ```py in``` are supported.
- All unary operators are supported (`+`, `-`, `~`).
- All logical operators are supported (```py and```, ```py or```, ```py not```).
== Ternary operator
The ternary operator ```py . if . else .``` is supported. As for `if` statements (see @if-else), the test expression must be a boolean. Additionally, both branches must be of the same type.
For example:
#codly(
highlights: (
(
line: 1,
start: 10,
end: 44,
tag: [_str_],
fill: blue,
),
(
line: 1,
start: 11,
end: 16,
tag: [_str_],
fill: green,
),
(
line: 1,
start: 39,
end: 43,
tag: [_str_],
fill: green,
),
(
line: 1,
start: 21,
end: 32,
tag: [_bool_],
fill: green,
),
),
)
#figure(
```python
parity = ("even" if num % 2 == 0 else "odd")
```,
caption: [Typing of ternary operator],
)
== Control flow
Some control flow features are supported. For the limited code of this project, not all constructs are supported. The following are those currently handled and typ checked by Midas.
=== `if` / `elif` / `else` <if-else>
Conditional statements are checked relatively strictly by Midas. The test expression, i.e. what comes after the ```py if``` keyword, must be a boolean. While Python allows introducing and leaking new variables from inside an ```py if``` statement, Midas will strictly forbid leaks by restraining bindings to the scope they are defined in. For example, the following Python code will not compile with Midas:
#figure(
```python
age = 22
if age >= 18:
msg = "You're an adult"
else:
msg = "You're still a child"
print(msg) # -> unknown variable 'msg'
```,
caption: [`if`/`else` statement cannot leak variables],
)
=== `for` loops
Simple forms of `for` loops can be used, that is using a single variable and iterating over an object implementing the `__getitem__` method. Like above in @if-else, leaking variables from inside the loop is ignored.
The `for`-`else` statements are not supported. `while` loops are also not not supported.
== Functions
You can define functions as usual and the type checker will do its best to type it. Apart from argument sinks (`*args`, `**kwargs`), all forms of parameter specifications are supported (positional-only, keyword-only, mixed, optional).
As for the rest of your code, type annotations are optional, but recommended. If you omit the return type hint, the type checker will try to infer it from the function body and its return statements. If you did specify a return type, all return paths must return values that are subtypes of the type hint.
#codly(
highlights: (
(
line: 2,
start: 12,
end: 16,
tag: [_float_],
fill: green,
),
(
line: 2,
start: 12,
tag: [_float_],
fill: blue,
),
(
line: 3,
start: 10,
end: 15,
tag: [_(value: float) -> float_],
fill: green,
),
(
line: 3,
start: 17,
end: 19,
tag: [_float_],
fill: green,
),
(
line: 3,
start: 10,
tag: [_float_],
fill: blue,
),
),
)
#figure(
```python
def double(value: float) -> float:
return value * 2
result = double(4.0)
```,
caption: [Typing of function's body and call],
)
Anonymous functions (```py lambda```) are not yet supported
== Casts <cast>
#gc.info[
The functions discussed in this section are provided by the `midas.typing` submodule. You can import them in your script like so:
#figure(
```python
from midas.typing import cast, unsafe_cast
```,
caption: [Importing cast functions],
)
]
Sometimes, you may want to use a value whose type is not known to the type checker in a place where it expects a particular type. In that case, if you do know that the runtime type will correspond to what is expected, you can use a `cast` expression.
Similar to the `cast` function from the `typing` package of Python's Standard Library, it allows telling the type checker that a value has a given type. While `typing`'s function doesn't have any runtime side-effect, Midas' will generate runtime assertions, ensuring that your statement is true when running the code. What cannot be checked statically is checked at runtime.
In the following example, a runtime check would be generated to ensure that the value is indeed a `float` and that it satisfies the type's constraint (i.e. `>= 0`):
#codly(
highlights: (
(
line: 1,
start: 35,
end: 47,
tag: [_UnknownType_],
fill: red,
),
(
line: 2,
start: 7,
end: 17,
tag: [_PositiveFloat_],
fill: green,
),
),
)
#figure(
```python
typed_value = cast(PositiveFloat, unknown_value)
print(typed_value)
```,
caption: [Typing of `cast` expression],
)
#gc.warning[
Assertions are statements inserted just before a statement using a `cast` expression. This means that the expression is evaluated _before_ its actual intended usage location, which might cause issues if you rely on logical operator short-circuiting. See @eager-eval for more information.
]
There may be some cases where the cost of checking a value at runtime is simply not worth the safety, for example when dealing with a big dataset. If do wish so, you can use `unsafe_cast` which will only tell the type checker the type of the value, without generating a runtime assertion. This maps to the default behavior of `typing`'s own `cast` function.
If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a string, a list of literals, etc.), the assertion is evaluated _at compile-time_ and no runtime assertion is generated.
= Commands <commands>
#TODO
== Type Checking (`check`) <cmd-check>
== Compiling (`compile`) <cmd-compile>
== Formatting (`format`) <cmd-format>
== Highlighting (`highlight`) <cmd-highlight>
== Dumping the AST (`parse`) <cmd-parse>
== Dumping the Registry (`dump-registry`) <cmd-registry>
== Generating Stubs (`stubs`) <cmd-stubs>
== Showing Type Judgements (`types`) <cmd-types>
== Validating Definitions (`validate`) <cmd-validate>
= Known limitations <limitations>
== Eager evaluation in runtime assertions <eager-eval>
The process of generating assertions to ensure safety at runtime, mainly for `cast` expressions, leads to the creation of aliases for the expressions being casted. These alias definitions eagerly evaluate before the assertion, and most importantly before the real usage location. This means that you should avoid using `cast` expressions inside logical expressions like `and` or `or`, because the normal "short-circuit" behavior will be irrelevant to the evaluations of the operands.
For example:
#figure(
```py
def foo():
print("Foo")
return True
def bar():
print("Bar")
return True
result = foo() or bar()
# Foo
# Bar
```,
caption: [Runtime assertions may eagerly evaluate expressions and bypass logical operator's short-circuit],
)

View File

@@ -1,180 +0,0 @@
%YAML 1.2
---
name: Midas
file_extensions:
- midas
scope: source.midas
variables:
identifier: "[a-zA-Z_][a-zA-Z0-9_]*"
contexts:
prototype:
- include: comments
main:
- include: keywords
- include: types
comments:
- match: "//"
scope: punctuation.definition.comment.midas
push:
- meta_scope: comment.line.midas
- match: $
pop: true
- match: /\*
scope: punctuation.definition.comment.midas
push:
- meta_scope: comment.block.midas
- match: \*/
pop: true
string:
- meta_include_prototype: false
- meta_scope: string.quoted.double.c
- match: '"'
pop: true
keywords:
- match: \btype\b
scope: keyword.declaration.midas
push: type-stmt
- match: \bextend\b
scope: keyword.declaration.midas
push: extend-stmt
- match: \bpredicate\b
scope: keyword.declaration.midas
push: predicate-stmt
type-stmt:
- match: "{{identifier}}"
scope: entity.name.type
- match: \[
push: type-params
- match: "="
scope: keyword.operator.equal.midas
push: type-expr
- match: $
pop: true
type-expr:
- match: \b(fn)\s*(\()
captures:
1: keyword.other.midas
2: punctuation.section.group.begin
push: fn-params
- match: \b(where)\b
scope: keyword.other.midas
set: constraint
- match: "{{identifier}}"
scope: entity.name.type
- match: $
pop: 2
fn-params:
- match: "({{identifier}})(:)"
captures:
1: variable.parameter.midas
2: punctuation.separator.annotation.midas
push:
- include: type-expr
- match: \?
scope: keyword.operator.qmark.midas
- match: "(?=,)"
scope: punctuation.separator.midas
pop: true
- match: '(?=\))'
pop: true
- include: type-expr
- match: '\)'
set:
- match: "->"
scope: keyword.operator.arrow.midas
set: type-expr
constraint:
- match: $
pop: 2
- match: \d+(\.\d+)?
scope: constant.numeric.midas
- match: \b(true|false|none)\b
scope: constant.language.midas
- match: '"'
push: string
- match: (<=|>=|<|>|==|!=|&)
scope: keyword.operator
- match: _
scope: variable.language.midas
- match: '{{identifier}}(?=\s*\()'
scope: variable.function.midas
- match: "{{identifier}}"
scope: variable.other.readwrite.midas
type-params:
- match: "<:"
scope: keyword.operator.subtype.midas
- match: "[a-zA-Z][a-zA-Z_0-9]*"
scope: entity.name.type
- match: "]"
pop: true
extend-stmt:
- match: "{{identifier}}"
scope: entity.name.type
- match: \[
push: type-params
- match: \{
scope: punctuation.section.block.begin
set: extend-body
extend-body:
- include: member-stmt
- match: \}
scope: punctuation.section.block.end
pop: true
member-stmt:
- match: \b(prop|def)\b
scope: keyword.other.midas
push:
- match: "{{identifier}}"
scope: variable.other.member
- match: ":"
push: type-expr
- match: $
pop: true
predicate-stmt:
- match: "{{identifier}}"
scope: entity.name.function.midas
- match: '\('
push: predicate-params
- match: "="
scope: keyword.operator.equal.midas
set: constraint
- match: $
pop: true
predicate-params:
- match: "({{identifier}})(:)"
captures:
1: variable.parameter.midas
2: punctuation.separator.annotation.midas
push:
- include: type-expr
- match: "(?=,)"
scope: punctuation.separator.midas
pop: true
- match: '(?=\))'
pop: true
- match: '\)'
pop: true

View File

@@ -1,143 +0,0 @@
#import "@preview/modpattern:0.2.0": modpattern
#let TODO = block(
width: 6em,
height: 3em,
stroke: red,
fill: modpattern(
size: (10pt, 10pt),
line(
start: (0%, 0%),
end: (100%, 100%),
stroke: gray.transparentize(60%) + 2pt,
),
),
align(
center + horizon,
text(fill: red, size: 1.5em)[*TODO*],
),
)
#let _render-header(version, hash) = {
let last-heading = query(heading.where(level: 1).before(here())).last(default: none)
let next-heading = query(heading.where(level: 1).after(here())).first(default: none)
let current-heading = if next-heading != none and next-heading.location().page() == here().page() {
next-heading
} else if last-heading != none {
last-heading
} else { none }
let chapter = if current-heading != none {
let body = current-heading.body
if current-heading.numbering != none {
let num = counter(heading).display(current-heading.numbering, at: current-heading.location())
body = [#num #body]
}
body
} else []
grid(
columns: (1fr, auto, 1fr),
align: (left, center, right),
document.title, [v#version - #hash], chapter,
)
}
#let _unshift-prefix(prefix, content) = context {
pad(left: -measure(prefix).width, prefix + content)
}
#let project(
title: none,
author: none,
version: "0.0.1",
hash: "abcdefgh",
icon-path: none,
doc,
) = {
assert(title != none, message: "Please provide a title")
set document(
title: title,
author: author,
)
set text(
font: "Source Sans 3",
)
set raw(syntaxes: path("midas.sublime-syntax"))
let front-page() = {
align(center)[
#{
set text(size: 1.5em)
std.title()
}
v#version - #hash
#if icon-path != none {
v(1cm)
image(icon-path)
}
]
pagebreak()
}
let outlines() = {
outline()
pagebreak()
outline(
title: [List of Listings],
target: figure.where(kind: raw),
)
outline(
title: [List of Tables],
target: figure.where(kind: table),
)
}
let main() = {
// Adapted from https://github.com/hei-templates/hei-synd-thesis/blob/7d2b941197babae0bf3afd4e5914754e09a64001/lib/template-thesis.typ#L242-L261
show heading.where(level: 1): it => {
pagebreak()
set text(size: 1.5em)
set block(above: 1.2em, below: 1.2em)
if it.numbering != none {
let num = numbering(it.numbering, ..counter(heading).at(it.location()))
let prefix = num + h(1em)
_unshift-prefix(prefix, it.body)
} else {
it
}
}
show heading.where(level: 2): it => {
if it.numbering != none {
let num = numbering(it.numbering, ..counter(heading).at(it.location()))
_unshift-prefix(num + h(0.8em), it.body)
} else {
it
}
}
set page(
header: context _render-header(version, hash),
footer: context if page.numbering != none {
align(center, counter(page).display(page.numbering, both: true))
},
numbering: "1 / 1",
)
show heading: set heading(numbering: "I.1.")
counter(page).update(1)
doc
}
front-page()
outlines()
main()
}

View File

@@ -44,11 +44,6 @@ class TypeStmt:
type: Type
class AliasStmt:
name: Token
type: Type
class MemberStmt:
name: Token
type: Type

View File

@@ -15,7 +15,7 @@ from midas.ast.location import Location
###> MidasType | Type annotations | node
class BaseType:
base: str
args: tuple[MidasType, ...]
param: Optional[MidasType]
class ConstraintType:
@@ -174,10 +174,6 @@ class SliceExpr:
step: Optional[Expr]
class TupleExpr:
items: tuple[Expr, ...]
class RawExpr:
expr: ast.expr

View File

@@ -51,9 +51,6 @@ class Stmt(ABC):
@abstractmethod
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
@abstractmethod
def visit_alias_stmt(self, stmt: AliasStmt) -> T: ...
@abstractmethod
def visit_member_stmt(self, stmt: MemberStmt) -> T: ...
@@ -74,15 +71,6 @@ class TypeStmt(Stmt):
return visitor.visit_type_stmt(self)
@dataclass(frozen=True)
class AliasStmt(Stmt):
name: Token
type: Type
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_alias_stmt(self)
@dataclass(frozen=True)
class MemberStmt(Stmt):
name: Token

View File

@@ -105,14 +105,6 @@ class MidasAstPrinter(
with self._child_level(single=True):
stmt.type.accept(self)
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
self._write_line("AliasStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def _print_type_param(self, param: m.TypeParam) -> None:
self._write_line("Param")
with self._child_level():
@@ -398,9 +390,6 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
return self.indented(res)
def visit_alias_stmt(self, stmt: m.AliasStmt) -> str:
return self.indented(f"alias {stmt.name.lexeme} = {stmt.type.accept(self)}")
def _print_type_param(self, param: m.TypeParam) -> str:
res: str = param.name.lexeme
if param.bound is not None:
@@ -560,13 +549,7 @@ class PythonAstPrinter(
self._write_line("BaseType")
with self._child_level():
self._write_line(f"base: {node.base}")
self._write_line("args:", last=True)
with self._child_level():
for i, arg in enumerate(node.args):
self._idx = i
if i == len(node.args) - 1:
self._mark_last()
arg.accept(self)
self._write_optional_child("param", node.param, last=True)
def visit_constraint_type(self, node: p.ConstraintType) -> None:
self._write_line("ConstraintType")
@@ -879,17 +862,6 @@ class PythonAstPrinter(
self._write_optional_child("upper", expr.upper)
self._write_optional_child("step", expr.step, last=True)
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
self._write_line("TupleExpr")
with self._child_level():
self._write_line("items", last=True)
with self._child_level():
for i, item in enumerate(expr.items):
self._idx = i
if i == len(expr.items) - 1:
self._mark_last()
item.accept(self)
def visit_raw_expr(self, expr: p.RawExpr) -> None:
self._write_line("RawExpr")
with self._child_level(single=True):

View File

@@ -44,7 +44,7 @@ class MidasType(ABC):
@dataclass(frozen=True)
class BaseType(MidasType):
base: str
args: tuple[MidasType, ...]
param: Optional[MidasType]
def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_base_type(self)
@@ -268,9 +268,6 @@ class Expr(ABC):
@abstractmethod
def visit_slice_expr(self, expr: SliceExpr) -> T: ...
@abstractmethod
def visit_tuple_expr(self, expr: TupleExpr) -> T: ...
@abstractmethod
def visit_raw_expr(self, expr: RawExpr) -> T: ...
@@ -405,14 +402,6 @@ class SliceExpr(Expr):
return visitor.visit_slice_expr(self)
@dataclass(frozen=True)
class TupleExpr(Expr):
items: tuple[Expr, ...]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_tuple_expr(self)
@dataclass(frozen=True)
class RawExpr(Expr):
expr: ast.expr

View File

@@ -178,100 +178,4 @@ extend dict[K, V] {
// def __ior__: fn(value: SupportsKeysAndGetItem[K, V], /) -> dict[K, V]
// def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V]
}
extend str {
def capitalize: fn() -> str
def casefold: fn() -> str
def center: fn(width: int, fillchar: str?, /) -> str
def count: fn(sub: str, start: None?, end: None?, /) -> int
def count: fn(sub: str, start: int, end: None?, /) -> int
def count: fn(sub: str, start: None, end: int, /) -> int
def count: fn(sub: str, start: int, end: int, /) -> int
def encode: fn(encoding: str?, errors: str?) -> bytes
def endswith: fn(suffix: str, start: None?, end: None?, /) -> bool
def endswith: fn(suffix: str, start: int, end: None?, /) -> bool
def endswith: fn(suffix: str, start: None, end: int, /) -> bool
def endswith: fn(suffix: str, start: int, end: int, /) -> bool
def expandtabs: fn(tabsize: int?) -> str
def find: fn(sub: str, start: None?, end: None?, /) -> int
def find: fn(sub: str, start: int, end: None?, /) -> int
def find: fn(sub: str, start: None, end: int, /) -> int
def find: fn(sub: str, start: int, end: int, /) -> int
// def format: fn(*args: object, **kwargs: object) -> str
// def format_map: fn(mapping: _FormatMapMapping, /) -> str
def index: fn(sub: str, start: None?, end: None?, /) -> int
def index: fn(sub: str, start: int, end: None?, /) -> int
def index: fn(sub: str, start: None, end: int, /) -> int
def index: fn(sub: str, start: int, end: int, /) -> int
def isalnum: fn() -> bool
def isalpha: fn() -> bool
def isascii: fn() -> bool
def isdecimal: fn() -> bool
def isdigit: fn() -> bool
def isidentifier: fn() -> bool
def islower: fn() -> bool
def isnumeric: fn() -> bool
def isprintable: fn() -> bool
def isspace: fn() -> bool
def istitle: fn() -> bool
def isupper: fn() -> bool
def join: fn(iterable: list[str], /) -> str // TODO: use Iterable
def ljust: fn(width: int, fillchar: str?, /) -> str
def lower: fn() -> str
def lstrip: fn(chars: None?, /) -> str
def lstrip: fn(chars: str, /) -> str
def partition: fn(sep: str, /) -> tuple[str, str, str]
def replace: fn(old: str, new: str, count: int?, /) -> str
def removeprefix: fn(prefix: str, /) -> str
def removesuffix: fn(suffix: str, /) -> str
def rfind: fn(sub: str, start: None?, end: None?, /) -> int
def rfind: fn(sub: str, start: int, end: None?, /) -> int
def rfind: fn(sub: str, start: None, end: int, /) -> int
def rfind: fn(sub: str, start: int, end: int, /) -> int
def rindex: fn(sub: str, start: None?, end: None?, /) -> int
def rindex: fn(sub: str, start: int, end: None?, /) -> int
def rindex: fn(sub: str, start: None, end: int, /) -> int
def rindex: fn(sub: str, start: int, end: int, /) -> int
def rjust: fn(width: int, fillchar: str?, /) -> str
def rpartition: fn(sep: str, /) -> tuple[str, str, str]
def rsplit: fn(sep: None?, maxsplit: int?) -> list[str]
def rsplit: fn(sep: str, maxsplit: int?) -> list[str]
def rstrip: fn(chars: None?, /) -> str
def rstrip: fn(chars: str, /) -> str
def split: fn(sep: None?, maxsplit: int?) -> list[str]
def split: fn(sep: str, maxsplit: int?) -> list[str]
def splitlines: fn(keepends: bool?) -> list[str]
def startswith: fn(prefix: str, start: None?, end: None?, /) -> bool
def startswith: fn(prefix: str, start: int, end: None?, /) -> bool
def startswith: fn(prefix: str, start: None, end: int, /) -> bool
def startswith: fn(prefix: str, start: int, end: int, /) -> bool
def strip: fn(chars: None?, /) -> str
def strip: fn(chars: str, /) -> str
def swapcase: fn() -> str
def title: fn() -> str
// def translate: fn(table: _TranslateTable, /) -> str
def upper: fn() -> str
def zfill: fn(width: int, /) -> str
def __add__: fn(value: str, /) -> str
// Incompatible with Sequence.__contains__
def __contains__: fn(key: str, /) -> bool
def __eq__: fn(value: object, /) -> bool
def __ge__: fn(value: str, /) -> bool
def __getitem__: fn(key: slice, /) -> str
def __getitem__: fn(key: int, /) -> str
def __gt__: fn(value: str, /) -> bool
def __hash__: fn() -> int
// def __iter__: fn() -> Iterator[str]
def __le__: fn(value: str, /) -> bool
def __len__: fn() -> int
def __lt__: fn(value: str, /) -> bool
def __mod__: fn(value: Any, /) -> str
def __mul__: fn(value: int, /) -> str
def __ne__: fn(value: object, /) -> bool
def __rmul__: fn(value: int, /) -> str
def __getnewargs__: fn() -> tuple[str]
def __format__: fn(format_spec: str, /) -> str
}
}

View File

@@ -14,10 +14,8 @@ if TYPE_CHECKING:
from midas.checker.registry import TypesRegistry
# Hard-coded subtype relationships between builtin types
# Circular dependencies and diamond inheritance MUST be avoided
BUILTIN_SUBTYPES: dict[str, set[str]] = {
"object": {"float", "list", "dict", "str", "bytes", "tuple"},
"object": {"float", "list", "dict", "str"},
"float": {"int"},
"int": {"bool"},
}
@@ -28,15 +26,12 @@ def define_builtins(reg: TypesRegistry):
any = reg.define_type("Any", TopType())
unit = reg.define_type("None", UnitType())
object = reg.define_type("object", BaseType(name="object"))
bytes = reg.define_type("bytes", BaseType(name="bytes"))
bool = reg.define_type("bool", BaseType(name="bool"))
int = reg.define_type("int", BaseType(name="int"))
float = reg.define_type("float", BaseType(name="float"))
str = reg.define_type("str", BaseType(name="str"))
slice = reg.define_type("slice", BaseType(name="slice"))
tuple = reg.define_type("tuple", BaseType(name="tuple"))
list = reg.define_type(
"list",
GenericType(

View File

@@ -1,484 +0,0 @@
import logging
from dataclasses import dataclass
from enum import StrEnum
from typing import Generic, Optional, Protocol, TypeVar, Union
from midas.ast.location import Location
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import (
AppliedType,
DerivedType,
Function,
GenericType,
OverloadedFunction,
Type,
UnknownType,
)
from midas.checker.unifier import Unifier
class HasLocation(Protocol):
@property
def location(self) -> Location: ...
E = TypeVar("E", bound=HasLocation)
TypedExpr = tuple[E, Type]
@dataclass(frozen=True, kw_only=True)
class MappedArgument(Generic[E]):
expr: E
type: Type
argument: Function.Argument
@dataclass(frozen=True, kw_only=True)
class OverloadCandidate:
function: Function
mapped: list[MappedArgument]
class CallError(StrEnum):
INVALID_ARGS = "Invalid arguments"
NO_MATCHING_OVERLOAD = "No matching overload"
IMPOSSIBLE_UNIFICATION = "Parameters unification failed"
NOT_CALLABLE = "Not callable"
@dataclass(frozen=True, kw_only=True)
class CallResult:
error: Optional[CallError] = None
result: Type = UnknownType()
message: Optional[str] = None
@property
def is_valid(self) -> bool:
return self.error is None
@property
def error_message(self) -> str:
if self.message is not None:
return self.message
if self.error is not None:
return str(self.error)
return ""
class CallDispatcher(Generic[E]):
def __init__(self, types: TypesRegistry, reporter: FileReporter) -> None:
self.types: TypesRegistry = types
self.reporter: FileReporter = reporter
self.logger: logging.Logger = logging.getLogger("CallDispatcher")
def set_reporter(self, reporter: FileReporter):
self.reporter = reporter
def get_result(
self,
location: Location,
callee: Type,
positional: list[TypedExpr[E]],
keywords: dict[str, TypedExpr[E]],
report_errors: bool = True,
) -> CallResult:
"""Get the result type of a function call
If the function has overloads, the function will try to resolve the
appropriate signature.
Argument types are matched to the defined parameters.
The function doesn't take the raw expression as a parameter to accommodate
for desugared calls such as for operators.
Args:
location (Location): the call location
callee (Type): the called function
positional (list[TypedExpr]): the list positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Type: the return type of the call, or `None` if either
the call is invalid or no overload matched the arguments uniquely
"""
match callee:
case Function() as function:
valid: bool
mapped: list[MappedArgument[E]]
valid, mapped = self.map_call_arguments(
function, location, positional, keywords
)
valid = valid and self._are_arguments_valid(mapped, report_errors)
if not valid:
return CallResult(error=CallError.INVALID_ARGS)
return CallResult(result=function.returns)
case OverloadedFunction(overloads=overloads):
res = self._match_overload(
overloads, location, positional, keywords, report_errors
)
if res[0] is None:
return CallResult(
error=CallError.NO_MATCHING_OVERLOAD,
message=res[1],
)
return CallResult(result=res[0].returns)
case AppliedType(body=body):
return self.get_result(
location, body, positional, keywords, report_errors
)
case UnknownType():
return CallResult(result=UnknownType())
case DerivedType(type=base):
return self.get_result(
location, base, positional, keywords, report_errors
)
case GenericType():
unifier: Unifier = Unifier(self.types)
pos: list[Type] = [a[1] for a in positional]
kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()}
unified: Optional[Type] = unifier.unify_call(callee, pos, kw)
if unified is None:
pos_str: str = ", ".join(str(t) for t in pos)
kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items())
message: str = (
f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}"
)
if report_errors:
self.reporter.error(location, message)
return CallResult(
error=CallError.IMPOSSIBLE_UNIFICATION,
message=message,
)
return self.get_result(
location,
unified,
positional,
keywords,
report_errors,
)
case _:
message: str = f"{callee} ({callee.__class__.__name__}) is not callable"
if report_errors:
self.reporter.error(location, message)
return CallResult(
error=CallError.NOT_CALLABLE,
message=message,
)
def _unwrap_function(
self,
callee: Type,
positional: list[TypedExpr[E]],
keywords: dict[str, TypedExpr[E]],
) -> Union[tuple[Function, None], tuple[None, CallError]]:
match callee:
case DerivedType(type=base):
return self._unwrap_function(base, positional, keywords)
case GenericType():
unifier: Unifier = Unifier(self.types)
unified: Optional[Type] = unifier.unify_call(
callee,
[a[1] for a in positional],
{k: v[1] for k, v in keywords.items()},
)
if unified is None:
return None, CallError.IMPOSSIBLE_UNIFICATION
return self._unwrap_function(unified, positional, keywords)
case Function():
return callee, None
case AppliedType(body=body):
return self._unwrap_function(body, positional, keywords)
case _:
return None, CallError.NOT_CALLABLE
def _are_arguments_valid(
self,
arguments: list[MappedArgument[E]],
report_errors: bool = True,
) -> bool:
"""Check whether the passed argument types correspond to their matched parameter definitions
Args:
arguments (list[MappedArgument]): the list of argument/parameter pairs
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
bool: True if all arguments fit the matching parameter definitions, False otherwise
"""
valid: bool = True
for arg in arguments:
if not self.types.is_subtype(arg.type, arg.argument.type):
if report_errors:
self.reporter.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
)
valid = False
return valid
def _match_overload(
self,
overloads: list[Type],
location: Location,
positional: list[TypedExpr[E]],
keywords: dict[str, TypedExpr[E]],
report_errors: bool = True,
) -> Union[tuple[Function, None], tuple[None, str]]:
"""Try and resolve the appropriate overload for the given arguments
Args:
overloads (list[Type]): the list of possible overloads
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keywords arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Optional[Function]: the resolved function signature if it can be
determined unambiguously, or `None`.
"""
candidates: list[OverloadCandidate] = []
errors: list[CallError] = []
for overload in overloads:
function, unwrap_error = self._unwrap_function(
overload, positional, keywords
)
if function is None:
errors.append(unwrap_error) # type: ignore
continue
valid, mapped = self.map_call_arguments(
function=function,
location=location,
positional=positional,
keywords=keywords,
report_errors=False,
)
if valid and self._are_arguments_valid(mapped, report_errors=False):
candidates.append(
OverloadCandidate(
function=function,
mapped=mapped,
)
)
pos_types: str = ", ".join(str(type) for _, type in positional)
kw_types: str = ", ".join(
f"{name}: {type}" for name, (_, type) in keywords.items()
)
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
n_candidates: int = len(candidates)
# Exactly 1 match -> return it
if n_candidates == 1:
return candidates[0].function, None
# No match -> invalid call
if n_candidates == 0:
overloads_str: str = ", ".join(map(str, overloads))
errors_str: str = ", ".join(errors)
message: str = (
f"No matching overload in [{overloads_str}] {for_args} (errors: {errors_str})"
)
if report_errors:
self.reporter.error(location, message)
return None, message
# Multiple matches -> see if one <: all others (more specific)
for i1, c1 in enumerate(candidates):
mapped1: list[MappedArgument[E]] = c1.mapped
best_match: bool = True
for i2, c2 in enumerate(candidates):
if i1 == i2:
continue
mapped2: list[MappedArgument[E]] = c2.mapped
if not self._are_mapped_subtypes(mapped1, mapped2):
best_match = False
break
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
if best_match:
return c1.function, None
candidates_str: str = ", ".join(
str(candidate.function) for candidate in candidates
)
message: str = f"Multiple matching overloads {for_args}: {candidates_str}"
if report_errors:
self.reporter.error(location, message)
return None, message
def map_call_arguments(
self,
function: Function,
location: Location,
positional: list[TypedExpr[E]],
keywords: dict[str, TypedExpr[E]],
report_errors: bool = True,
) -> tuple[bool, list[MappedArgument]]:
"""Map call arguments to a function's parameters as defined in its signature
This method maps positional-only, keyword-only and mixed parameter definitions
with the arguments passed at the call site
Any mismatched, missing or unexpected argument is reported as a diagnostic,
unless `report_errors` is set to `False`
Args:
function (Function): the function definition
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
tuple[bool, list[MappedArgument]]: a boolean reporting whether
the call is valid and the list of mapped arguments
"""
set_args: set[str] = set()
required_positional: list[str] = [
arg.name for arg in function.pos_args + function.args if arg.required
]
required_keyword: list[str] = [
arg.name for arg in function.kw_args if arg.required
]
mapped: list[MappedArgument[E]] = []
pos_params: list[Function.Argument] = list(function.pos_args)
mixed_params: list[Function.Argument] = list(function.args)
kw_params: dict[str, Function.Argument] = {
arg.name: arg for arg in function.kw_args
}
valid_call: bool = True
# TODO: handle *args and **kwargs sinks
for arg in positional:
param: Function.Argument
if len(pos_params) != 0:
param = pos_params.pop(0)
elif len(mixed_params) != 0:
param = mixed_params.pop(0)
else:
if report_errors:
self.reporter.error(
arg[0].location, "Too many positional arguments"
)
valid_call = False
break
name: str = param.name
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
kw_params.update({arg.name: arg for arg in mixed_params})
for name, arg in keywords.items():
param: Function.Argument
if name not in kw_params:
if report_errors:
if name in set_args:
self.reporter.error(
arg[0].location, f"Multiple values for argument '{name}'"
)
else:
self.reporter.error(
arg[0].location, f"Unknown keyword argument '{name}'"
)
valid_call = False
continue
param = kw_params.pop(name)
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
def join_args(args: list[str]) -> str:
args = list(map(lambda a: f"'{a}'", args))
if len(args) == 0:
return ""
if len(args) == 1:
return args[0]
return ", ".join(args[:-1]) + " and " + args[-1]
if len(required_positional) != 0:
plural: str = "" if len(required_positional) == 1 else "s"
args: str = join_args(required_positional)
if report_errors:
self.reporter.error(
location,
f"Missing required positional argument{plural}: {args}",
)
valid_call = False
if len(required_keyword) != 0:
plural: str = "" if len(required_keyword) == 1 else "s"
args: str = join_args(required_keyword)
if report_errors:
self.reporter.error(
location,
f"Missing required keyword argument{plural}: {args}",
)
valid_call = False
return valid_call, mapped
def _are_mapped_subtypes(
self, mapped1: list[MappedArgument[E]], mapped2: list[MappedArgument[E]]
) -> bool:
"""Check whether the given argument mappings are subtype/supertype of one another
This function checks whether the argument mappings `mapped1` are subtypes
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
of the corresponding parameter in `mapped2`, `False` is returned.
This is used to check whether a given overload is
a more specific function/ a subtype of another.
Args:
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
Returns:
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
"""
by_expr: dict[E, Type] = {}
for arg in mapped1:
by_expr[arg.expr] = arg.argument.type
for arg in mapped2:
type2: Type = arg.argument.type
type1: Type = by_expr[arg.expr]
if not self.types.is_subtype(type1, type2):
return False
return True

View File

@@ -0,0 +1,198 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional
from midas.ast.location import Location
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import (
ColumnType,
DataFrameType,
Function,
OverloadedFunction,
TopType,
Type,
UnknownType,
unfold_type,
)
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
@staticmethod
def frame_method(*names: str):
def wrapper(func):
names_: tuple[str, ...] = names
if len(names_) == 0:
names_ = (func.__name__,)
setattr(func, "__method_names__", names_)
return func
return wrapper
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
frame: DataFrameType
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
class _MethodRegistryMeta(type):
_methods: dict[str, Callable] = {}
def __new__(
cls,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
):
new_class = super().__new__(cls, name, bases, namespace)
new_class._methods = {}
for attr in namespace.values():
if callable(attr) and hasattr(attr, "__method_names__"):
for name in attr.__method_names__: # type: ignore
new_class._methods[name] = attr
return new_class
class MethodRegistry(metaclass=_MethodRegistryMeta):
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
@property
def reporter(self) -> FileReporter:
return self.typer.reporter
@property
def types(self) -> TypesRegistry:
return self.typer.types
def call(
self,
method: str,
call: Call,
) -> Type:
func: Optional[Callable] = self._methods.get(method)
if func is None:
self.reporter.error(call.location, f"Unknown method {method}")
return UnknownType()
return func(self, call)
@frame_method("add", "__add__")
def add(
self,
call: Call,
) -> Type:
# TODO: support add with scalar, sequence, Series, dict
# TODO: check operation exists on inner column types
new_columns: list[DataFrameType.Column] = []
by_name: dict[str, DataFrameType.Column] = {}
frame2: Optional[DataFrameType] = None
if len(call.positional) != 0:
other: Type = call.positional[0][1]
unfolded_other: Type = unfold_type(other)
if isinstance(unfolded_other, DataFrameType):
frame2 = unfolded_other
by_name = {
col.name: col for col in frame2.columns if col.name is not None
}
in_frame1: set[str] = set()
for column in call.frame.columns:
if column.name is not None:
in_frame1.add(column.name)
col_type1: Type = column.type
col_type: Type = ColumnType(type=UnknownType())
if column.name in by_name:
column2 = by_name[column.name]
col_type2: Type = column2.type
if self.types.are_equivalent(col_type2, col_type1):
col_type = col_type1
new_column = DataFrameType.Column(
index=column.index,
name=column.name,
type=col_type,
)
new_columns.append(new_column)
if frame2 is not None:
for column in frame2.columns:
if column.name in in_frame1:
continue
new_columns.append(
DataFrameType.Column(
index=len(new_columns),
name=column.name,
type=ColumnType(type=UnknownType()),
)
)
signature = Function(
args=[
Function.Argument(
pos=0,
name="other",
type=DataFrameType(columns=[]),
required=True,
),
],
returns=DataFrameType(columns=new_columns),
)
return (
self.typer._get_call_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
or UnknownType()
)
@frame_method()
def mean(self, call: Call) -> Type:
with_axis = Function(
kw_args=[
Function.Argument(
pos=0,
name="axis",
type=self.types.get_type("int"),
required=False,
)
],
returns=ColumnType(type=TopType()),
)
without_axis = Function(
kw_args=[
Function.Argument(
pos=0,
name="axis",
type=self.types.get_type("None"),
required=True,
)
],
returns=TopType(),
)
overload = OverloadedFunction(
overloads=[
with_axis,
without_axis,
]
)
return (
self.typer._get_call_result(
location=call.location,
callee=overload,
positional=call.positional,
keywords=call.keywords,
)
or UnknownType()
)

View File

@@ -4,20 +4,9 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frames.frame_groupby_methods import Call as GroupByCall
from midas.checker.frames.frame_groupby_methods import FrameGroupByMethodRegistry
from midas.checker.frames.frame_methods import Call, FrameMethodRegistry
from midas.checker.registry import TypesRegistry
from midas.checker.frame_methods import Call, MethodRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import (
ColumnGroupBy,
ColumnType,
DataFrameType,
FrameGroupBy,
TupleType,
Type,
UnknownType,
)
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
@@ -30,10 +19,7 @@ def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
class FrameManager:
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
self.method_resolver: FrameMethodRegistry = FrameMethodRegistry(self.typer)
self.groupby_method_resolver: FrameGroupByMethodRegistry = (
FrameGroupByMethodRegistry(self.typer)
)
self.method_resolver: MethodRegistry = MethodRegistry(self.typer)
def assign(
self,
@@ -48,41 +34,12 @@ class FrameManager:
return self.assign_column(reporter, location, frame, name, value_type)
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
isinstance(index.value, str) for index in indices
isinstance(idx, str) for idx in indices
):
names: list[str] = [cast(str, index.value) for index in indices]
if not isinstance(value_type, TupleType):
reporter.error(
location,
f"Cannot assign {type} to dataframe columns. Must be a tuple of columns",
)
return UnknownType()
if len(names) != len(value_type.items):
reporter.error(
location,
f"Wrong number of columns. Cannot assign {len(value_type.items)} to {len(names)} targets",
)
return UnknownType()
new_frame: Type = frame
for name, value in zip(names, value_type.items):
new_frame = self.assign_column(
reporter,
location,
new_frame,
name,
value,
)
if not isinstance(new_frame, DataFrameType):
return new_frame
return new_frame
raise NotImplementedError
case _:
reporter.error(
location, f"Invalid index type {index} on {frame} (assignment)"
)
reporter.error(location, f"Invalid index type {index} on {frame}")
return UnknownType()
def assign_column(
@@ -130,31 +87,9 @@ class FrameManager:
return TupleType(items=tuple(columns))
case _:
reporter.error(
location, f"Invalid index type {index} on {frame} (access)"
)
reporter.error(location, f"Invalid index type {index} on {frame}")
return UnknownType()
def groupby_get(
self,
reporter: FileReporter,
location: Location,
groupby: FrameGroupBy,
index: p.Expr,
) -> Type:
result: Type = self.get(reporter, location, groupby.frame, index)
match result:
case ColumnType():
result = ColumnGroupBy(column=result)
case TupleType(items=columns):
result = TupleType(
items=tuple(
ColumnGroupBy(column=cast(ColumnType, column))
for column in columns
)
)
return result
@classmethod
def _set_column(
cls, frame: DataFrameType, name: str, column: ColumnType
@@ -206,50 +141,14 @@ class FrameManager:
self,
method: str,
location: Location,
call_expr: p.Expr,
frame: DataFrameType,
frame_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: Call = Call(
location=location,
call_expr=call_expr,
frame=frame,
frame_expr=frame_expr,
positional=positional,
keywords=keywords,
)
return self.method_resolver.call(method, call)
def groupby_call(
self,
method: str,
location: Location,
call_expr: p.Expr,
groupby: FrameGroupBy,
groupby_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: GroupByCall = GroupByCall(
location=location,
call_expr=call_expr,
groupby=groupby,
groupby_expr=groupby_expr,
positional=positional,
keywords=keywords,
)
return self.groupby_method_resolver.call(method, call)
def get_attribute(self, frame: DataFrameType, name: str) -> Optional[Type]:
types: TypesRegistry = self.typer.types
match name:
case "ndim" | "size":
return types.get_type("int")
case "shape":
return types.tuple_of("int", "int")
case _:
return None

View File

@@ -1,203 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.dispatcher import CallResult
from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import ColumnGroupBy, ColumnType, Function, TopType, Type
if TYPE_CHECKING:
from midas.checker.python import TypedExpr
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
call_expr: p.Expr
groupby: ColumnGroupBy
groupby_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.groupby_expr, self.groupby)
class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
NAMED_ARGS: dict[str, str] = {
"numeric_only": "bool",
"skipna": "bool",
"engine": "str",
"engine_kwargs": "dict",
}
def _aggregate(
self,
call: Call,
args: list[str | tuple[str, str, bool]] = [],
*,
preserve_inner_type: bool = False,
) -> Type:
real_args: list[Function.Argument] = []
for i, arg in enumerate(args):
match arg:
case str() as name:
arg = Function.Argument(
pos=i,
name=name,
type=self.types.get_type(self.NAMED_ARGS[name]),
required=False,
)
case (name, type, required):
arg = Function.Argument(
pos=i,
name=name,
type=self.types.get_type(type),
required=required,
)
real_args.append(arg)
signature = Function(
args=real_args,
returns=(
call.groupby.column
if preserve_inner_type
else ColumnType(type=TopType())
),
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method()
def kurt(self, call: Call) -> Type:
return self._aggregate(
call,
["skipna", "numeric_only"],
)
@method()
def max(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
preserve_inner_type=True,
)
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(
call,
["numeric_only", "skipna", "engine", "engine_kwargs"],
)
@method()
def median(self, call: Call) -> Type:
return self._aggregate(
call,
["numeric_only", "skipna"],
preserve_inner_type=True,
)
@method()
def min(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
preserve_inner_type=True,
)
@method()
def prod(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
],
)
@method()
def std(self, call: Call) -> Type:
return self._aggregate(
call,
[
(
"ddof",
"int",
False,
),
"engine",
"engine_kwargs",
"numeric_only",
"skipna",
],
)
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
)
@method()
def var(self, call: Call) -> Type:
return self._aggregate(
call,
[
(
"var",
"int",
False,
),
"engine",
"engine_kwargs",
"numeric_only",
"skipna",
],
)

View File

@@ -1,78 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frames.column_groupby_methods import Call as GroupByCall
from midas.checker.frames.column_groupby_methods import ColumnGroupByMethodRegistry
from midas.checker.frames.column_methods import Call, ColumnMethodRegistry
from midas.checker.registry import TypesRegistry
from midas.checker.types import ColumnGroupBy, ColumnType, Type
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
class ColumnManager:
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
self.method_resolver: ColumnMethodRegistry = ColumnMethodRegistry(self.typer)
self.groupby_method_resolver: ColumnGroupByMethodRegistry = (
ColumnGroupByMethodRegistry(self.typer)
)
def call(
self,
method: str,
location: Location,
call_expr: p.Expr,
column: ColumnType,
column_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: Call = Call(
location=location,
call_expr=call_expr,
column=column,
column_expr=column_expr,
positional=positional,
keywords=keywords,
)
return self.method_resolver.call(method, call)
def groupby_call(
self,
method: str,
location: Location,
call_expr: p.Expr,
groupby: ColumnGroupBy,
groupby_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: GroupByCall = GroupByCall(
location=location,
call_expr=call_expr,
groupby=groupby,
groupby_expr=groupby_expr,
positional=positional,
keywords=keywords,
)
return self.groupby_method_resolver.call(method, call)
def get_attribute(self, column: ColumnType, name: str) -> Optional[Type]:
types: TypesRegistry = self.typer.types
match name:
case "ndim" | "size":
return types.get_type("int")
case "shape":
return types.tuple_of("int")
case "T":
return column
case _:
return None

View File

@@ -1,410 +0,0 @@
from __future__ import annotations
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.dispatcher import CallResult
from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import (
ColumnGroupBy,
ColumnType,
Function,
GenericType,
TopType,
Type,
TypeVar,
UnknownType,
unfold_type,
)
if TYPE_CHECKING:
from midas.checker.python import TypedExpr
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
call_expr: p.Expr
column: ColumnType
column_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.column_expr, self.column)
class ColumnMethodRegistry(MethodRegistry[Call]):
def _element_binary_op(self, call: Call, method: str) -> ColumnType:
"""Compute the result of an element-wise binary operation
This function delegates to the inner types for computing the resulting
type.
Args:
call (Call): the call that triggered this resolution
method (str): the method name
Returns:
ColumnType: the resulting column type
"""
column2: Optional[ColumnType] = None
col_type1: Type = call.column.type
new_column: Type = ColumnType(type=UnknownType())
if len(call.positional) != 0:
other: Type = call.positional[0][1]
unfolded_other: Type = unfold_type(other)
if isinstance(unfolded_other, ColumnType):
column2 = unfolded_other
col_type2: Type = column2.type
new_inner_type = self.typer.result_of_binary_op(
location=call.location,
expr=call.call_expr,
left=(call.column_expr, col_type1),
right=(call.positional[0][0], col_type2),
method=method,
)
new_column = ColumnType(type=new_inner_type)
return new_column
def _element_wise(self, call: Call, method: str) -> Type:
# TODO: support add with scalar
# Build signature with new column type and generic operand
param_type: TypeVar = TypeVar(name="T", bound=None)
signature = GenericType(
name="add",
params=[param_type],
body=Function(
args=[
Function.Argument(
pos=0,
name="other",
type=ColumnType(type=param_type),
required=True,
),
],
returns=self._element_binary_op(call, method),
),
)
# Map arguments and compute result type
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
if result.is_valid:
self._assert_same_length(
call.call_expr, call.column_expr, call.positional[0][0]
)
return result.result
@method("add", "__add__")
def add(self, call: Call) -> Type:
return self._element_wise(call, "__add__")
@method("sub", "__sub__")
def sub(self, call: Call) -> Type:
return self._element_wise(call, "__sub__")
@method("mul", "__mul__")
def mul(self, call: Call) -> Type:
return self._element_wise(call, "__mul__")
@method("div", "truediv", "__truediv__")
def truediv(self, call: Call) -> Type:
return self._element_wise(call, "__truediv__")
@method("floordiv", "__floordiv__")
def floordiv(self, call: Call) -> Type:
return self._element_wise(call, "__floordiv__")
@method("mod", "__mod__")
def mod(self, call: Call) -> Type:
return self._element_wise(call, "__mod__")
@method("pow", "__pow__")
def pow(self, call: Call) -> Type:
return self._element_wise(call, "__pow__")
@method("lt", "__lt__")
def lt(self, call: Call) -> Type:
return self._element_wise(call, "__lt__")
@method("gt", "__gt__")
def gt(self, call: Call) -> Type:
return self._element_wise(call, "__gt__")
@method("le", "__le__")
def le(self, call: Call) -> Type:
return self._element_wise(call, "__le__")
@method("ge", "__ge__")
def ge(self, call: Call) -> Type:
return self._element_wise(call, "__ge__")
@method("ne", "__ne__")
def ne(self, call: Call) -> Type:
return self._element_wise(call, "__ne__")
@method("eq", "__eq__")
def eq(self, call: Call) -> Type:
return self._element_wise(call, "__eq__")
def _aggregate(
self,
call: Call,
kwargs: list[Function.Argument] = [],
*,
preserve_inner_type: bool = False,
) -> Type:
signature = Function(
kw_args=[
Function.Argument(
pos=0,
name="axis",
type=TopType(),
required=False,
),
*kwargs,
],
returns=call.column if preserve_inner_type else ColumnType(type=TopType()),
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method("kurtosis", "kurt")
def kurtosis(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def max(self, call: Call) -> Type:
return self._aggregate(call, preserve_inner_type=True)
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def median(self, call: Call) -> Type:
return self._aggregate(call, preserve_inner_type=True)
@method()
def min(self, call: Call) -> Type:
return self._aggregate(call, preserve_inner_type=True)
@method()
def mode(self, call: Call) -> Type:
return self._aggregate(call, preserve_inner_type=True)
@method("product", "prod")
def product(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def std(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Argument(
pos=1,
name="ddof",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def var(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Argument(
pos=1,
name="var",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def head(self, call: Call) -> Type:
signature = Function(
args=[
Function.Argument(
pos=0,
name="n",
type=self.types.get_type("int"),
required=False,
),
],
returns=call.column,
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method()
def tail(self, call: Call) -> Type:
signature = Function(
args=[
Function.Argument(
pos=0,
name="n",
type=self.types.get_type("int"),
required=False,
),
],
returns=call.column,
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method()
def groupby(self, call: Call) -> Type:
bool_: Type = self.types.get_type("bool")
function: Function = Function(
args=[
Function.Argument(
pos=0,
name="by",
type=TopType(),
required=False,
),
Function.Argument(
pos=1,
name="level",
type=TopType(),
required=False,
),
],
kw_args=[
Function.Argument(
pos=2,
name="as_index",
type=bool_,
required=False,
),
Function.Argument(
pos=3,
name="sort",
type=bool_,
required=False,
),
Function.Argument(
pos=4,
name="group_keys",
type=bool_,
required=False,
),
Function.Argument(
pos=5,
name="observed",
type=bool_,
required=False,
),
Function.Argument(
pos=6,
name="dropna",
type=bool_,
required=False,
),
],
returns=ColumnGroupBy(column=call.column),
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=function,
positional=call.positional,
keywords=call.keywords,
)
return result.result
def _assert_same_length(self, call_expr: p.Expr, column1: p.Expr, column2: p.Expr):
func_name: str = "__midas_column_same_length__"
# Efficiently compute length
# https://stackoverflow.com/a/15943975/11109181
def len_of_col(col: ast.expr) -> ast.expr:
return ast.Call(
func=ast.Name(id="len"),
args=[
ast.Attribute(
value=col,
attr="index",
)
],
keywords=[],
)
self.assertions.define(
func_name,
ast.FunctionDef(
name=func_name,
args=ast.arguments(
posonlyargs=[],
args=[
ast.arg(arg="column1"),
ast.arg(arg="column2"),
],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Return(
value=ast.Compare(
left=len_of_col(ast.Name(id="column1")),
ops=[ast.Eq()],
comparators=[
len_of_col(ast.Name(id="column2")),
],
)
)
],
decorator_list=[],
),
)
self.assertions.add(
bound_expr=call_expr,
inputs=[column1, column2],
builder=lambda c1, c2: ast.Call(
func=ast.Name(id=func_name),
args=[c1, c2],
keywords=[],
),
message="Columns must have the same length",
)

View File

@@ -1,103 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import (
ColumnGroupBy,
ColumnType,
DataFrameType,
FrameGroupBy,
Type,
UnknownType,
)
if TYPE_CHECKING:
from midas.checker.python import TypedExpr
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
call_expr: p.Expr
groupby: FrameGroupBy
groupby_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.groupby_expr, self.groupby)
class FrameGroupByMethodRegistry(MethodRegistry[Call]):
NAMED_ARGS: dict[str, str] = {
"numeric_only": "bool",
"skipna": "bool",
"engine": "str",
"engine_kwargs": "dict",
}
def _aggregate(self, call: Call, method: str) -> Type:
new_columns: list[DataFrameType.Column] = []
for column in call.groupby.frame.columns:
column_groupby: ColumnGroupBy = ColumnGroupBy(column=column.type)
result_type: Type = self.typer.call_method(
location=call.location,
call_expr=call.call_expr,
obj=(call.groupby_expr, column_groupby),
method_name=method,
positional=call.positional,
keywords=call.keywords,
)
if not isinstance(result_type, ColumnType):
result_type = ColumnType(type=UnknownType())
new_columns.append(
DataFrameType.Column(
index=column.index,
name=column.name,
type=result_type,
)
)
return DataFrameType(columns=new_columns)
@method()
def kurt(self, call: Call) -> Type:
return self._aggregate(call, "kurt")
@method()
def max(self, call: Call) -> Type:
return self._aggregate(call, "max")
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(call, "mean")
@method()
def median(self, call: Call) -> Type:
return self._aggregate(call, "median")
@method()
def min(self, call: Call) -> Type:
return self._aggregate(call, "min")
@method()
def prod(self, call: Call) -> Type:
return self._aggregate(call, "prod")
@method()
def std(self, call: Call) -> Type:
return self._aggregate(call, "std")
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(call, "sum")
@method()
def var(self, call: Call) -> Type:
return self._aggregate(call, "var")

View File

@@ -1,487 +0,0 @@
from __future__ import annotations
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.dispatcher import CallResult
from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import (
ColumnType,
DataFrameType,
FrameGroupBy,
Function,
OverloadedFunction,
TopType,
Type,
UnknownType,
unfold_type,
)
if TYPE_CHECKING:
from midas.checker.python import TypedExpr
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
call_expr: p.Expr
frame: DataFrameType
frame_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.frame_expr, self.frame)
class FrameMethodRegistry(MethodRegistry[Call]):
def _get_method_result(
self,
call: Call,
column1: ColumnType,
column2: ColumnType,
method: str,
) -> ColumnType:
"""Get the result of calling a method on a column, passing a second
This function delegates to the main typer the resolution of the method
member, as well as computing the result type. Because we don't have any
AST expression for the individual columns, the frame expressions are
used instead.
Args:
call (Call): the call that triggered this resolution
column1 (ColumnType): the first column, i.e. left operand
column2 (ColumnType): the second column, i.e. right operand
method (str): the method name
Returns:
ColumnType: the resulting column.
If the operation is invalid / doesn't exist,
`ColumnType(type=UnknownType())` is returned
"""
result: Type = self.typer.result_of_binary_op(
location=call.location,
expr=call.call_expr,
left=(call.frame_expr, column1),
right=(call.positional[0][0], column2),
method=method,
)
if not isinstance(result, ColumnType):
return ColumnType(type=UnknownType())
return result
def _element_binary_op(self, call: Call, method: str) -> DataFrameType:
"""Compute the result of an element-wise binary operation
This function delegates to the matching columns for computing resulting
types. Any column only present in one of the frames is forwarded as a
generic `ColumnType(type=UnknownType())`. Columns only in the second
frame are append at the end of the schema.
Args:
call (Call): the call that triggered this resolution
method (str): the method name
Returns:
DataFrameType: the resulting frame type
"""
new_columns: list[DataFrameType.Column] = []
by_name: dict[str, DataFrameType.Column] = {}
frame2: Optional[DataFrameType] = None
# Get map of operand's columns by name, if there is at least 1 operand, which is a dataframe
if len(call.positional) != 0:
operand: TypedExpr = call.positional[0]
unfolded_other: Type = unfold_type(operand[1])
if isinstance(unfolded_other, DataFrameType):
frame2 = unfolded_other
by_name = {
col.name: col for col in frame2.columns if col.name is not None
}
# Compute new schema:
# Step 1: for all columns in frame1:
# - if present in frame2 -> delegate operation to columns
# - if not -> add to schema as unknown
in_frame1: set[str] = set()
for column in call.frame.columns:
if column.name is not None:
in_frame1.add(column.name)
col_type1: ColumnType = column.type
col_type: ColumnType = ColumnType(type=UnknownType())
if column.name in by_name:
column2 = by_name[column.name]
col_type2: ColumnType = column2.type
col_type = self._get_method_result(call, col_type1, col_type2, method)
new_column = DataFrameType.Column(
index=column.index,
name=column.name,
type=col_type,
)
new_columns.append(new_column)
# Step 2: for all columns in frame2
# - if not in frame1 -> add to schema as unknown
if frame2 is not None:
for column in frame2.columns:
if column.name in in_frame1:
continue
new_columns.append(
DataFrameType.Column(
index=len(new_columns),
name=column.name,
type=ColumnType(type=UnknownType()),
)
)
return DataFrameType(columns=new_columns)
def _element_wise(self, call: Call, method: str) -> Type:
# TODO: support scalar, sequence, Series, dict operand
# Build signature with new schema and generic operand
signature = Function(
args=[
Function.Argument(
pos=0,
name="other",
type=DataFrameType(columns=[]),
required=True,
),
],
returns=self._element_binary_op(call, method),
)
# Map arguments and compute result type
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
if result.is_valid:
self._assert_same_length(
call.call_expr, call.frame_expr, call.positional[0][0]
)
return result.result
@method("add", "__add__")
def add(self, call: Call) -> Type:
return self._element_wise(call, "__add__")
@method("sub", "__sub__")
def sub(self, call: Call) -> Type:
return self._element_wise(call, "__sub__")
@method("mul", "__mul__")
def mul(self, call: Call) -> Type:
return self._element_wise(call, "__mul__")
@method("div", "truediv", "__truediv__")
def truediv(self, call: Call) -> Type:
return self._element_wise(call, "__truediv__")
@method("floordiv", "__floordiv__")
def floordiv(self, call: Call) -> Type:
return self._element_wise(call, "__floordiv__")
@method("mod", "__mod__")
def mod(self, call: Call) -> Type:
return self._element_wise(call, "__mod__")
@method("pow", "__pow__")
def pow(self, call: Call) -> Type:
return self._element_wise(call, "__pow__")
@method("lt", "__lt__")
def lt(self, call: Call) -> Type:
return self._element_wise(call, "__lt__")
@method("gt", "__gt__")
def gt(self, call: Call) -> Type:
return self._element_wise(call, "__gt__")
@method("le", "__le__")
def le(self, call: Call) -> Type:
return self._element_wise(call, "__le__")
@method("ge", "__ge__")
def ge(self, call: Call) -> Type:
return self._element_wise(call, "__ge__")
@method("ne", "__ne__")
def ne(self, call: Call) -> Type:
return self._element_wise(call, "__ne__")
@method("eq", "__eq__")
def eq(self, call: Call) -> Type:
return self._element_wise(call, "__eq__")
def _aggregate(self, call: Call, kwargs: list[Function.Argument] = []) -> Type:
with_axis = Function(
kw_args=[
Function.Argument(
pos=0,
name="axis",
type=self.types.get_type("int"),
required=False,
),
*kwargs,
],
returns=ColumnType(type=TopType()),
)
without_axis = Function(
kw_args=[
Function.Argument(
pos=0,
name="axis",
type=self.types.get_type("None"),
required=True,
),
*kwargs,
],
returns=TopType(),
)
overload = OverloadedFunction(
overloads=[
with_axis,
without_axis,
]
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=overload,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method("kurtosis", "kurt")
def kurtosis(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def max(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def median(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def min(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def mode(self, call: Call) -> Type:
return self._aggregate(call)
@method("product", "prod")
def product(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def std(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Argument(
pos=1,
name="ddof",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def var(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Argument(
pos=1,
name="var",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def head(self, call: Call) -> Type:
signature = Function(
args=[
Function.Argument(
pos=0,
name="n",
type=self.types.get_type("int"),
required=False,
),
],
returns=call.frame,
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method()
def tail(self, call: Call) -> Type:
signature = Function(
args=[
Function.Argument(
pos=0,
name="n",
type=self.types.get_type("int"),
required=False,
),
],
returns=call.frame,
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method()
def groupby(self, call: Call) -> Type:
bool_: Type = self.types.get_type("bool")
function: Function = Function(
args=[
Function.Argument(
pos=0,
name="by",
type=TopType(),
required=False,
),
Function.Argument(
pos=1,
name="level",
type=TopType(),
required=False,
),
],
kw_args=[
Function.Argument(
pos=2,
name="as_index",
type=bool_,
required=False,
),
Function.Argument(
pos=3,
name="sort",
type=bool_,
required=False,
),
Function.Argument(
pos=4,
name="group_keys",
type=bool_,
required=False,
),
Function.Argument(
pos=5,
name="observed",
type=bool_,
required=False,
),
Function.Argument(
pos=6,
name="dropna",
type=bool_,
required=False,
),
],
returns=FrameGroupBy(frame=call.frame),
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=function,
positional=call.positional,
keywords=call.keywords,
)
return result.result
def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr):
func_name: str = "__midas_frame_same_length__"
# Efficiently compute length
# https://stackoverflow.com/a/15943975/11109181
def len_of_df(df: ast.expr) -> ast.expr:
return ast.Call(
func=ast.Name(id="len"),
args=[
ast.Attribute(
value=df,
attr="index",
)
],
keywords=[],
)
self.assertions.define(
func_name,
ast.FunctionDef(
name=func_name,
args=ast.arguments(
posonlyargs=[],
args=[
ast.arg(arg="frame1"),
ast.arg(arg="frame2"),
],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Return(
value=ast.Compare(
left=len_of_df(ast.Name(id="frame1")),
ops=[ast.Eq()],
comparators=[len_of_df(ast.Name(id="frame2"))],
)
)
],
decorator_list=[],
),
)
self.assertions.add(
bound_expr=call_expr,
inputs=[frame1, frame2],
builder=lambda f1, f2: ast.Call(
func=ast.Name(id=func_name),
args=[f1, f2],
keywords=[],
),
message="DataFrames must have the same length",
)

View File

@@ -1,100 +0,0 @@
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Optional,
Protocol,
Self,
TypeVar,
)
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.dispatcher import CallDispatcher
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import Type, UnknownType
from midas.generator.collector import AssertionCollector
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
class _MethodRegistryMeta(type):
_methods: dict[str, Callable[..., Type]] = {}
def __new__(
cls,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
):
new_class = super().__new__(cls, name, bases, namespace)
new_class._methods = {}
for attr in namespace.values():
if callable(attr) and hasattr(attr, "__method_names__"):
for name in attr.__method_names__: # type: ignore
new_class._methods[name] = attr # type: ignore
return new_class
class MethodCall(Protocol):
@property
def location(self) -> Location: ...
@property
def call_expr(self) -> p.Expr: ...
@property
def subject(self) -> TypedExpr: ...
T = TypeVar("T", bound=MethodCall)
class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
@property
def reporter(self) -> FileReporter:
return self.typer.reporter
@property
def types(self) -> TypesRegistry:
return self.typer.types
@property
def dispatcher(self) -> CallDispatcher[p.Expr]:
return self.typer.dispatcher
@property
def assertions(self) -> AssertionCollector:
return self.typer.assertions
def call(self, method: str, call: T) -> Type:
func: Optional[Callable[[Self, T], Type]] = self._methods.get(method)
if func is None:
self.reporter.warning(
call.location, f"Unknown method {method} on {call.subject[1]}"
)
return UnknownType()
return func(self, call)
_Self = TypeVar("_Self", bound=MethodRegistry[Any])
Method = Callable[[_Self, T], Type]
def method(*names: str) -> Callable[[Method[_Self, T]], Method[_Self, T]]:
def wrapper(func: Method[_Self, T]) -> Method[_Self, T]:
names_: tuple[str, ...] = names
if len(names_) == 0:
names_ = (func.__name__,)
setattr(func, "__method_names__", names_)
return func
return wrapper

View File

@@ -6,25 +6,27 @@ from typing import Optional
import midas.ast.midas as m
from midas.ast.location import Location
from midas.checker.builtins import define_builtins
from midas.checker.dispatcher import CallDispatcher, CallResult
from midas.checker.environment import Environment
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import (
AliasType,
AppliedType,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
DerivedType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
Predicate,
Type,
TypeVar,
UnknownType,
unfold_type,
)
from midas.checker.variance import VarianceInferrer
from midas.lexer.midas import MidasLexer
@@ -39,6 +41,9 @@ class TypedParamSpec:
kw: list[Function.Argument]
TypedExpr = tuple[m.Expr, Type]
class ReturnException(Exception):
pass
@@ -62,11 +67,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
self.logger: logging.Logger = logging.getLogger("MidasTyper")
self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types
self.dispatcher: CallDispatcher[m.Expr] = CallDispatcher[m.Expr](
self.types, self.reporter
)
self.types: TypesRegistry = types
self._local_variables: dict[str, TypeVar] = {}
self._predicate_params: dict[str, Type] = {}
@@ -81,14 +83,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
self._preamble: Environment = Preamble(self.types)
def set_reporter(self, reporter: FileReporter):
self.reporter = reporter
self.dispatcher.set_reporter(reporter)
def process(self, source: str, path: Optional[str]):
reporter: FileReporter = self.reporter.for_file(path)
self.set_reporter(reporter)
self.reporter = self.reporter.for_file(path)
lexer: MidasLexer = MidasLexer(source)
tokens: list[Token] = lexer.process()
parser: MidasParser = MidasParser(tokens)
@@ -158,18 +154,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
if len(params) != 0:
type = GenericType(name=name, params=params, body=type)
else:
type = DerivedType(name=name, type=type)
type = AliasType(name=name, type=type)
self.types.define_type(name, type)
self._local_variables.clear()
self._current_name = None
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
name: str = stmt.name.lexeme
self._current_name = name
type: Type = stmt.type.accept(self)
self.types.define_type(name, type)
self._current_name = None
def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ...
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
@@ -263,13 +252,13 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
)
return UnknownType()
result: CallResult = self.dispatcher.get_result(
location=location,
callee=operation,
positional=[(right_expr, right)],
keywords={},
result: Optional[Type] = self._get_call_result(
location,
operation,
[(right_expr, right)],
{},
)
return result.result
return result or UnknownType()
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
@@ -289,29 +278,31 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
)
return UnknownType()
result: CallResult = self.dispatcher.get_result(
location=expr.location,
callee=operation,
positional=[],
keywords={},
result: Optional[Type] = self._get_call_result(
expr.location,
operation,
[],
{},
)
return result.result
return result or UnknownType()
def visit_call_expr(self, expr: m.CallExpr) -> Type:
callee: Type = expr.callee.accept(self)
positional: list[tuple[m.Expr, Type]] = [
positional: list[TypedExpr] = [
(arg, self.type_of(arg)) for arg in expr.arguments
]
keywords: dict[str, tuple[m.Expr, Type]] = {
keywords: dict[str, TypedExpr] = {
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
}
result: CallResult = self.dispatcher.get_result(
location=expr.location,
callee=callee,
positional=positional,
keywords=keywords,
return (
self._get_call_result(
expr.location,
callee,
positional,
keywords,
)
or UnknownType()
)
return result.result
def visit_get_expr(self, expr: m.GetExpr) -> Type:
object: Type = expr.expr.accept(self)
@@ -435,3 +426,343 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
self._local_variables[name] = var
vars.append(var)
return vars
def _get_call_result(
self,
location: Location,
callee: Type,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> Optional[Type]:
"""Get the result type of a function call
If the function has overloads, the function will try to resolve the
appropriate signature.
Argument types are matched to the defined parameters.
The function doesn't take the raw expression as a parameter to accommodate
for desugared calls such as for operators.
Args:
location (Location): the call location
callee (Type): the called function
positional (list[TypedExpr]): the list positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Type: the return type of the call, or `None` if either
the call is invalid or no overload matched the arguments uniquely
"""
match callee:
case Function() as function:
valid: bool
mapped: list[MappedArgument]
valid, mapped = self.map_call_arguments(
function, location, positional, keywords
)
valid = valid and self._are_arguments_valid(mapped, report_errors)
if not valid:
return None
return function.returns
case OverloadedFunction(overloads=overloads):
function = self._match_overload(
overloads, location, positional, keywords, report_errors
)
if function is None:
return None
return function.returns
case AppliedType(body=body):
return self._get_call_result(
location, body, positional, keywords, report_errors
)
case UnknownType():
return UnknownType()
case _:
if report_errors:
self.reporter.error(location, f"{callee} is not callable")
return None
def _are_arguments_valid(
self,
arguments: list[MappedArgument],
report_errors: bool = True,
) -> bool:
"""Check whether the passed argument types correspond to their matched parameter definitions
Args:
arguments (list[MappedArgument]): the list of argument/parameter pairs
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
bool: True if all arguments fit the matching parameter definitions, False otherwise
"""
valid: bool = True
for arg in arguments:
if not self.types.is_subtype(arg.type, arg.argument.type):
if report_errors:
self.reporter.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
)
valid = False
return valid
def _match_overload(
self,
overloads: list[Type],
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> Optional[Function]:
"""Try and resolve the appropriate overload for the given arguments
Args:
overloads (list[Type]): the list of possible overloads
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keywords arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Optional[Function]: the resolved function signature if it can be
determined unambiguously, or `None`.
"""
candidates: list[OverloadCandidate] = []
for overload in overloads:
function: Type = unfold_type(overload)
if not isinstance(function, Function):
if report_errors:
self.logger.error(
f"Overload is not a function: {overload} is {function}"
)
continue
valid, mapped = self.map_call_arguments(
function=function,
location=location,
positional=positional,
keywords=keywords,
report_errors=False,
)
if valid and self._are_arguments_valid(mapped, report_errors=False):
candidates.append(
OverloadCandidate(
function=function,
mapped=mapped,
)
)
pos_types: str = ", ".join(str(type) for _, type in positional)
kw_types: str = ", ".join(
f"{name}: {type}" for name, (_, type) in keywords.items()
)
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
n_candidates: int = len(candidates)
# Exactly 1 match -> return it
if n_candidates == 1:
return candidates[0].function
# No match -> invalid call
if n_candidates == 0:
overloads_str: str = ", ".join(map(str, overloads))
if report_errors:
self.reporter.error(
location,
f"No matching overload in [{overloads_str}] {for_args}",
)
return None
# Multiple matches -> see if one <: all others (more specific)
for i1, c1 in enumerate(candidates):
mapped1: list[MappedArgument] = c1.mapped
best_match: bool = True
for i2, c2 in enumerate(candidates):
if i1 == i2:
continue
mapped2: list[MappedArgument] = c2.mapped
if not self._are_mapped_subtypes(mapped1, mapped2):
best_match = False
break
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
if best_match:
return c1.function
candidates_str: str = ", ".join(
str(candidate.function) for candidate in candidates
)
if report_errors:
self.reporter.error(
location,
f"Multiple matching overloads {for_args}: {candidates_str}",
)
return None
def map_call_arguments(
self,
function: Function,
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> tuple[bool, list[MappedArgument]]:
"""Map call arguments to a function's parameters as defined in its signature
This method maps positional-only, keyword-only and mixed parameter definitions
with the arguments passed at the call site
Any mismatched, missing or unexpected argument is reported as a diagnostic,
unless `report_errors` is set to `False`
Args:
function (Function): the function definition
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
tuple[bool, list[MappedArgument]]: a boolean reporting whether
the call is valid and the list of mapped arguments
"""
set_args: set[str] = set()
required_positional: list[str] = [
arg.name for arg in function.pos_args + function.args if arg.required
]
required_keyword: list[str] = [
arg.name for arg in function.kw_args if arg.required
]
mapped: list[MappedArgument] = []
pos_params: list[Function.Argument] = list(function.pos_args)
mixed_params: list[Function.Argument] = list(function.args)
kw_params: dict[str, Function.Argument] = {
arg.name: arg for arg in function.kw_args
}
valid_call: bool = True
# TODO: handle *args and **kwargs sinks
for arg in positional:
param: Function.Argument
if len(pos_params) != 0:
param = pos_params.pop(0)
elif len(mixed_params) != 0:
param = mixed_params.pop(0)
else:
if report_errors:
self.reporter.error(
arg[0].location, "Too many positional arguments"
)
valid_call = False
break
name: str = param.name
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
kw_params.update({arg.name: arg for arg in mixed_params})
for name, arg in keywords.items():
param: Function.Argument
if name not in kw_params:
if report_errors:
if name in set_args:
self.reporter.error(
arg[0].location, f"Multiple values for argument '{name}'"
)
else:
self.reporter.error(
arg[0].location, f"Unknown keyword argument '{name}'"
)
valid_call = False
continue
param = kw_params.pop(name)
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
def join_args(args: list[str]) -> str:
args = list(map(lambda a: f"'{a}'", args))
if len(args) == 0:
return ""
if len(args) == 1:
return args[0]
return ", ".join(args[:-1]) + " and " + args[-1]
if len(required_positional) != 0:
plural: str = "" if len(required_positional) == 1 else "s"
args: str = join_args(required_positional)
if report_errors:
self.reporter.error(
location,
f"Missing required positional argument{plural}: {args}",
)
valid_call = False
if len(required_keyword) != 0:
plural: str = "" if len(required_keyword) == 1 else "s"
args: str = join_args(required_keyword)
if report_errors:
self.reporter.error(
location,
f"Missing required keyword argument{plural}: {args}",
)
valid_call = False
return valid_call, mapped
def _are_mapped_subtypes(
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
) -> bool:
"""Check whether the given argument mappings are subtype/supertype of one another
This function checks whether the argument mappings `mapped1` are subtypes
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
of the corresponding parameter in `mapped2`, `False` is returned.
This is used to check whether a given overload is
a more specific function/ a subtype of another.
Args:
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
Returns:
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
"""
by_expr: dict[m.Expr, Type] = {}
for arg in mapped1:
by_expr[arg.expr] = arg.argument.type
for arg in mapped2:
type2: Type = arg.argument.type
type1: Type = by_expr[arg.expr]
if not self.types.is_subtype(type1, type2):
return False
return True

View File

@@ -41,7 +41,7 @@ PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = {
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
TokenType.PLUS: "__add__",
# TokenType.PLUS: "__add__",
TokenType.MINUS: "__sub__",
TokenType.STAR: "__mul__",
TokenType.SLASH: "__truediv__",

View File

@@ -1,17 +1,9 @@
from dataclasses import dataclass
from typing import Any, Callable, Optional
from typing import Callable, Optional
from midas.checker.environment import Environment
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
Function,
GenericType,
OverloadedFunction,
TopType,
Type,
TypeVar,
UnitType,
)
from midas.checker.types import Function, GenericType, TopType, Type, TypeVar, UnitType
@dataclass(frozen=True)
@@ -25,7 +17,7 @@ class Preamble(Environment):
def __init__(self, types: TypesRegistry) -> None:
super().__init__()
self._types: TypesRegistry = types
self._python_funcs: dict[str, Callable[..., Any]] = {}
self._python_funcs: dict[str, Callable] = {}
self._def_type_constructor("object", object)
self._def_type_constructor("float", float)
@@ -42,7 +34,7 @@ class Preamble(Environment):
# TODO: use sink
self._def_function(
name="print",
pos=[Param("object", TopType(), required=False)],
pos=[Param("object", TopType())],
returns=UnitType(),
py_function=print,
)
@@ -72,48 +64,11 @@ class Preamble(Environment):
pos=[Param("prompt", TopType(), required=False)],
returns=self._types.get_type("str"),
)
self._def_function(
name="len",
pos=[Param("object", TopType())],
returns=self._types.get_type("int"),
)
T = TypeVar(name="T", bound=None)
self._def_overloads(
name="max",
py_function=max,
signatures=[
(
[Param("arg1", T), Param("arg2", T)],
[],
[],
T,
[T],
),
([Param("iterable", self._list_of(T))], [], [], T, [T]),
],
)
self._def_overloads(
name="min",
py_function=min,
signatures=[
(
[Param("arg1", T), Param("arg2", T)],
[],
[],
T,
[T],
),
([Param("iterable", self._list_of(T))], [], [], T, [T]),
],
)
def _list_of(self, item_type: Type) -> Type:
return self._types.apply_generic(self._types.get_type("list"), [item_type])
def _list_of(self, item_type: str | Type) -> Type:
return self._types.list_of(item_type)
def _def_type_constructor(
self, name: str, py_function: Optional[Callable[..., Any]] = None
):
def _def_type_constructor(self, name: str, py_function: Optional[Callable] = None):
# TODO: more specific arg types
self._def_function(
name=name,
@@ -166,7 +121,7 @@ class Preamble(Environment):
kw: list[Param] = [],
returns: Type = UnitType(),
type_vars: list[TypeVar] = [],
py_function: Optional[Callable[..., Any]] = None,
py_function: Optional[Callable] = None,
):
function: Type = self._make_function(
name=name,
@@ -180,31 +135,5 @@ class Preamble(Environment):
if py_function is not None:
self._python_funcs[name] = py_function
def _def_overloads(
self,
*,
name: str,
signatures: list[
tuple[list[Param], list[Param], list[Param], Type, list[TypeVar]]
],
py_function: Optional[Callable[..., Any]] = None,
):
overloads: list[Type] = []
for pos, mixed, kw, returns, type_vars in signatures:
overloads.append(
self._make_function(
name=name,
pos=pos,
mixed=mixed,
kw=kw,
returns=returns,
type_vars=type_vars,
)
)
function: Type = OverloadedFunction(overloads=overloads)
self.define(name, function)
if py_function is not None:
self._python_funcs[name] = py_function
def get_py_func(self, name: str) -> Optional[Callable[..., Any]]:
def get_py_func(self, name: str) -> Optional[Callable]:
return self._python_funcs.get(name)

View File

@@ -6,11 +6,9 @@ from typing import Any, Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasPrinter
from midas.checker.dispatcher import CallDispatcher, CallResult
from midas.checker.environment import Environment
from midas.checker.evaluator import Evaluator
from midas.checker.frames.column_manager import ColumnManager
from midas.checker.frames.frame_manager import FrameManager
from midas.checker.frames import FrameManager
from midas.checker.operators import (
PY_COMPARATOR_METHODS,
PY_OPERATOR_METHODS,
@@ -21,17 +19,15 @@ from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ColumnGroupBy,
ColumnType,
ConstraintType,
DataFrameType,
DerivedType,
FrameGroupBy,
Function,
GenericType,
TopType,
OverloadedFunction,
TupleType,
Type,
TypeVar,
@@ -40,7 +36,7 @@ from midas.checker.types import (
Variance,
unfold_type,
)
from midas.generator.collector import AssertionCollector
from midas.checker.unifier import Unifier
from midas.parser.python import PythonParser
from midas.utils import TypedAST
@@ -84,24 +80,14 @@ class PythonTyper(
self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types
self.frame_mgr: FrameManager = FrameManager(self)
self.column_mgr: ColumnManager = ColumnManager(self)
self.global_env: Environment = Preamble(self.types)
self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {}
self.judgements: list[tuple[p.Expr, Type]] = []
self.evaluated_casts: list[p.CastExpr] = []
self.dispatcher: CallDispatcher[p.Expr] = CallDispatcher[p.Expr](
self.types, self.reporter
)
self.assertions: AssertionCollector = AssertionCollector()
def set_reporter(self, reporter: FileReporter):
self.reporter = reporter
self.dispatcher.set_reporter(self.reporter)
def process(self, source: str, path: Optional[str]) -> TypedAST:
reporter: FileReporter = self.reporter.for_file(path)
self.set_reporter(reporter)
self.reporter = self.reporter.for_file(path)
tree: ast.Module = ast.parse(source, filename=path or "<unknown>")
parser = PythonParser()
@@ -120,7 +106,6 @@ class PythonTyper(
stmts=stmts,
judgements=self.judgements,
evaluated_casts=self.evaluated_casts,
assertions=self.assertions,
)
def judge(self, expr: p.Expr, type: Type):
@@ -217,69 +202,32 @@ class PythonTyper(
def call_method(
self,
location: Location,
call_expr: p.Expr,
obj: TypedExpr,
obj: Type,
method_name: str,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
unfolded: Type = unfold_type(obj[1])
) -> Optional[Type]:
unfolded: Type = unfold_type(obj)
match unfolded:
case DataFrameType():
return self.frame_mgr.call(
method=method_name,
location=location,
call_expr=call_expr,
frame=unfolded,
frame_expr=obj[0],
positional=positional,
keywords=keywords,
)
case FrameGroupBy():
return self.frame_mgr.groupby_call(
method=method_name,
location=location,
call_expr=call_expr,
groupby=unfolded,
groupby_expr=obj[0],
positional=positional,
keywords=keywords,
)
case ColumnType():
return self.column_mgr.call(
method=method_name,
location=location,
call_expr=call_expr,
column=unfolded,
column_expr=obj[0],
positional=positional,
keywords=keywords,
)
case ColumnGroupBy():
return self.column_mgr.groupby_call(
method=method_name,
location=location,
call_expr=call_expr,
groupby=unfolded,
groupby_expr=obj[0],
positional=positional,
keywords=keywords,
)
method: Optional[Type] = self.types.lookup_member(obj[1], method_name)
method: Optional[Type] = self.types.lookup_member(obj, method_name)
if method is None:
raise UndefinedMethodException
result: CallResult = self.dispatcher.get_result(
location=location,
callee=method,
positional=positional,
keywords=keywords,
return self._get_call_result(
location,
method,
positional,
keywords,
)
return result.result
def is_subtype(self, type1: Type, type2: Type) -> bool:
return self.types.is_subtype(type1, type2)
@@ -465,16 +413,13 @@ class PythonTyper(
value_type: Type,
):
var_type: Type = self.type_of(var)
unfolded_type: Type = unfold_type(var_type)
# TODO: what happens if type is an alias of a dataframe type
match unfolded_type:
match var_type:
case DataFrameType() as frame:
new_type: Type = self.frame_mgr.assign(
self.reporter, location, frame, index, value_type
)
self.env.assign(var.name, new_type)
case UnknownType():
return
case _:
self.reporter.error(
location,
@@ -494,10 +439,8 @@ class PythonTyper(
# print(m) # <- m is still defined
test_type: Type = self.type_of(stmt.test)
if (
not self.types.is_subtype(test_type, self.types.get_type("bool"))
and test_type != UnknownType()
):
# TODO Allow subtypes or any type
if test_type != self.types.get_type("bool"):
self.reporter.error(
stmt.test.location, f"If test must be a boolean, got {test_type}"
)
@@ -513,16 +456,13 @@ class PythonTyper(
pass
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
item_type: Type = UnknownType()
iterator_type: Type = self.type_of(stmt.iterator)
if iterator_type != UnknownType():
maybe_item_type = self._get_iterator_type(stmt.iterator, iterator_type)
if maybe_item_type is None:
self.reporter.error(
stmt.iterator.location, f"{iterator_type} is not iterable"
)
else:
item_type = maybe_item_type
item_type: Optional[Type] = self._get_iterator_type(stmt.iterator)
if item_type is None:
iterator_type: Type = self.compute_type(stmt.iterator)
self.reporter.error(
stmt.iterator.location, f"{iterator_type} is not iterable"
)
item_type = UnknownType()
self._assign(stmt.location, stmt.target, item_type)
self.judge(stmt.target, item_type)
@@ -543,15 +483,7 @@ class PythonTyper(
)
return UnknownType()
left: Type = self.type_of(expr.left)
right: Type = self.type_of(expr.right)
return self.result_of_binary_op(
expr.location,
expr,
(expr.left, left),
(expr.right, right),
method,
)
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
method: Optional[str] = PY_COMPARATOR_METHODS.get(expr.operator.__class__)
@@ -562,40 +494,26 @@ class PythonTyper(
)
return UnknownType()
left: Type = self.type_of(expr.left)
right: Type = self.type_of(expr.right)
return self.result_of_binary_op(
expr.location,
expr,
(expr.left, left),
(expr.right, right),
method,
)
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
def result_of_binary_op(
self,
location: Location,
expr: p.Expr,
left: TypedExpr,
right: TypedExpr,
method: str,
def _visit_binary_expr(
self, location: Location, left_expr: p.Expr, right_expr: p.Expr, method: str
) -> Type:
left: Type = self.type_of(left_expr)
right: Type = self.type_of(right_expr)
result: Optional[Type]
try:
return self.call_method(
location=location,
call_expr=expr,
obj=left,
method_name=method,
positional=[right],
keywords={},
)
result = self.call_method(location, left, method, [(right_expr, right)], {})
except UndefinedMethodException:
self.reporter.error(
location,
f"Undefined operation {method} between {left[1]} and {right[1]}",
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
return result or UnknownType()
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
method: Optional[str] = PY_UNARY_METHODS.get(expr.operator.__class__)
if method is None:
@@ -607,15 +525,9 @@ class PythonTyper(
operand: Type = self.type_of(expr.right)
result: Optional[Type]
try:
return self.call_method(
location=expr.location,
call_expr=expr,
obj=(expr.right, operand),
method_name=method,
positional=[],
keywords={},
)
result = self.call_method(expr.location, operand, method, [], {})
except UndefinedMethodException:
self.reporter.error(
expr.location,
@@ -623,6 +535,8 @@ class PythonTyper(
)
return UnknownType()
return result or UnknownType()
def visit_call_expr(self, expr: p.CallExpr) -> Type:
match expr.callee:
case p.VariableExpr(name="TypeVar"):
@@ -638,37 +552,32 @@ class PythonTyper(
match expr.callee:
case p.GetExpr(object=obj, name=method):
obj_type: Type = self.type_of(obj)
return self.call_method(
location=expr.location,
call_expr=expr,
obj=(obj, obj_type),
method_name=method,
positional=positional,
keywords=keywords,
)
unfolded: Type = unfold_type(obj_type)
if isinstance(unfolded, DataFrameType):
return self.frame_mgr.call(
method,
expr.location,
unfolded,
positional,
keywords,
)
callee: Type = self.type_of(expr.callee)
result: CallResult = self.dispatcher.get_result(
location=expr.location,
callee=callee,
positional=positional,
keywords=keywords,
return (
self._get_call_result(
location=expr.location,
callee=callee,
positional=positional,
keywords=keywords,
)
or UnknownType()
)
return result.result
def visit_get_expr(self, expr: p.GetExpr) -> Type:
object: Type = self.type_of(expr.object)
member: Optional[Type] = self.types.lookup_member(object, expr.name)
if member is None:
match object:
case DataFrameType():
member = self.frame_mgr.get_attribute(object, expr.name)
case ColumnType():
member = self.column_mgr.get_attribute(object, expr.name)
if member is None:
self.reporter.warning(
self.reporter.error(
expr.location, f"Unknown member '{expr.name}' of {object}"
)
return UnknownType()
@@ -729,10 +638,7 @@ class PythonTyper(
test_type: Type = self.type_of(expr.test)
# TODO Allow subtypes or any type
if (
not self.is_subtype(test_type, self.types.get_type("bool"))
and test_type != UnknownType()
):
if test_type != self.types.get_type("bool"):
self.reporter.error(
expr.test.location, f"If test must be a boolean, got {test_type}"
)
@@ -761,7 +667,7 @@ class PythonTyper(
if len(item_types) == 1:
item_type: Type = item_types[0]
return self.types.apply_generic(list_type, [item_type])
self.reporter.warning(
self.reporter.error(
expr.location,
f"Heterogeneous list items: [{', '.join(map(str, item_types))}]",
)
@@ -793,7 +699,7 @@ class PythonTyper(
if len(key_types) == 1:
key_type = key_types[0]
else:
self.reporter.warning(
self.reporter.error(
expr.location,
f"Heterogeneous dict keys: [{', '.join(map(str, key_types))}]",
)
@@ -801,7 +707,7 @@ class PythonTyper(
if len(value_types) == 1:
value_type = value_types[0]
else:
self.reporter.warning(
self.reporter.error(
expr.location,
f"Heterogeneous dict values: [{', '.join(map(str, value_types))}]",
)
@@ -815,8 +721,6 @@ class PythonTyper(
return self._visit_tuple_subscript(unfolded, expr)
case DataFrameType():
return self._visit_frame_subscript(unfolded, expr)
case FrameGroupBy():
return self._visit_frame_groupby_subscript(unfolded, expr)
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
if operation is None:
@@ -827,22 +731,14 @@ class PythonTyper(
return UnknownType()
index: Type = self.type_of(expr.index)
result: CallResult = self.dispatcher.get_result(
location=expr.location,
callee=operation,
positional=[(expr.index, index)],
keywords={},
return (
self._get_call_result(expr.location, operation, [(expr.index, index)], {})
or UnknownType()
)
return result.result
def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
return self.types.get_type("slice")
def visit_tuple_expr(self, expr: p.TupleExpr) -> Type:
return TupleType(
items=tuple(self.type_of(item) for item in expr.items),
)
def visit_raw_expr(self, expr: p.RawExpr) -> Type:
return UnknownType()
@@ -854,9 +750,9 @@ class PythonTyper(
self.reporter.warning(node.location, f"Unknown type '{node.base}'")
return UnknownType()
if len(node.args) != 0:
args: list[Type] = [self.resolve_type_expr(arg) for arg in node.args]
return self.types.apply_generic(base, args)
if node.param is not None:
param: Type = self.resolve_type_expr(node.param)
return self.types.apply_generic(base, [param])
return base
def visit_constraint_type(self, node: p.ConstraintType) -> Type:
@@ -884,24 +780,393 @@ class PythonTyper(
]
)
def _get_iterator_type(self, expr: p.Expr, type: Type) -> Optional[Type]:
def _get_call_result(
self,
location: Location,
callee: Type,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> Optional[Type]:
"""Get the result type of a function call
If the function has overloads, the function will try to resolve the
appropriate signature.
Argument types are matched to the defined parameters.
The function doesn't take the raw expression as a parameter to accommodate
for desugared calls such as for operators.
Args:
location (Location): the call location
callee (Type): the called function
positional (list[TypedExpr]): the list positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Type: the return type of the call, or `None` if either
the call is invalid or no overload matched the arguments uniquely
"""
match callee:
case Function() as function:
valid: bool
mapped: list[MappedArgument]
valid, mapped = self.map_call_arguments(
function, location, positional, keywords
)
valid = valid and self._are_arguments_valid(mapped, report_errors)
if not valid:
return None
return function.returns
case OverloadedFunction(overloads=overloads):
function = self._match_overload(
overloads, location, positional, keywords, report_errors
)
if function is None:
return None
return function.returns
case AppliedType(body=body):
return self._get_call_result(
location, body, positional, keywords, report_errors
)
case UnknownType():
return UnknownType()
case AliasType(type=base):
return self._get_call_result(
location, base, positional, keywords, report_errors
)
case GenericType():
unifier: Unifier = Unifier(self.types)
pos: list[Type] = [a[1] for a in positional]
kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()}
unified: Optional[Type] = unifier.unify_call(callee, pos, kw)
if unified is None:
if report_errors:
pos_str: str = ", ".join(str(t) for t in pos)
kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items())
self.reporter.error(
location,
f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}",
)
return None
return self._get_call_result(
location,
unified,
positional,
keywords,
report_errors,
)
case _:
if report_errors:
self.reporter.error(
location,
f"{callee} ({callee.__class__.__name__}) is not callable",
)
return None
def _are_arguments_valid(
self,
arguments: list[MappedArgument],
report_errors: bool = True,
) -> bool:
"""Check whether the passed argument types correspond to their matched parameter definitions
Args:
arguments (list[MappedArgument]): the list of argument/parameter pairs
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
bool: True if all arguments fit the matching parameter definitions, False otherwise
"""
valid: bool = True
for arg in arguments:
if not self.is_subtype(arg.type, arg.argument.type):
if report_errors:
self.reporter.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
)
valid = False
return valid
def _match_overload(
self,
overloads: list[Type],
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> Optional[Function]:
"""Try and resolve the appropriate overload for the given arguments
Args:
overloads (list[Type]): the list of possible overloads
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keywords arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Optional[Function]: the resolved function signature if it can be
determined unambiguously, or `None`.
"""
candidates: list[OverloadCandidate] = []
for overload in overloads:
function: Type = unfold_type(overload)
if not isinstance(function, Function):
if report_errors:
self.logger.error(
f"Overload is not a function: {overload} is {function}"
)
continue
valid, mapped = self.map_call_arguments(
function=function,
location=location,
positional=positional,
keywords=keywords,
report_errors=False,
)
if valid and self._are_arguments_valid(mapped, report_errors=False):
candidates.append(
OverloadCandidate(
function=function,
mapped=mapped,
)
)
pos_types: str = ", ".join(str(type) for _, type in positional)
kw_types: str = ", ".join(
f"{name}: {type}" for name, (_, type) in keywords.items()
)
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
n_candidates: int = len(candidates)
# Exactly 1 match -> return it
if n_candidates == 1:
return candidates[0].function
# No match -> invalid call
if n_candidates == 0:
overloads_str: str = ", ".join(map(str, overloads))
if report_errors:
self.reporter.error(
location,
f"No matching overload in [{overloads_str}] {for_args}",
)
return None
# Multiple matches -> see if one <: all others (more specific)
for i1, c1 in enumerate(candidates):
mapped1: list[MappedArgument] = c1.mapped
best_match: bool = True
for i2, c2 in enumerate(candidates):
if i1 == i2:
continue
mapped2: list[MappedArgument] = c2.mapped
if not self._are_mapped_subtypes(mapped1, mapped2):
best_match = False
break
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
if best_match:
return c1.function
candidates_str: str = ", ".join(
str(candidate.function) for candidate in candidates
)
if report_errors:
self.reporter.error(
location,
f"Multiple matching overloads {for_args}: {candidates_str}",
)
return None
def map_call_arguments(
self,
function: Function,
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> tuple[bool, list[MappedArgument]]:
"""Map call arguments to a function's parameters as defined in its signature
This method maps positional-only, keyword-only and mixed parameter definitions
with the arguments passed at the call site
Any mismatched, missing or unexpected argument is reported as a diagnostic,
unless `report_errors` is set to `False`
Args:
function (Function): the function definition
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
tuple[bool, list[MappedArgument]]: a boolean reporting whether
the call is valid and the list of mapped arguments
"""
set_args: set[str] = set()
required_positional: list[str] = [
arg.name for arg in function.pos_args + function.args if arg.required
]
required_keyword: list[str] = [
arg.name for arg in function.kw_args if arg.required
]
mapped: list[MappedArgument] = []
pos_params: list[Function.Argument] = list(function.pos_args)
mixed_params: list[Function.Argument] = list(function.args)
kw_params: dict[str, Function.Argument] = {
arg.name: arg for arg in function.kw_args
}
valid_call: bool = True
# TODO: handle *args and **kwargs sinks
for arg in positional:
param: Function.Argument
if len(pos_params) != 0:
param = pos_params.pop(0)
elif len(mixed_params) != 0:
param = mixed_params.pop(0)
else:
if report_errors:
self.reporter.error(
arg[0].location, "Too many positional arguments"
)
valid_call = False
break
name: str = param.name
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
kw_params.update({arg.name: arg for arg in mixed_params})
for name, arg in keywords.items():
param: Function.Argument
if name not in kw_params:
if report_errors:
if name in set_args:
self.reporter.error(
arg[0].location, f"Multiple values for argument '{name}'"
)
else:
self.reporter.error(
arg[0].location, f"Unknown keyword argument '{name}'"
)
valid_call = False
continue
param = kw_params.pop(name)
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
def join_args(args: list[str]) -> str:
args = list(map(lambda a: f"'{a}'", args))
if len(args) == 0:
return ""
if len(args) == 1:
return args[0]
return ", ".join(args[:-1]) + " and " + args[-1]
if len(required_positional) != 0:
plural: str = "" if len(required_positional) == 1 else "s"
args: str = join_args(required_positional)
if report_errors:
self.reporter.error(
location,
f"Missing required positional argument{plural}: {args}",
)
valid_call = False
if len(required_keyword) != 0:
plural: str = "" if len(required_keyword) == 1 else "s"
args: str = join_args(required_keyword)
if report_errors:
self.reporter.error(
location,
f"Missing required keyword argument{plural}: {args}",
)
valid_call = False
return valid_call, mapped
def _are_mapped_subtypes(
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
) -> bool:
"""Check whether the given argument mappings are subtype/supertype of one another
This function checks whether the argument mappings `mapped1` are subtypes
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
of the corresponding parameter in `mapped2`, `False` is returned.
This is used to check whether a given overload is
a more specific function/ a subtype of another.
Args:
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
Returns:
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
"""
by_expr: dict[p.Expr, Type] = {}
for arg in mapped1:
by_expr[arg.expr] = arg.argument.type
for arg in mapped2:
type2: Type = arg.argument.type
type1: Type = by_expr[arg.expr]
if not self.is_subtype(type1, type2):
return False
return True
def _get_iterator_type(self, expr: p.Expr) -> Optional[Type]:
# TODO: lookup __iter__
type: Type = self.type_of(expr)
getitem: Optional[Type] = self.types.lookup_member(type, "__getitem__")
if getitem is None:
return None
index: p.Expr = p.LiteralExpr(location=expr.location, value=0)
index_type: Type = self.compute_type(index)
result: CallResult = self.dispatcher.get_result(
result: Optional[Type] = self._get_call_result(
location=expr.location,
callee=getitem,
positional=[(index, index_type)],
keywords={},
report_errors=False,
)
if not result.is_valid:
return None
return result.result
return result
def define_typevar(self, call: p.CallExpr) -> Optional[TypeVar]:
def is_kw_true(name: str) -> bool:
@@ -953,7 +1218,7 @@ class PythonTyper(
node: ast.Expression = ast.parse(value, mode="eval")
return parser._parse_type(node.body)
case p.VariableExpr(name=name):
return p.BaseType(location=location, base=name, args=())
return p.BaseType(location=location, base=name, param=None)
case _:
raise NotImplementedError
@@ -992,22 +1257,6 @@ class PythonTyper(
pairs.append((key_val, value_val))
return True, dict(pairs)
case p.UnaryExpr(operator=operator, right=operand):
is_lit, operand_val = self._get_literal(operand)
if not is_lit:
return False, None
match operator:
case ast.UAdd():
return True, operand_val
case ast.USub():
return True, -operand_val
case ast.Invert():
return True, ~operand_val
case ast.Not():
return True, not operand_val
case _: # Should never be reached
return False, None
case _:
return False, None
@@ -1015,56 +1264,11 @@ class PythonTyper(
self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any
) -> bool:
match target_type:
case TopType():
return True
case UnitType():
if lit_value is not None:
self.reporter.error(
expr.location, f"Value {lit_value!r} is not None"
)
return False
return True
case DerivedType(type=base):
case AliasType(type=base):
return self._evaluate_cast_statically(
expr, subject_type, base, lit_value
)
case AppliedType(name="list", args=[item_type]) if isinstance(
lit_value, list
):
match subject_type:
case AppliedType(name="list", args=[lit_item_type]):
evaluated: bool = True
for item in lit_value:
if not self._evaluate_cast_statically(
expr, lit_item_type, item_type, item
):
evaluated = False
return evaluated
case _:
return False
case AppliedType(name="dict", args=[key_type, value_type]) if isinstance(
lit_value, dict
):
match subject_type:
case AppliedType(name="dict", args=[lit_key_type, lit_value_type]):
evaluated: bool = True
for key, value in lit_value.items():
if not self._evaluate_cast_statically(
expr, lit_key_type, key_type, key
):
evaluated = False
if not self._evaluate_cast_statically(
expr, lit_value_type, value_type, value
):
evaluated = False
return evaluated
case _:
return False
case AppliedType(body=body):
return self._evaluate_cast_statically(
expr, subject_type, body, lit_value
@@ -1079,19 +1283,10 @@ class PythonTyper(
evaluator = Evaluator(self.types)
evaluator.set_value("_", lit_value)
printer = MidasPrinter()
constraint_str: str = printer.print(constraint)
res: Any
try:
res = evaluator.evaluate(constraint)
except Exception as e:
self.reporter.error(
expr.location,
f"An error occurred while checking constraint '{constraint_str}' on the value {lit_value!r}: {e}",
)
return False
res = evaluator.evaluate(constraint)
if not res:
printer = MidasPrinter()
constraint_str: str = printer.print(constraint)
self.reporter.error(
expr.location,
f"Value {lit_value!r} does not fit constraint '{constraint_str}'",
@@ -1111,12 +1306,6 @@ class PythonTyper(
return False
return True
case DataFrameType() | ColumnType():
self.reporter.error(
expr.location, f"Cannot cast {lit_value!r} to {target_type}"
)
return False
case _:
self.reporter.info(
expr.location, f"Cannot evaluate cast to {target_type} statically"
@@ -1142,10 +1331,3 @@ class PythonTyper(
self, frame: DataFrameType, expr: p.SubscriptExpr
) -> Type:
return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index)
def _visit_frame_groupby_subscript(
self, groupby: FrameGroupBy, expr: p.SubscriptExpr
) -> Type:
return self.frame_mgr.groupby_get(
self.reporter, expr.location, groupby, expr.index
)

View File

@@ -5,20 +5,19 @@ from typing import Optional
from midas.ast.midas import MemberKind
from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
DerivedType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
Predicate,
TopType,
TupleType,
Type,
TypeVar,
UnknownType,
@@ -113,15 +112,6 @@ class TypesRegistry:
raise ValueError(f"Predicate {name} already defined")
self._predicates[name] = predicate
def is_builtin_subtype(self, name1: str, name2: str) -> bool:
subtypes: set[str] = BUILTIN_SUBTYPES.get(name2, set())
if name1 in subtypes:
return True
for subtype in subtypes:
if self.is_builtin_subtype(name1, subtype):
return True
return False
def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2`
@@ -155,11 +145,11 @@ class TypesRegistry:
return True
return self.is_subtype(type1, bound)
case (DerivedType(type=base1), _):
case (AliasType(type=base1), _):
return self.is_subtype(base1, type2)
case (BaseType(name=name1), BaseType(name=name2)):
return self.is_builtin_subtype(name1, name2)
return name1 in BUILTIN_SUBTYPES.get(name2, set())
case (ComplexType(properties=props1), ComplexType(properties=props2)):
for k, t in props2.items():
@@ -327,8 +317,8 @@ class TypesRegistry:
def apply_generic(self, type: Type, args: list[Type]) -> Type:
match type:
case DerivedType(name=name, type=base):
return DerivedType(name=name, type=self.apply_generic(base, args))
case AliasType(name=name, type=base):
return AliasType(name=name, type=self.apply_generic(base, args))
case GenericType(name=name, params=type_vars, body=body):
n_args: int = len(args)
@@ -356,9 +346,6 @@ class TypesRegistry:
body=substitute_typevars(body, substitutions),
)
case BaseType(name="tuple"):
return TupleType(items=tuple(args))
case _:
raise ValueError(f"{type} is not a generic type")
@@ -398,7 +385,7 @@ class TypesRegistry:
return self._members[name][member_name].type
return None
case DerivedType(name=name, type=base):
case AliasType(name=name, type=base):
if name in self._members:
if member_name in self._members[name]:
return self._members[name][member_name].type
@@ -452,29 +439,3 @@ class TypesRegistry:
def lookup_predicate(self, name: str) -> Optional[Predicate]:
return self._predicates.get(name)
def _by_name_or_type(self, name_or_type: str | Type) -> Type:
if isinstance(name_or_type, str):
return self.get_type(name_or_type)
return name_or_type
def list_of(self, item_type: str | Type) -> Type:
list_ = self.get_type("list")
return self.apply_generic(list_, [self._by_name_or_type(item_type)])
def tuple_of(self, *item_types: str | Type) -> Type:
tuple_ = self.get_type("tuple")
return self.apply_generic(
tuple_,
[self._by_name_or_type(item_type) for item_type in item_types],
)
def dict_of(self, key_type: str | Type, value_type: str | Type) -> Type:
dict_ = self.get_type("dict")
return self.apply_generic(
dict_,
[
self._by_name_or_type(key_type),
self._by_name_or_type(value_type),
],
)

View File

@@ -236,9 +236,5 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
if expr.step is not None:
self.resolve(expr.step)
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
for item in expr.items:
self.resolve(item)
def visit_raw_expr(self, expr: p.RawExpr) -> None:
pass

View File

@@ -23,7 +23,7 @@ class BaseType:
@dataclass(frozen=True, kw_only=True)
class DerivedType:
class AliasType:
name: str
type: Type
@@ -187,22 +187,6 @@ class DataFrameType:
type: ColumnType
@dataclass(frozen=True, kw_only=True)
class FrameGroupBy:
frame: DataFrameType
def __str__(self) -> str:
return f"FrameGroupBy[{self.frame}]"
@dataclass(frozen=True, kw_only=True)
class ColumnGroupBy:
column: ColumnType
def __str__(self) -> str:
return f"ColumnGroupBy[{self.column}]"
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def sub_argument(arg: Function.Argument):
return Function.Argument(
@@ -229,10 +213,8 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
case BaseType():
return type
case DerivedType(name=name, type=type2):
return DerivedType(
name=name, type=substitute_typevars(type2, substitutions)
)
case AliasType(name=name, type=type2):
return AliasType(name=name, type=substitute_typevars(type2, substitutions))
case Function(
pos_args=pos_args,
@@ -321,20 +303,11 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
columns=list(map(sub_column, columns)),
)
case FrameGroupBy(frame=frame):
return FrameGroupBy(
frame=cast(DataFrameType, substitute_typevars(frame, substitutions))
)
case ColumnGroupBy(column=column):
return ColumnGroupBy(
column=cast(ColumnType, substitute_typevars(column, substitutions))
)
case UnknownType() | UnitType():
return type
case TopType() | GenericType():
raise NotImplementedError(f"Unsupported type {type}")
# Ensure exhaustiveness
@@ -344,7 +317,7 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def unfold_type(type: Type) -> Type:
match type:
case DerivedType(type=ref_type):
case AliasType(type=ref_type):
return unfold_type(ref_type)
case _:
return type
@@ -367,7 +340,7 @@ def to_annotation(type: Type) -> str:
case BaseType(name=name):
return name
case DerivedType(name=name):
case AliasType(name=name):
return name
case UnknownType():
@@ -407,12 +380,6 @@ def to_annotation(type: Type) -> str:
case DataFrameType():
return "pd.DataFrame"
case FrameGroupBy():
return "pd.api.typing.DataFrameGroupBy"
case ColumnGroupBy():
return "pd.api.typing.SeriesGroupBy"
case _:
assert_never(type)
@@ -427,7 +394,7 @@ class Predicate:
Type = (
TopType
| BaseType
| DerivedType
| AliasType
| UnknownType
| UnitType
| Function
@@ -441,6 +408,4 @@ Type = (
| TupleType
| ColumnType
| DataFrameType
| FrameGroupBy
| ColumnGroupBy
)

View File

@@ -4,8 +4,6 @@ from typing import Optional
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
AppliedType,
ColumnType,
DataFrameType,
Function,
GenericType,
TopType,
@@ -100,30 +98,6 @@ class Unifier:
return substitutions
case (
DataFrameType(columns=template_columns),
DataFrameType(columns=concrete_columns),
) if len(template_columns) == len(concrete_columns):
substitutions: dict[str, Type] = {}
for template_column, concrete_column in zip(
template_columns, concrete_columns
):
if template_column.index != concrete_column or (
template_column.name != concrete_column.name
):
self.logger.debug(
f"Column mismatch: template={template_column}, concrete={concrete_column}"
)
raise UnificationError
new_substistutions: dict[str, Type] = self.match(
template_column.type, concrete_column.type
)
substitutions = self.merge(substitutions, new_substistutions)
return substitutions
case (ColumnType(type=template_column), ColumnType(type=concrete_column)):
return self.match(template_column, concrete_column)
case (Function(), Function()):
mapped: list[tuple[Function.Argument, Function.Argument]] = (
self.map_params(template, concrete)

View File

@@ -5,7 +5,7 @@
import sys
from pathlib import Path
from typing import Optional, TextIO
from typing import TextIO
import click
@@ -19,23 +19,18 @@ from midas.utils import TypedAST
@click.command(help="Compile source")
@click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-s", "--stubs", type=str, multiple=True)
@click.option("--ignore-errors", is_flag=True)
def compile(
file: TextIO,
types: tuple[TextIO],
stubs: tuple[str],
ignore_errors: bool,
):
source: str = file.read()
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
type_files: list[tuple[Path, Optional[str]]] = []
for i, types_file in enumerate(types):
in_path: Path = Path(types_file.name).resolve()
checker.import_midas(in_path)
type_files.append((in_path, stubs[i] if i < len(stubs) else None))
for types_file in types:
checker.import_midas(Path(types_file.name).resolve())
typed_ast: TypedAST = checker.type_check_source(source, str(source_path))
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
@@ -48,4 +43,4 @@ def compile(
sys.exit(1)
generator = Generator(workdir=source_path.parent, types=checker.types)
generator.generate(typed_ast, source_path, type_files=type_files)
generator.generate(typed_ast, source_path)

View File

@@ -11,14 +11,14 @@ import click
from midas.ast.printer import MidasPrinter
from midas.checker.checker import TypeChecker
from midas.checker.registry import Member
from midas.checker.types import AppliedType, BaseType, DerivedType, GenericType, Type
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
def base_type(type: Type) -> Type:
match type:
case BaseType():
return type
case DerivedType(type=base):
case AliasType(type=base):
return base
case AppliedType(body=body):
return body

View File

@@ -1,7 +1,7 @@
import ast
import time
from pathlib import Path
from typing import Optional, TextIO
from typing import TextIO
import black
import click
@@ -38,17 +38,15 @@ class Handler(FileSystemEventHandler):
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"))
@click.option("-o", "--output", type=click.File("w"), default="-")
@click.option("-w", "--watch", is_flag=True)
def stubs(
file: TextIO,
output: Optional[TextIO],
output: TextIO,
watch: bool,
):
source_path: Path = Path(file.name).resolve()
out_path: Path = source_path.with_suffix(".pyi")
if output is not None:
out_path = Path(output.name).resolve()
out_path: Path = Path(output.name).resolve()
generate_stubs(source_path, out_path)
if watch:

View File

@@ -134,9 +134,9 @@ class PythonHighlighter(
def visit_base_type(self, node: p.BaseType) -> None:
self.wrap(node, "base-type")
for arg in node.args:
self.wrap(arg, "arg")
arg.accept(self)
if node.param is not None:
self.wrap(node.param, "param")
node.param.accept(self)
def visit_constraint_type(self, node: p.ConstraintType) -> None:
self.wrap(node, "constraint-type")
@@ -247,10 +247,6 @@ class PythonHighlighter(
if expr.step is not None:
expr.step.accept(self)
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
for item in expr.items:
item.accept(self)
def visit_raw_expr(self, expr: p.RawExpr) -> None: ...
def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ...

View File

@@ -3,7 +3,7 @@ span {
--col: 108, 233, 108;
}
&.arg {
&.param {
--col: 103, 192, 224;
}

View File

@@ -1,59 +0,0 @@
import ast
from dataclasses import dataclass
from typing import Callable
import midas.ast.python as p
AssertionBuilder = Callable[..., ast.expr]
@dataclass
class Assertion:
bound_expr: p.Expr
inputs: list[p.Expr]
builder: AssertionBuilder
message: str
def is_bound_to(self, expr: p.Expr) -> bool:
return expr == self.bound_expr
class AssertionCollector:
def __init__(self):
self.assertions: list[Assertion] = []
self.definitions: dict[str, ast.stmt] = {}
def add(
self,
bound_expr: p.Expr,
inputs: list[p.Expr],
builder: AssertionBuilder,
message: str,
):
self.assertions.append(
Assertion(
bound_expr=bound_expr,
inputs=inputs,
builder=builder,
message=message,
)
)
def remove(self, assertion: Assertion):
try:
self.assertions.remove(assertion)
except ValueError:
pass
def define(self, name: str, stmt: ast.stmt):
if name not in self.definitions:
self.definitions[name] = stmt
def get_definitions(self) -> list[ast.stmt]:
return list(self.definitions.values())
def get_assertions(self) -> list[Assertion]:
return self.assertions
def get_assertions_for(self, expr: p.Expr) -> list[Assertion]:
return list(filter(lambda a: a.is_bound_to(expr), self.assertions))

View File

@@ -9,19 +9,16 @@ import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasPrinter
from midas.checker.checker import TypeChecker
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ColumnGroupBy,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
DerivedType,
ExtensionType,
FrameGroupBy,
Function,
GenericType,
OverloadedFunction,
@@ -32,22 +29,17 @@ from midas.checker.types import (
UnitType,
UnknownType,
)
from midas.generator.collector import Assertion, AssertionCollector
from midas.generator.constraints import ConstraintGenerator
from midas.generator.stubs import StubsGenerator
from midas.utils import TypedAST
@dataclass
class Scope:
pre_assertions: list[ast.stmt] = field(default_factory=list[ast.stmt])
aliases: list[str] = field(default_factory=list[str])
pre_assertions: list[ast.stmt] = field(default_factory=list)
aliases: list[str] = field(default_factory=list)
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
IS_DATAFRAME_FUNC = "__midas_is_dataframe__"
IS_COLUMN_FUNC = "__midas_is_column__"
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
self.workdir: Path = workdir.resolve()
self.build_dir: Path = self.workdir / "build" / "midas"
@@ -58,47 +50,28 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
stmts=[],
judgements=[],
evaluated_casts=[],
assertions=AssertionCollector(),
)
self._alias_count: int = 0
self._predicate_count: int = 0
self._scopes: list[Scope] = []
self._aliases: list[tuple[p.Expr, ast.expr]] = []
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
self._constraints: list[tuple[m.Expr, ast.expr]] = []
self.define_is_dataframe: bool = False
self.define_is_column: bool = False
def set_src_path(self, path: Path):
self.rel_src_path = path.resolve().relative_to(self.workdir)
def generate_ast(self, typed_ast: TypedAST) -> ast.AST:
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
self.rel_src_path = src_path.resolve().relative_to(self.workdir)
self._typed_ast = typed_ast
body: list[ast.stmt] = self._visit_body(typed_ast.stmts, can_be_empty=True)
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
body = predicates + body
if self.define_is_dataframe:
body = [self._is_dataframe_definition()] + body
if self.define_is_column:
body = [self._is_column_definition()] + body
module = ast.Module(body=body, type_ignores=[])
module = ast.Module(body=predicates + body, type_ignores=[])
module = ast.fix_missing_locations(module)
return module
def generate(
self,
typed_ast: TypedAST,
src_path: Path,
out_path: Optional[Path] = None,
type_files: Optional[list[tuple[Path, Optional[str]]]] = None,
self, typed_ast: TypedAST, src_path: Path, out_path: Optional[Path] = None
) -> Path:
self.set_src_path(src_path)
module: ast.AST = self.generate_ast(typed_ast, src_path)
compiled: str = ast.unparse(module)
if out_path is None:
if self.build_dir.exists():
shutil.rmtree(self.build_dir)
@@ -110,72 +83,43 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
raise ValueError(
f"Directory traversal, {self.rel_src_path} points outside of parent directory"
)
out_dir: Path = out_path.parent
out_dir.parent.mkdir(parents=True, exist_ok=True)
if type_files is not None:
for in_path, out_name in type_files:
if out_name is None:
out_name = in_path.stem
self.generate_stubs(in_path, out_dir / f"{out_name}.py")
module: ast.AST = self.generate_ast(typed_ast)
compiled: str = ast.unparse(module)
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(compiled)
return out_path
def generate_stubs(self, in_path: Path, out_path: Path):
checker = TypeChecker()
checker.import_midas(in_path)
generator = StubsGenerator(checker.types)
module: ast.Module = generator.generate_stubs()
module = ast.fix_missing_locations(module)
output: str = ast.unparse(module)
out_path.write_text(output)
def convert(self, expr: p.Expr) -> ast.expr:
for expr2, alias in self._aliases:
if expr2 == expr:
return alias
assertions = self._typed_ast.assertions.get_assertions_for(expr)
if len(assertions) != 0:
return self._apply_assertions(expr, assertions)
return expr.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> ast.expr:
return ast.BinOp(
left=self.convert(expr.left),
left=expr.left.accept(self),
op=expr.operator,
right=self.convert(expr.right),
right=expr.right.accept(self),
)
def visit_compare_expr(self, expr: p.CompareExpr) -> ast.expr:
return ast.Compare(
left=self.convert(expr.left),
left=expr.left.accept(self),
ops=[expr.operator],
comparators=[self.convert(expr.right)],
comparators=[expr.right.accept(self)],
)
def visit_unary_expr(self, expr: p.UnaryExpr) -> ast.expr:
return ast.UnaryOp(
op=expr.operator,
operand=self.convert(expr.right),
operand=expr.right.accept(self),
)
def visit_call_expr(self, expr: p.CallExpr) -> ast.expr:
return ast.Call(
func=self.convert(expr.callee),
args=[self.convert(arg) for arg in expr.arguments],
func=expr.callee.accept(self),
args=[arg.accept(self) for arg in expr.arguments],
keywords=[
ast.keyword(arg=name, value=self.convert(arg))
ast.keyword(arg=name, value=arg.accept(self))
for name, arg in expr.keywords.items()
],
)
def visit_get_expr(self, expr: p.GetExpr) -> ast.expr:
return ast.Attribute(
value=self.convert(expr.object),
value=expr.object.accept(self),
attr=expr.name,
)
@@ -188,58 +132,51 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_logical_expr(self, expr: p.LogicalExpr) -> ast.expr:
return ast.BoolOp(
op=expr.operator,
values=[self.convert(expr.left), self.convert(expr.right)],
values=[expr.left.accept(self), expr.right.accept(self)],
)
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
expr2: ast.expr = self.convert(expr.expr)
expr2: ast.expr = expr.expr.accept(self)
if expr in self._typed_ast.evaluated_casts or expr.unsafe:
return expr2
alias: ast.expr = self._make_alias(expr.expr, expr2)
alias: ast.expr = self._make_alias(expr2)
type: Type = self._get_expr_type(expr)
asserts: list[ast.stmt] = self._make_cast_asserts(expr.location, alias, type)
for assert_ in asserts:
self._add_assert(assert_)
self._make_cast_asserts(expr.location, alias, type)
return alias
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
return ast.IfExp(
test=self.convert(expr.test),
body=self.convert(expr.if_true),
orelse=self.convert(expr.if_false),
test=expr.test.accept(self),
body=expr.if_true.accept(self),
orelse=expr.if_false.accept(self),
)
def visit_list_expr(self, expr: p.ListExpr) -> ast.expr:
return ast.List(
elts=[self.convert(item) for item in expr.items],
elts=[item.accept(self) for item in expr.items],
)
def visit_dict_expr(self, expr: p.DictExpr) -> ast.expr:
return ast.Dict(
keys=[self.convert(key) if key is not None else None for key in expr.keys],
values=[self.convert(value) for value in expr.values],
keys=[key.accept(self) if key is not None else None for key in expr.keys],
values=[value.accept(self) for value in expr.values],
)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr:
return ast.Subscript(
value=self.convert(expr.object),
slice=self.convert(expr.index),
value=expr.object.accept(self),
slice=expr.index.accept(self),
)
def visit_slice_expr(self, expr: p.SliceExpr) -> ast.expr:
return ast.Slice(
lower=self.convert(expr.lower) if expr.lower is not None else None,
upper=self.convert(expr.upper) if expr.upper is not None else None,
step=self.convert(expr.step) if expr.step is not None else None,
)
def visit_tuple_expr(self, expr: p.TupleExpr) -> ast.expr:
return ast.Tuple(
elts=[self.convert(item) for item in expr.items],
lower=expr.lower.accept(self) if expr.lower is not None else None,
upper=expr.upper.accept(self) if expr.upper is not None else None,
step=expr.step.accept(self) if expr.step is not None else None,
)
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
@@ -247,7 +184,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt:
return ast.Expr(
value=self.convert(stmt.expr),
value=stmt.expr.accept(self),
)
def visit_function(self, stmt: p.Function) -> ast.stmt:
@@ -260,12 +197,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
kwonlyargs=[ast.arg(arg=arg.name) for arg in stmt.kwonlyargs],
kwarg=None,
defaults=[
self.convert(arg.default)
arg.default.accept(self)
for arg in stmt.posonlyargs + stmt.args
if arg.default is not None
],
kw_defaults=[
self.convert(arg.default) if arg.default is not None else None
arg.default.accept(self) if arg.default is not None else None
for arg in stmt.kwonlyargs
],
),
@@ -279,20 +216,20 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_assign_stmt(self, stmt: p.AssignStmt) -> ast.stmt:
return ast.Assign(
targets=[self.convert(target) for target in stmt.targets],
value=self.convert(stmt.value),
targets=[target.accept(self) for target in stmt.targets],
value=stmt.value.accept(self),
)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> ast.stmt:
return ast.Return(
value=self.convert(stmt.value) if stmt.value is not None else None,
value=stmt.value.accept(self) if stmt.value is not None else None,
)
def visit_if_stmt(self, stmt: p.IfStmt) -> ast.stmt:
return ast.If(
test=self.convert(stmt.test),
test=stmt.test.accept(self),
body=self._visit_body(stmt.body),
orelse=self._visit_body(stmt.orelse, can_be_empty=True),
orelse=self._visit_body(stmt.orelse),
)
def visit_pass(self, stmt: p.Pass) -> ast.stmt:
@@ -300,8 +237,8 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_for_stmt(self, stmt: p.ForStmt) -> ast.stmt:
return ast.For(
target=self.convert(stmt.target),
iter=self.convert(stmt.iterator),
target=stmt.target.accept(self),
iter=stmt.iterator.accept(self),
body=self._visit_body(stmt.body),
orelse=[],
)
@@ -309,9 +246,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_raw_stmt(self, stmt: p.RawStmt) -> ast.stmt:
return stmt.stmt
def _visit_body(
self, stmts: list[p.Stmt], can_be_empty: bool = False
) -> list[ast.stmt]:
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
generated: list[ast.stmt] = []
for stmt in stmts:
scope = Scope()
@@ -329,11 +264,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
# Remove redundant pass statements
if len(generated) > 1:
generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)]
if len(generated) == 0 and not can_be_empty:
generated = [ast.Pass()]
return generated
def _make_alias(self, node: p.Expr, expr: ast.expr) -> ast.expr:
def _make_alias(self, expr: ast.expr) -> ast.expr:
name: str = f"__midas_a{self._alias_count}__"
alias = ast.Name(id=name)
self._alias_count += 1
@@ -344,182 +277,97 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
value=expr,
)
)
self._aliases.append((node, alias))
return alias
def _build_assert(self, expr: ast.expr, message: str | ast.expr) -> ast.stmt:
def _add_assert(self, expr: ast.expr, message: str | ast.expr):
if isinstance(message, str):
message = ast.Constant(value=message)
return ast.Assert(
test=expr,
msg=message,
self._scopes[-1].pre_assertions.append(
ast.Assert(
test=expr,
msg=message,
)
)
def _add_assert(self, assertion: ast.stmt):
self._scopes[-1].pre_assertions.append(assertion)
def _get_expr_type(self, query: p.Expr) -> Type:
for expr, type in self._typed_ast.judgements:
if expr == query:
return type
raise RuntimeError(f"Cannot get type judgement for {query}")
def _make_cast_asserts(
self, src_location: Location, expr: ast.expr, type: Type
) -> list[ast.stmt]:
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
match type:
case UnknownType() | TopType():
return []
case UnknownType():
pass
case BaseType(name=name):
return [
self._build_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id=name)],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
)
]
self._add_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id=name)],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
)
case DerivedType(type=base):
return self._make_cast_asserts(src_location, expr, base)
case AliasType(type=base):
self._make_cast_asserts(src_location, expr, base)
case UnitType():
return [
self._build_assert(
ast.Compare(
left=expr,
ops=[ast.Is()],
comparators=[
ast.Constant(value=None),
],
),
self._make_cast_assert_message(src_location, expr, type),
self._add_assert(
ast.Compare(
left=expr,
ops=[ast.Is()],
comparators=[
ast.Constant(value=None),
],
),
]
self._make_cast_assert_message(src_location, expr, type),
)
case AppliedType(body=body):
return self._make_cast_asserts(src_location, expr, body)
self._make_cast_asserts(src_location, expr, body)
case ConstraintType(type=base, constraint=constraint):
asserts: list[ast.stmt] = self._make_cast_asserts(
src_location, expr, base
)
asserts.append(
self._make_constraint_assert(src_location, expr, constraint)
)
return asserts
self._make_cast_asserts(src_location, expr, base)
self._make_constraint_assert(src_location, expr, constraint)
case TypeVar(bound=bound):
# TODO: check with type from arguments / use call-site context
if bound is None:
return []
return self._make_cast_asserts(src_location, expr, bound)
if bound is not None:
self._make_cast_asserts(src_location, expr, bound)
case TupleType(items=items):
asserts: list[ast.stmt] = [
self._build_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id="tuple")],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
self._add_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id="tuple")],
keywords=[],
),
]
self._make_cast_assert_message(src_location, expr, type),
)
assert isinstance(expr, ast.Tuple)
for item, item_type in zip(expr.elts, items):
asserts.extend(
self._make_cast_asserts(src_location, item, item_type)
)
return asserts
case DataFrameType(columns=columns):
self.define_is_dataframe = True
asserts: list[ast.stmt] = [
self._build_assert(
ast.Call(
func=ast.Name(id=self.IS_DATAFRAME_FUNC),
args=[expr],
keywords=[],
),
self._make_cast_assert_message(
src_location, expr, type, ": Not a dataframe"
),
),
]
for column in columns:
asserts.append(
self._build_assert(
ast.Compare(
left=ast.Constant(value=column.name),
ops=[ast.In()],
comparators=[expr],
),
self._make_cast_assert_message(
src_location,
expr,
type,
f": Missing column {column.name}",
),
)
)
asserts.extend(
self._make_cast_asserts(
src_location,
ast.Subscript(
value=expr, slice=ast.Constant(value=column.name)
),
column.type,
)
)
return asserts
case ColumnType():
self.define_is_column = True
asserts: list[ast.stmt] = [
self._build_assert(
ast.Call(
func=ast.Name(id=self.IS_COLUMN_FUNC),
args=[expr],
keywords=[],
),
self._make_cast_assert_message(
src_location, expr, type, ": Not a column"
),
),
]
inner_assert: Optional[ast.stmt] = self._make_column_inner_assert(
src_location, expr, type
)
if inner_assert is not None:
asserts.append(inner_assert)
return asserts
self._make_cast_asserts(src_location, item, item_type)
case (
Function()
TopType()
| Function()
| OverloadedFunction()
| ComplexType()
| ExtensionType()
| GenericType()
| FrameGroupBy()
| ColumnGroupBy()
| ColumnType()
| DataFrameType()
):
self.logger.warning(f"Can't make assertion for type {type}")
return []
# Ensure exhaustiveness
case _:
assert_never(type)
def _make_cast_assert_message(
self,
location: Location,
expr: ast.expr,
type: Type,
extra: Optional[str] = None,
self, location: Location, expr: ast.expr, type: Type
) -> ast.expr:
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
@@ -537,15 +385,15 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
),
conversion=-1,
),
ast.Constant(f" to {type}{extra or ''}"),
ast.Constant(f" to {type}"),
]
)
def _make_constraint_assert(
self, src_location: Location, expr: ast.expr, constraint: m.Expr
) -> ast.stmt:
):
test_func: ast.expr = self._get_constraint(constraint)
return self._build_assert(
self._add_assert(
ast.Call(
func=test_func,
args=[expr],
@@ -573,117 +421,3 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
constraint: ast.expr = self._constraint_generator.generate(expr)
self._constraints.append((expr, constraint))
return constraint
def _is_dataframe_definition(self) -> ast.stmt:
"""
def IS_DATAFRAME_FUNC(obj) -> bool:
import pandas as pd
return isinstance(obj, pd.DataFrame)
"""
return ast.FunctionDef(
name=self.IS_DATAFRAME_FUNC,
args=ast.arguments(
posonlyargs=[ast.arg(arg="obj")],
args=[],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
ast.Return(
value=ast.Call(
func=ast.Name(id="isinstance"),
args=[
ast.Name(id="obj"),
ast.Attribute(
value=ast.Name(id="pd"),
attr="DataFrame",
),
],
keywords=[],
)
),
],
decorator_list=[],
returns=ast.Name(id="bool"),
)
def _is_column_definition(self) -> ast.stmt:
"""
def IS_COLUMN_FUNC(obj) -> bool:
import pandas as pd
return isinstance(obj, pd.Series)
"""
return ast.FunctionDef(
name=self.IS_COLUMN_FUNC,
args=ast.arguments(
posonlyargs=[ast.arg(arg="obj")],
args=[],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
ast.Return(
value=ast.Call(
func=ast.Name(id="isinstance"),
args=[
ast.Name(id="obj"),
ast.Attribute(
value=ast.Name(id="pd"),
attr="Series",
),
],
keywords=[],
)
),
],
decorator_list=[],
returns=ast.Name(id="bool"),
)
def _make_column_inner_assert(
self, src_location: Location, column: ast.expr, type: ColumnType
) -> Optional[ast.stmt]:
# TODO: improve message, maybe chain contexts
col: ast.expr = ast.Name(id="col")
body: list[ast.stmt] = self._make_cast_asserts(src_location, col, type.type)
if len(body) == 0:
return None
return ast.For(
target=col,
iter=column,
body=body,
orelse=[],
)
def _convert_assertion(self, assertion: Assertion) -> ast.stmt:
inputs: list[ast.expr] = []
for input in assertion.inputs:
converted: ast.expr = self.convert(input)
alias: ast.expr = self._make_alias(input, converted)
inputs.append(alias)
test: ast.expr = assertion.builder(*inputs)
location: Location = assertion.bound_expr.location
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
return self._build_assert(
test, f"{loc_str}: AssertionError: {assertion.message}"
)
def _apply_assertions(self, expr: p.Expr, assertions: list[Assertion]) -> ast.expr:
for assertion in assertions:
assert_stmt: ast.stmt
assert_stmt = self._convert_assertion(assertion)
self._add_assert(assert_stmt)
# Mutating list in frozen dataclass
# Not ideal but easiest way to avoid duplicate assertions
self._typed_ast.assertions.remove(assertion)
return expr.accept(self)

View File

@@ -4,16 +4,14 @@ from typing import Optional, assert_never
import midas.ast.midas as m
from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ColumnGroupBy,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
DerivedType,
ExtensionType,
FrameGroupBy,
Function,
GenericType,
OverloadedFunction,
@@ -93,21 +91,6 @@ class StubsGenerator:
def generate_stub(self, name: str, type: Type):
base_type: Type = type
# TODO: improve
match type:
case DerivedType(name=name_) | GenericType(name=name_) if name_ == name:
pass
case UnitType() if name == "None":
pass
case TopType() if name == "Any":
pass
case _:
alias = ast.Assign(
targets=[ast.Name(id=name)], value=self.dump_type(type)
)
self.add_stub(alias)
return
members: dict[str, Member] = self.types._members.get(name, {})
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
return
@@ -129,7 +112,7 @@ class StubsGenerator:
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
match type:
case DerivedType(type=base):
case AliasType(type=base):
return [self.dump_type(base)], {}
case GenericType(params=params, body=body):
@@ -194,7 +177,7 @@ class StubsGenerator:
def dump_type(self, type: Type) -> ast.expr:
match type:
case DerivedType(name=name) | GenericType(name=name) if (
case AliasType(name=name) | GenericType(name=name) if (
name in self.substitutions
):
type = substitute_typevars(type, self.substitutions[name])
@@ -207,7 +190,7 @@ class StubsGenerator:
case BaseType(name=name):
return ast.Name(id=name)
case DerivedType(name=name):
case AliasType(name=name):
return ast.Name(id=name)
case UnitType():
@@ -289,32 +272,6 @@ class StubsGenerator:
attr="DataFrame",
)
case FrameGroupBy():
self.import_pandas = True
return ast.Attribute(
value=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id="pd"),
attr="api",
),
attr="typing",
),
attr="DataFrameGroupBy",
)
case ColumnGroupBy():
self.import_pandas = True
return ast.Attribute(
value=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id="pd"),
attr="api",
),
attr="typing",
),
attr="SeriesGroupBy",
)
case _:
assert_never(type)

View File

@@ -46,8 +46,8 @@ class MidasLexer(Lexer):
self.add_token(TokenType.UNDERSCORE)
case "-" if self.match(">"):
self.add_token(TokenType.ARROW)
case "+":
self.add_token(TokenType.PLUS)
# case "+":
# self.add_token(TokenType.PLUS)
case "-":
self.add_token(TokenType.MINUS)
case "*":

View File

@@ -25,7 +25,7 @@ class TokenType(Enum):
DOT = auto()
# Operators
PLUS = auto()
# PLUS = auto()
MINUS = auto()
STAR = auto()
SLASH = auto()
@@ -47,7 +47,6 @@ class TokenType(Enum):
# Keywords
TYPE = auto()
ALIAS = auto()
PREDICATE = auto()
EXTEND = auto()
WHERE = auto()
@@ -64,7 +63,6 @@ class TokenType(Enum):
KEYWORDS: dict[str, TokenType] = {
"type": TokenType.TYPE,
"alias": TokenType.ALIAS,
"predicate": TokenType.PREDICATE,
"extend": TokenType.EXTEND,
"where": TokenType.WHERE,

View File

@@ -2,7 +2,6 @@ from typing import Optional
from midas.ast.location import Location
from midas.ast.midas import (
AliasStmt,
BinaryExpr,
CallExpr,
ComplexType,
@@ -81,8 +80,6 @@ class MidasParser(Parser):
try:
if self.match(TokenType.TYPE):
return self.type_declaration()
if self.match(TokenType.ALIAS):
return self.alias_declaration()
if self.match(TokenType.EXTEND):
return self.extend_declaration()
if self.match(TokenType.PREDICATE):
@@ -162,25 +159,6 @@ class MidasParser(Parser):
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
return params
def alias_declaration(self) -> AliasStmt:
"""Parse an alias declaration
Returns:
AliasStmt: the parsed alias declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume_identifier("Expected type name")
self.consume(TokenType.EQUAL, "Expected '=' before alias definition")
type: Type = self.type_expr()
return AliasStmt(
location=keyword.location_to(self.previous()),
name=name,
type=type,
)
def type_expr(self) -> Type:
"""Parse a type expression
@@ -361,35 +339,13 @@ class MidasParser(Parser):
Returns:
Expr: the parsed expression
"""
expr: Expr = self.term()
expr: Expr = self.unary()
while self.match(
TokenType.LESS,
TokenType.LESS_EQUAL,
TokenType.GREATER,
TokenType.GREATER_EQUAL,
):
operator: Token = self.previous()
right: Expr = self.term()
location: Location = Location.span(expr.location, right.location)
expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right
)
return expr
def term(self) -> Expr:
expr: Expr = self.factor()
while self.match(TokenType.PLUS, TokenType.MINUS):
operator: Token = self.previous()
right: Expr = self.factor()
location: Location = Location.span(expr.location, right.location)
expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right
)
return expr
def factor(self) -> Expr:
expr: Expr = self.unary()
while self.match(TokenType.STAR, TokenType.SLASH):
operator: Token = self.previous()
right: Expr = self.unary()
location: Location = Location.span(expr.location, right.location)
@@ -421,7 +377,7 @@ class MidasParser(Parser):
pos_args: list[Expr] = []
kw_args: dict[str, Expr] = {}
keywords: bool = False
while not self.check(TokenType.RIGHT_PAREN):
while not self.match(TokenType.RIGHT_PAREN):
if self.check_identifier() and self.check_next(TokenType.EQUAL):
keywords = True
keyword: Token = self.advance()

View File

@@ -30,7 +30,6 @@ from midas.ast.python import (
Stmt,
SubscriptExpr,
TernaryExpr,
TupleExpr,
TypeAssign,
UnaryExpr,
VariableExpr,
@@ -301,28 +300,26 @@ class PythonParser:
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
return self._parse_frame_type(schema)
case ast.Subscript(value=ast.Name(id=name), slice=arg):
args: tuple[MidasType, ...] = (
tuple(self._parse_type(a) for a in arg.elts)
if isinstance(arg, ast.Tuple)
else (self._parse_type(arg),)
)
case ast.Subscript(value=ast.Name(id=name), slice=param):
return BaseType(
location=loc,
base=name,
args=args,
param=self._parse_type(param),
)
case ast.Name(id=name):
return BaseType(
location=loc,
base=name,
args=(),
param=None,
)
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
left = self._parse_type(left_expr)
match left:
case None:
raise InvalidSyntaxError()
# If chained constraints, separate base type and rebuild constraint
case ConstraintType(type=left_type, constraint=left_constraint):
constraint = ast.BinOp(
@@ -348,7 +345,7 @@ class PythonParser:
return BaseType(
location=loc,
base="None",
args=(),
param=None,
)
case _:
@@ -480,12 +477,6 @@ class PythonParser:
step=self.parse_expr(step) if step is not None else None,
)
case ast.Tuple(elts=items):
return TupleExpr(
location=location,
items=tuple(self.parse_expr(item) for item in items),
)
case _:
print(f"Unsupported expression: {ast.unparse(node)}")
return RawExpr(location=location, expr=node)

View File

@@ -3,7 +3,6 @@ from typing import Any, Callable, Optional
import midas.ast.python as p
from midas.checker.types import Type
from midas.generator.collector import AssertionCollector
AllowRepeat = Callable[[object], bool]
@@ -64,4 +63,3 @@ class TypedAST:
stmts: list[p.Stmt]
judgements: list[tuple[p.Expr, Type]]
evaluated_casts: list[p.CastExpr]
assertions: AssertionCollector

View File

@@ -1,43 +0,0 @@
from typing import Type
from midas.cli.ansi import Ansi
from tests.base import Tester
from tests.checker import CheckerTester
from tests.generator import GeneratorTester
from tests.midas import MidasTester
from tests.python import PythonTester
def print_banner(name: str):
horizontal: str = "+" + "-" * (len(name) + 2) + "+"
print(horizontal)
print(f"| {name} |")
print(horizontal)
def run_tests(tester_cls: Type[Tester]) -> bool:
print_banner(tester_cls.__name__)
tester: Tester = tester_cls()
success: bool = tester.run_all_tests()
print()
return success
def main():
testers: list[Type[Tester]] = [
PythonTester,
MidasTester,
CheckerTester,
GeneratorTester,
]
success: bool = all(map(run_tests, testers))
if success:
print(Ansi.FG(Ansi.BRIGHT_GREEN) + "All tests passed!" + Ansi.RESET)
else:
print(Ansi.FG(Ansi.BRIGHT_RED) + "Some tests failed!" + Ansi.RESET)
if __name__ == "__main__":
main()

View File

@@ -7,8 +7,6 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import Iterator, Protocol
from midas.cli.ansi import Ansi
class CaseResult(Protocol):
def dumps(self) -> str: ...
@@ -46,11 +44,8 @@ class Tester(ABC):
print(rule)
for i, test in enumerate(tests):
path: Path = test.resolve().relative_to(self.CASES_DIR)
print(f"{Ansi.FG(Ansi.BRIGHT_CYAN)}Case {i+1}/{n}: {path}{Ansi.RESET}")
print(Ansi.DIM, end="")
print(f"Case {i+1}/{n}: {test.resolve().relative_to(self.CASES_DIR)}")
success: bool = self._run_test(test)
print(Ansi.RESET, end="")
if success:
successes += 1
else:
@@ -151,9 +146,8 @@ class Tester(ABC):
if not success:
sys.exit(1)
case None:
success: bool = tester.run_all_tests()
if not success:
sys.exit(1)
print("No subcommand provided. Available subcommands: run, update")
sys.exit(1)
case _:
print(f"Unknown subcommand '{args.subcommand}'")
sys.exit(1)

View File

@@ -24,7 +24,7 @@
"type": {
"_type": "BaseType",
"base": "Meter",
"args": []
"param": null
},
"expr": {
"_type": "LiteralExpr",
@@ -62,7 +62,7 @@
"type": {
"_type": "BaseType",
"base": "Second",
"args": []
"param": null
},
"expr": {
"_type": "LiteralExpr",

View File

@@ -317,7 +317,7 @@
"pos": 0,
"name": "object",
"type": {},
"required": false
"required": true
}
],
"args": [],

View File

@@ -1,117 +0,0 @@
# type: ignore
# ruff: disable [F821]
df1: Frame[a:int, b:float]
df2: Frame[a:int, b:float]
_: Any
# Arithmetic
_ = df1 + df2
_ = df1 - df2
_ = df1 * df2
_ = df1 / df2
_ = df1 // df2
_ = df1 % df2
_ = df1**df2
# Comparisons
_ = df1 < df2
_ = df1 > df2
_ = df1 <= df2
_ = df1 >= df2
_ = df1 != df2
_ = df1 == df2
# Aggregate
_ = df1.kurt()
_ = df1.kurtosis()
_ = df1.max()
_ = df1.mean()
_ = df1.median()
_ = df1.min()
_ = df1.mode()
_ = df1.prod()
_ = df1.product()
_ = df1.std()
_ = df1.sum()
_ = df1.var()
# Groupby
df_gb = df1.groupby(by="a")
_ = df_gb.kurt()
_ = df_gb.max()
_ = df_gb.mean()
_ = df_gb.median()
_ = df_gb.min()
_ = df_gb.prod()
_ = df_gb.std()
_ = df_gb.sum()
_ = df_gb.var()
# Columns
col1 = df1["a"]
col2 = df1["a"]
# Arithmetic
_ = col1 + col2
_ = col1 - col2
_ = col1 * col2
_ = col1 / col2
_ = col1 // col2
_ = col1 % col2
_ = col1**col2
# Comparisons
_ = col1 < col2
_ = col1 > col2
_ = col1 <= col2
_ = col1 >= col2
_ = col1 != col2
_ = col1 == col2
# Aggregate
_ = col1.kurt()
_ = col1.kurtosis()
_ = col1.max()
_ = col1.mean()
_ = col1.median()
_ = col1.min()
_ = col1.mode()
_ = col1.prod()
_ = col1.product()
_ = col1.std()
_ = col1.sum()
_ = col1.var()
# Groupby
col_gb = col1.groupby(level=0)
_ = col_gb.kurt()
_ = col_gb.max()
_ = col_gb.mean()
_ = col_gb.median()
_ = col_gb.min()
_ = col_gb.prod()
_ = col_gb.std()
_ = col_gb.sum()
_ = col_gb.var()
# Attributes
_ = df1.ndim # int
_ = df1.size # int
_ = df1.shape # (int, int)
_ = col1.ndim # int
_ = col1.size # int
_ = col1.shape # (int)
_ = col1.T # Column[int]
# Misc
_ = df1.head()
_ = df1.tail()
_ = col1.head()
_ = col1.tail()

File diff suppressed because it is too large Load Diff

View File

@@ -16,7 +16,7 @@
"type": {
"_type": "BaseType",
"base": "bool",
"args": []
"param": null
}
},
{
@@ -25,7 +25,7 @@
"type": {
"_type": "BaseType",
"base": "int",
"args": []
"param": null
}
},
{
@@ -36,7 +36,7 @@
"type": {
"_type": "BaseType",
"base": "float",
"args": []
"param": null
},
"constraint": "(_ > 0) + (_ < 250)"
}
@@ -47,7 +47,7 @@
"type": {
"_type": "BaseType",
"base": "str",
"args": []
"param": null
}
},
{
@@ -56,7 +56,7 @@
"type": {
"_type": "BaseType",
"base": "datetime",
"args": []
"param": null
}
},
{
@@ -65,7 +65,7 @@
"type": {
"_type": "BaseType",
"base": "float",
"args": []
"param": null
}
},
{
@@ -79,7 +79,7 @@
"type": {
"_type": "BaseType",
"base": "_",
"args": []
"param": null
}
}
]

View File

@@ -16,7 +16,7 @@
"type": {
"_type": "BaseType",
"base": "GeoLocation",
"args": []
"param": null
}
}
]
@@ -28,13 +28,11 @@
"type": {
"_type": "BaseType",
"base": "Column",
"args": [
{
"_type": "BaseType",
"base": "GeoLocation",
"args": []
}
]
"param": {
"_type": "BaseType",
"base": "GeoLocation",
"param": null
}
}
},
{
@@ -67,13 +65,11 @@
"type": {
"_type": "BaseType",
"base": "Column",
"args": [
{
"_type": "BaseType",
"base": "GeoLocation",
"args": []
}
]
"param": {
"_type": "BaseType",
"base": "GeoLocation",
"param": null
}
}
},
{
@@ -121,7 +117,7 @@
"type": {
"_type": "BaseType",
"base": "Latitude",
"args": []
"param": null
}
},
{
@@ -150,7 +146,7 @@
"type": {
"_type": "BaseType",
"base": "Latitude",
"args": []
"param": null
}
},
{
@@ -179,13 +175,11 @@
"type": {
"_type": "BaseType",
"base": "Difference",
"args": [
{
"_type": "BaseType",
"base": "Latitude",
"args": []
}
]
"param": {
"_type": "BaseType",
"base": "Latitude",
"param": null
}
}
},
{
@@ -223,7 +217,7 @@
"type": {
"_type": "BaseType",
"base": "int",
"args": []
"param": null
},
"constraint": "_ >= 0"
}
@@ -236,7 +230,7 @@
"type": {
"_type": "BaseType",
"base": "float",
"args": []
"param": null
},
"constraint": "_ >= 0"
}
@@ -258,7 +252,7 @@
"type": {
"_type": "BaseType",
"base": "int",
"args": []
"param": null
},
"constraint": "Positive"
}
@@ -271,7 +265,7 @@
"type": {
"_type": "BaseType",
"base": "float",
"args": []
"param": null
},
"constraint": "Positive"
}

View File

@@ -14,17 +14,15 @@
"type": {
"_type": "BaseType",
"base": "Column",
"args": [
{
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"args": []
},
"constraint": "0 <= _ <= 1"
}
]
"param": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "0 <= _ <= 1"
}
},
"default": null
},
@@ -33,17 +31,15 @@
"type": {
"_type": "BaseType",
"base": "Column",
"args": [
{
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"args": []
},
"constraint": "0 <= _ <= 1"
}
]
"param": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "0 <= _ <= 1"
}
},
"default": null
}
@@ -54,17 +50,15 @@
"returns": {
"_type": "BaseType",
"base": "Column",
"args": [
{
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"args": []
},
"constraint": "0 <= _ <= 2"
}
]
"param": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "0 <= _ <= 2"
}
},
"body": [
{
@@ -73,17 +67,15 @@
"type": {
"_type": "BaseType",
"base": "Column",
"args": [
{
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"args": []
},
"constraint": "0 <= _ <= 2"
}
]
"param": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "0 <= _ <= 2"
}
}
},
{
@@ -125,7 +117,7 @@
"type": {
"_type": "BaseType",
"base": "int",
"args": []
"param": null
},
"default": null
}
@@ -136,7 +128,7 @@
"type": {
"_type": "BaseType",
"base": "float",
"args": []
"param": null
},
"default": null
}
@@ -148,7 +140,7 @@
"type": {
"_type": "BaseType",
"base": "str",
"args": []
"param": null
},
"default": null
}

View File

@@ -46,8 +46,7 @@ class GeneratorTester(Tester):
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
generator = Generator(workdir=path.parent, types=checker.types)
generator.set_src_path(path)
result.compiled_ast = generator.generate_ast(typed_ast)
result.compiled_ast = generator.generate_ast(typed_ast, path)
return result

View File

@@ -1,7 +1,6 @@
from typing import Optional, Sequence
from midas.ast.midas import (
AliasStmt,
BinaryExpr,
CallExpr,
ComplexType,
@@ -62,13 +61,6 @@ class MidasAstJsonSerializer(
"bound": self._serialize_optional(param.bound),
}
def visit_alias_stmt(self, stmt: AliasStmt) -> dict:
return {
"_type": "AliasStmt",
"name": stmt.name.lexeme,
"type": stmt.type.accept(self),
}
def visit_member_stmt(self, stmt: MemberStmt) -> dict:
return {
"_type": "MemberStmt",

View File

@@ -30,7 +30,6 @@ from midas.ast.python import (
Stmt,
SubscriptExpr,
TernaryExpr,
TupleExpr,
TypeAssign,
UnaryExpr,
VariableExpr,
@@ -99,7 +98,7 @@ class PythonAstJsonSerializer(
return {
"_type": "BaseType",
"base": node.base,
"args": self._serialize_list(node.args),
"param": self._serialize_optional(node.param),
}
def visit_constraint_type(self, node: ConstraintType) -> dict:
@@ -303,12 +302,6 @@ class PythonAstJsonSerializer(
"step": self._serialize_optional(expr.step),
}
def visit_tuple_expr(self, expr: TupleExpr) -> dict:
return {
"_type": "TupleExpr",
"items": [item.accept(self) for item in expr.items],
}
def visit_raw_expr(self, expr: RawExpr) -> dict:
return {
"_type": "RawExpr",