Compare commits
119 Commits
252a5abdfd
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 03bc32400b | |||
|
4a93ee45d9
|
|||
|
8197131d8d
|
|||
|
cf91187b7a
|
|||
|
1b2bdf0b79
|
|||
| c6cc38bfeb | |||
|
4d3e3f44a1
|
|||
|
ec80b1e92e
|
|||
|
4ea15519f3
|
|||
|
7a6e01cff8
|
|||
|
733c8736b8
|
|||
|
20173a0b07
|
|||
|
a143972ef1
|
|||
|
0c70048b62
|
|||
|
1c0c917873
|
|||
|
1f6189daa4
|
|||
|
66b585c3d6
|
|||
|
819ab3c2bf
|
|||
|
d8c0b17512
|
|||
|
6e06f9078e
|
|||
|
ece2e3a6a3
|
|||
|
74c07c9afb
|
|||
|
be2fd4c837
|
|||
|
1bc4c704c3
|
|||
|
0288a05901
|
|||
|
b14f46d405
|
|||
|
8e8ed62266
|
|||
|
2fce2f4bfc
|
|||
|
640f2d1771
|
|||
|
b48dfe5301
|
|||
|
0d5840a4ce
|
|||
|
3c92f0867d
|
|||
|
b5acae4078
|
|||
|
5d20f8ec3e
|
|||
|
955c2233ed
|
|||
|
ff69b65171
|
|||
|
8df01afd8c
|
|||
|
47b2dfdd73
|
|||
|
bd4d793ce0
|
|||
|
f7a36f61b6
|
|||
|
ad2fabf471
|
|||
|
a59a58d21a
|
|||
| 3260ae4a1e | |||
|
bd1c9581c7
|
|||
|
663642ea6c
|
|||
|
e2abc04fe4
|
|||
|
a4016b55ce
|
|||
|
1ea5da7024
|
|||
|
a017a8cf1f
|
|||
|
8fc5ab623e
|
|||
|
14007db846
|
|||
|
6ad2ce4b68
|
|||
|
9a276c34c7
|
|||
|
6e717a3f9e
|
|||
|
77aadfa264
|
|||
| c81287df7f | |||
|
ffccc1bedd
|
|||
|
d14f208897
|
|||
|
293953a078
|
|||
|
bccc96e4d0
|
|||
|
9db56adf56
|
|||
|
3f99563ac8
|
|||
|
b36896cc7b
|
|||
|
cb75878ae9
|
|||
|
a5fe985eb2
|
|||
|
e324f414e6
|
|||
|
256536562f
|
|||
|
64f4314f0d
|
|||
|
6f6245d283
|
|||
|
3392bc347d
|
|||
|
7e0319906a
|
|||
|
75bd203d4a
|
|||
|
db40198357
|
|||
|
d79e1dee18
|
|||
|
4ea400265c
|
|||
|
24bffdabd4
|
|||
|
d7bb6326de
|
|||
|
dbf6f9e2db
|
|||
|
3cdc9031d3
|
|||
|
00e2ca8fe3
|
|||
|
4efb01285c
|
|||
|
f84a19159f
|
|||
|
946b2e0d2e
|
|||
|
08dd7408ec
|
|||
|
b33fadf768
|
|||
|
7219109e5d
|
|||
|
cdf1725c26
|
|||
|
7074b074bc
|
|||
|
ede7272c09
|
|||
|
87d5e286d2
|
|||
|
c91b206791
|
|||
|
a31d295eb1
|
|||
|
0d20993f02
|
|||
|
5357ca8e58
|
|||
|
556765fd35
|
|||
| d039a8e4b3 | |||
|
c4533421eb
|
|||
|
73769b42c1
|
|||
|
087f6b4669
|
|||
| d582df5927 | |||
|
6a0401833c
|
|||
|
e15607b763
|
|||
|
e28f324a85
|
|||
|
31e696c938
|
|||
|
759b416bf3
|
|||
|
4b2b0fe476
|
|||
|
4c39504750
|
|||
|
f9f3ade6c7
|
|||
|
386018b956
|
|||
|
bd47d33355
|
|||
|
93ddb28802
|
|||
| f7c43837b5 | |||
|
32ed62a6f1
|
|||
|
66f39acec0
|
|||
|
6c04e2fee4
|
|||
| 2bb2e0a684 | |||
|
5630320d21
|
|||
|
9f05ba3224
|
|||
|
5fbe965919
|
117
assets/icon.svg
Normal file
117
assets/icon.svg
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||||
|
<!-- Created with Inkscape (http://www.inkscape.org/) -->
|
||||||
|
|
||||||
|
<svg
|
||||||
|
width="128"
|
||||||
|
height="128"
|
||||||
|
viewBox="0 0 128 128"
|
||||||
|
version="1.1"
|
||||||
|
id="svg1"
|
||||||
|
inkscape:export-filename="logo.png"
|
||||||
|
inkscape:export-xdpi="96"
|
||||||
|
inkscape:export-ydpi="96"
|
||||||
|
inkscape:version="1.4.4 (1:1.4.4+202605061436+dcaf3e7d9e)"
|
||||||
|
sodipodi:docname="logo.svg"
|
||||||
|
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||||
|
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
|
||||||
|
xmlns:xlink="http://www.w3.org/1999/xlink"
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
xmlns:svg="http://www.w3.org/2000/svg">
|
||||||
|
<sodipodi:namedview
|
||||||
|
id="namedview1"
|
||||||
|
pagecolor="#ffffff"
|
||||||
|
bordercolor="#000000"
|
||||||
|
borderopacity="0.25"
|
||||||
|
inkscape:showpageshadow="2"
|
||||||
|
inkscape:pageopacity="0.0"
|
||||||
|
inkscape:pagecheckerboard="0"
|
||||||
|
inkscape:deskcolor="#d1d1d1"
|
||||||
|
inkscape:document-units="mm"
|
||||||
|
showgrid="true"
|
||||||
|
inkscape:zoom="1.9332778"
|
||||||
|
inkscape:cx="-8.2760999"
|
||||||
|
inkscape:cy="112.2446"
|
||||||
|
inkscape:window-width="2584"
|
||||||
|
inkscape:window-height="1028"
|
||||||
|
inkscape:window-x="0"
|
||||||
|
inkscape:window-y="24"
|
||||||
|
inkscape:window-maximized="1"
|
||||||
|
inkscape:current-layer="layer1">
|
||||||
|
<inkscape:grid
|
||||||
|
id="grid1"
|
||||||
|
units="px"
|
||||||
|
originx="0"
|
||||||
|
originy="0"
|
||||||
|
spacingx="4"
|
||||||
|
spacingy="4"
|
||||||
|
empcolor="#0099e5"
|
||||||
|
empopacity="0.30196078"
|
||||||
|
color="#0099e5"
|
||||||
|
opacity="0.14901961"
|
||||||
|
empspacing="4"
|
||||||
|
enabled="true"
|
||||||
|
visible="true" />
|
||||||
|
</sodipodi:namedview>
|
||||||
|
<defs
|
||||||
|
id="defs1">
|
||||||
|
<linearGradient
|
||||||
|
inkscape:collect="always"
|
||||||
|
xlink:href="#linearGradient4689"
|
||||||
|
id="linearGradient1478"
|
||||||
|
gradientUnits="userSpaceOnUse"
|
||||||
|
gradientTransform="matrix(0.562541,0,0,0.567972,-9.399749,-5.305317)"
|
||||||
|
x1="26.648937"
|
||||||
|
y1="20.603781"
|
||||||
|
x2="135.66525"
|
||||||
|
y2="114.39767" />
|
||||||
|
<linearGradient
|
||||||
|
id="linearGradient4689">
|
||||||
|
<stop
|
||||||
|
style="stop-color:#e1be1e;stop-opacity:1;"
|
||||||
|
offset="0"
|
||||||
|
id="stop4691" />
|
||||||
|
<stop
|
||||||
|
style="stop-color:#ffeb82;stop-opacity:1;"
|
||||||
|
offset="1"
|
||||||
|
id="stop4693" />
|
||||||
|
</linearGradient>
|
||||||
|
<linearGradient
|
||||||
|
inkscape:collect="always"
|
||||||
|
xlink:href="#linearGradient4671"
|
||||||
|
id="linearGradient1475"
|
||||||
|
gradientUnits="userSpaceOnUse"
|
||||||
|
gradientTransform="matrix(0.562541,0,0,0.567972,-9.399749,-5.305317)"
|
||||||
|
x1="150.96111"
|
||||||
|
y1="192.35176"
|
||||||
|
x2="112.03144"
|
||||||
|
y2="137.27299" />
|
||||||
|
<linearGradient
|
||||||
|
id="linearGradient4671">
|
||||||
|
<stop
|
||||||
|
style="stop-color:#ffdc21;stop-opacity:1;"
|
||||||
|
offset="0"
|
||||||
|
id="stop4673" />
|
||||||
|
<stop
|
||||||
|
style="stop-color:#ffeb82;stop-opacity:1;"
|
||||||
|
offset="1"
|
||||||
|
id="stop4675" />
|
||||||
|
</linearGradient>
|
||||||
|
</defs>
|
||||||
|
<g
|
||||||
|
inkscape:label="Calque 1"
|
||||||
|
inkscape:groupmode="layer"
|
||||||
|
id="layer1">
|
||||||
|
<g
|
||||||
|
id="g1"
|
||||||
|
transform="translate(2.911719,3.414527)">
|
||||||
|
<path
|
||||||
|
style="fill:url(#linearGradient1478);fill-opacity:1"
|
||||||
|
d="m 60.510156,6.3979729 c -4.583653,0.021298 -8.960939,0.4122177 -12.8125,1.09375 C 36.35144,9.4962267 34.291407,13.691825 34.291406,21.429223 v 10.21875 h 26.8125 v 3.40625 h -26.8125 -10.0625 c -7.792459,0 -14.6157592,4.683717 -16.7500002,13.59375 -2.46182,10.212966 -2.5710151,16.586023 0,27.25 1.9059283,7.937852 6.4575432,13.593748 14.2500002,13.59375 h 9.21875 v -12.25 c 0,-8.849902 7.657144,-16.656248 16.75,-16.65625 h 26.78125 c 7.454951,0 13.406253,-6.138164 13.40625,-13.625 v -25.53125 c 0,-7.266339 -6.12998,-12.7247775 -13.40625,-13.9375001 -4.605987,-0.7667253 -9.385097,-1.1150483 -13.96875,-1.09375 z m -14.5,8.2187501 c 2.769547,0 5.03125,2.298646 5.03125,5.125 -2e-6,2.816336 -2.261703,5.09375 -5.03125,5.09375 -2.779476,-1e-6 -5.03125,-2.277415 -5.03125,-5.09375 -1e-6,-2.826353 2.251774,-5.125 5.03125,-5.125 z"
|
||||||
|
id="path1948" />
|
||||||
|
<path
|
||||||
|
style="fill:url(#linearGradient1475);fill-opacity:1"
|
||||||
|
d="m 91.228906,35.054223 v 11.90625 c 0,9.230755 -7.825895,16.999999 -16.75,17 h -26.78125 c -7.335833,0 -13.406249,6.278483 -13.40625,13.625 v 25.531247 c 0,7.26634 6.318588,11.54032 13.40625,13.625 8.487331,2.49561 16.626237,2.94663 26.78125,0 6.750155,-1.95439 13.406253,-5.88761 13.40625,-13.625 V 92.897973 h -26.78125 v -3.40625 h 26.78125 13.406254 c 7.79246,0 10.69625,-5.435408 13.40624,-13.59375 2.79933,-8.398886 2.68022,-16.475776 0,-27.25 -1.92578,-7.757441 -5.60387,-13.59375 -13.40624,-13.59375 z m -15.0625,64.65625 c 2.779478,3e-6 5.03125,2.277417 5.03125,5.093747 -2e-6,2.82635 -2.251775,5.125 -5.03125,5.125 -2.76955,0 -5.03125,-2.29865 -5.03125,-5.125 2e-6,-2.81633 2.261697,-5.093747 5.03125,-5.093747 z"
|
||||||
|
id="path1950" />
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 4.7 KiB |
809
docs/manual.typ
Normal file
809
docs/manual.typ
Normal file
@@ -0,0 +1,809 @@
|
|||||||
|
//#import "@preview/codly:1.3.0": codly, codly-init
|
||||||
|
// Fix unaligned highlights in v0.15.0 ()
|
||||||
|
// See https://github.com/Dherse/codly/pull/132
|
||||||
|
#import "@local/codly:1.3.1": codly, codly-init
|
||||||
|
|
||||||
|
#import "@preview/codly-languages:0.1.10": codly-languages
|
||||||
|
#import "template.typ": TODO, project
|
||||||
|
#import "@preview/gentle-clues:1.3.1" as gc
|
||||||
|
|
||||||
|
#let midas-version = toml("../pyproject.toml").project.version
|
||||||
|
#let head-ref = read("../.git/HEAD").split(":").at(1).trim()
|
||||||
|
#let commit-hash = read("../.git/" + head-ref).slice(0, 8)
|
||||||
|
|
||||||
|
#show: project.with(
|
||||||
|
title: [Midas User Manual],
|
||||||
|
author: "Louis Heredero",
|
||||||
|
version: midas-version,
|
||||||
|
hash: commit-hash,
|
||||||
|
icon-path: path("../assets/icon.svg"),
|
||||||
|
)
|
||||||
|
|
||||||
|
#show: codly-init
|
||||||
|
#codly(
|
||||||
|
languages: codly-languages
|
||||||
|
+ (
|
||||||
|
midas: (
|
||||||
|
name: "Midas",
|
||||||
|
color: rgb("#eedd47"),
|
||||||
|
icon: box(
|
||||||
|
image(
|
||||||
|
"../assets/icon.svg",
|
||||||
|
height: 130%,
|
||||||
|
fit: "contain",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
= Introduction
|
||||||
|
|
||||||
|
Python is a very popular programming language, especially in data sciences.
|
||||||
|
However, it has been designed for simplicity, distancing itself from typed languages such as Java or C to embrace dynamic typing.
|
||||||
|
What this means is that in Python, type checks are deferred to runtime when operations are concretely executed.
|
||||||
|
|
||||||
|
For developers, it might seem like a great way of simplifying the language and making it very flexible, but it does come with a cost.
|
||||||
|
Indeed, type errors are very easy to make in Python. While passing an integer where a string is expected might not be an issue in some cases, these are the sort of thing that can cause crashes or incorrect results without a clear diagnostic to help the user fix it.
|
||||||
|
|
||||||
|
Fortunately, developers using IDEs or properly configured text editors can benefit from external type checkers such as MyPy which will perform static type analysis of their Python code. Some can also be configured to be very strict, forcing the user to make the whole code typeable statically, thus avoiding any runtime type errors.
|
||||||
|
|
||||||
|
This is not the end of the problem though. Some parts of a program, especially in data related fields, may not be available at "compile-time". For example, a dataset can be loaded from an external file, or data can be fetched from an API, with no guarantees of having the expected format when analyzing the code statically.
|
||||||
|
|
||||||
|
In turn, that can cause a range of loud and silent errors at runtime. A malformed number will probably crash the program when trying to convert it, but a NaN in a series of value might just produce wrong results without any exception. Combine this with often long-running data-processing pipelines and this is how developers can waste hours of precious computation time.
|
||||||
|
|
||||||
|
Midas is a type system which can be used on top of Python to provide better type checking capabilities and gradual typing.
|
||||||
|
It aims at providing optional but strict type annotations and casting operations which can produce runtime assertions. It also allows the user to define dependent types with value constraints that are translated into runtime checks.
|
||||||
|
|
||||||
|
= Installation
|
||||||
|
|
||||||
|
Midas comes as a very light Python package that you can install on your system in a few simple steps.
|
||||||
|
|
||||||
|
== Requirements
|
||||||
|
|
||||||
|
Here below are the requirements for installing Midas. All Python dependencies will be installed by `uv` in the installation process described in @install-steps.
|
||||||
|
|
||||||
|
- Python 3.11+
|
||||||
|
- `uv`
|
||||||
|
|
||||||
|
== Steps <install-steps>
|
||||||
|
|
||||||
|
1. Clone the repository
|
||||||
|
```bash
|
||||||
|
git clone https://git.kb28.ch/HEL/midas.git
|
||||||
|
```
|
||||||
|
2. Navigate inside the directory
|
||||||
|
```bash
|
||||||
|
cd midas
|
||||||
|
```
|
||||||
|
3. Install Midas as a tool in your local user space
|
||||||
|
```bash
|
||||||
|
uv tool install .
|
||||||
|
```
|
||||||
|
|
||||||
|
And that's it ! You can now use Midas commands anywhere, like this:
|
||||||
|
```bash
|
||||||
|
midas --help
|
||||||
|
```
|
||||||
|
|
||||||
|
= Quick Start
|
||||||
|
|
||||||
|
This chapter will give you the keys to quickly start using Midas in your project.
|
||||||
|
|
||||||
|
== Defining custom types
|
||||||
|
|
||||||
|
To begin with, you might want to define some custom types for your project, to avoid handling anonymous float values everywhere. To do so, create a `*.midas` file in your project, and write some definitions for your types. See @midas-ref for more information on syntax and features.
|
||||||
|
|
||||||
|
@qs-midas shows a simple example of what it might look like.
|
||||||
|
|
||||||
|
#codly(header: [types.midas])
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
type Meter = float
|
||||||
|
extend Meter {
|
||||||
|
def __add__: fn(Meter, /) -> Meter
|
||||||
|
def __sub__: fn(Meter, /) -> Meter
|
||||||
|
}
|
||||||
|
|
||||||
|
type Coordinate = object
|
||||||
|
extend Coordinate {
|
||||||
|
prop x: Meter
|
||||||
|
prop y: Meter
|
||||||
|
}
|
||||||
|
```,
|
||||||
|
caption: [Example Midas type definitions],
|
||||||
|
) <qs-midas>
|
||||||
|
|
||||||
|
You can check for any syntax error using the following command:
|
||||||
|
```bash
|
||||||
|
midas validate types.midas
|
||||||
|
```
|
||||||
|
|
||||||
|
When you are happy with your definitions, you can generate Python stubs to use in your source code. This allows other type checkers like MyPy to recognize your custom types and avoid reporting them as undefined. It can also help catch some type errors in your IDE.
|
||||||
|
```bash
|
||||||
|
midas stubs types.midas -o stubs.pyi
|
||||||
|
```
|
||||||
|
|
||||||
|
This command will generate a file as shown in @qs-stubs, providing stub classes to represent the type lattice including methods and properties.
|
||||||
|
|
||||||
|
#codly(header: [stubs.pyi])
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
```pyi
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
class Meter(float):
|
||||||
|
def __add__(self, _0: Meter, /) -> Meter: ...
|
||||||
|
def __sub__(self, _0: Meter, /) -> Meter: ...
|
||||||
|
|
||||||
|
class Coordinate(object):
|
||||||
|
x: Meter
|
||||||
|
y: Meter
|
||||||
|
```,
|
||||||
|
caption: [Generated stubs from example definitions of @qs-midas],
|
||||||
|
) <qs-stubs>
|
||||||
|
|
||||||
|
== Using Midas in Python
|
||||||
|
|
||||||
|
You can now write your Python program as you would normally. You can import your custom types from the generated stubs file and use them in type annotations.
|
||||||
|
|
||||||
|
You can also import the `cast` and `unsafe_cast` functions from `midas.typing` to explicitly cast a value to a specific type (see @cast for more information).
|
||||||
|
|
||||||
|
An example Python script is shown in @qs-python, demonstrating how you can use custom types in type annotations. Notice the comments describing errors that will be caught by the type checker in @qs-type-checking.
|
||||||
|
|
||||||
|
#codly(header: [script.py])
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
```python
|
||||||
|
from lib import load_coordinate
|
||||||
|
from midas.typing import cast
|
||||||
|
from stubs import Coordinate, Meter
|
||||||
|
|
||||||
|
p1 = cast(Coordinate, load_coordinate(0))
|
||||||
|
p2 = cast(Coordinate, load_coordinate(1))
|
||||||
|
|
||||||
|
diff_x = p2.x - p1.x
|
||||||
|
diff_y = p2.y - p1.y
|
||||||
|
|
||||||
|
dist = diff_x + diff_y
|
||||||
|
|
||||||
|
p2.x += cast(Meter, 1)
|
||||||
|
p2.y = True # invalid, wrong type
|
||||||
|
p2.z = 3 # invalid, no property 'z' on Coordinate
|
||||||
|
p2.x.a = 3 # invalid, no properties on Meter
|
||||||
|
```,
|
||||||
|
caption: [Example Python script],
|
||||||
|
) <qs-python>
|
||||||
|
|
||||||
|
== Type checking <qs-type-checking>
|
||||||
|
|
||||||
|
Now that you have defined some types and written a script, you can run the type checker with the following command. You can also skip this step and directly run the compilation command in @qs-compilation.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
midas check -t types.midas script.py
|
||||||
|
```
|
||||||
|
|
||||||
|
== Compiling <qs-compilation>
|
||||||
|
|
||||||
|
The final step is to compile your code. This step will produce a runnable Python script, including runtime assertions generated by `cast` expressions.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
midas compile -t types.midas script.py
|
||||||
|
python3 build/midas/script.py
|
||||||
|
```
|
||||||
|
|
||||||
|
= Midas Language Reference <midas-ref>
|
||||||
|
|
||||||
|
In this chapter, you will find a complete reference for the Midas definition language.
|
||||||
|
|
||||||
|
A `*.midas` file contains a number of statements, which can be:
|
||||||
|
- *`alias`* statements (see @alias-stmt): to define a new type alias
|
||||||
|
- *`type`* statements (see @type-stmt): to define a new type
|
||||||
|
- *`extend`* statements (see @extend-stmt): to define member of a type
|
||||||
|
- *`predicate`* statements (see @predicate-stmt): to define named predicates that can be used in constraint types
|
||||||
|
|
||||||
|
== Alias Statement <alias-stmt>
|
||||||
|
|
||||||
|
An *`alias`* statement lets you define a new type alias. It requires a unique name and base type.
|
||||||
|
|
||||||
|
While a `type` statement (see @type-stmt) allows generic definitions, aliases are purely a for givin an alternative name to a type.
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
alias MyType = float
|
||||||
|
```,
|
||||||
|
caption: [Simple `alias` statement declaring a new type "`MyType`" equivalent to `float`],
|
||||||
|
) <midas-simple-alias>
|
||||||
|
|
||||||
|
This statement defines a new type called `MyType` which is equivalent to `float`. `MyType` and `float` can be used interchangeably.
|
||||||
|
|
||||||
|
== Type Statement <type-stmt>
|
||||||
|
|
||||||
|
A *`type`* statement lets you define a new type. It requires a unique name and base type.
|
||||||
|
|
||||||
|
The simplest form of a *`type`* statement is:
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
type MyType = float
|
||||||
|
```,
|
||||||
|
caption: [Simple `type` statement declaring a new type "`MyType`" as a subtype of `float`],
|
||||||
|
) <midas-simple-type>
|
||||||
|
|
||||||
|
This statement defines a new type called `MyType` which is a subtype of `float`. `MyType` is a `float` but a `float` is not necessarily `MyType`.
|
||||||
|
|
||||||
|
=== Builtin / base types
|
||||||
|
|
||||||
|
A number of base types are provided out of the box, which can be used to derive other types.
|
||||||
|
|
||||||
|
They correspond to Python's builtin types:
|
||||||
|
```py object```,
|
||||||
|
```py str```,
|
||||||
|
```py float```,
|
||||||
|
```py int```,
|
||||||
|
```py bool```,
|
||||||
|
```py list```,
|
||||||
|
```py dict```,
|
||||||
|
```py None```.
|
||||||
|
|
||||||
|
Some differences are to be noted however.
|
||||||
|
1. ```py bool``` is not a subtype of ```py int```
|
||||||
|
2. ```py list``` are homogeneous, i.e. all items must be of the same type
|
||||||
|
3. ```py dict``` keys and values are homogeneous, i.e. all keys must be of the same type and all values must be of the same type (can be different from keys).
|
||||||
|
|
||||||
|
=== Function types
|
||||||
|
|
||||||
|
A function type is written in a similar notation to Python function definitions:
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
type Repeater = fn(text: str, count: int) -> str
|
||||||
|
```,
|
||||||
|
caption: [Simple function type definition],
|
||||||
|
)
|
||||||
|
|
||||||
|
Midas supports positional-only, keyword-only and mixed arguments (using the `/` and `*` separators). You may omit the name of positional-only arguments. The return type is required.
|
||||||
|
|
||||||
|
Optional parameters can be indicated by adding a question mark (`?`) after their type:
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
type Repeater = fn(text: str, count: int, *, sep: str?) -> str
|
||||||
|
```,
|
||||||
|
caption: [Function type definition with an optional keyword-only parameter],
|
||||||
|
)
|
||||||
|
|
||||||
|
#gc.warning[
|
||||||
|
Sink arguments (`*args`, `**kwargs`) are not currently supported.
|
||||||
|
]
|
||||||
|
|
||||||
|
=== Constraint types
|
||||||
|
|
||||||
|
A useful feature provided by Midas is the possibility to combine types with custom value constraints. For example, you might want to define a type for positive amounts of money:
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
type Money = float
|
||||||
|
type Income = Money where _ >= 0
|
||||||
|
```,
|
||||||
|
caption: [Simple constraint type definition],
|
||||||
|
)
|
||||||
|
|
||||||
|
Constraints can be combined with any type using the `where` keyword, followed by a constraint expression (see @constraint-expr).
|
||||||
|
|
||||||
|
=== Generic types
|
||||||
|
|
||||||
|
For more complex types, you might want to use type parameters. For example, to define a container, we might write:
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
type Container[T] = object
|
||||||
|
```,
|
||||||
|
caption: [Simple generic container type definition],
|
||||||
|
)
|
||||||
|
|
||||||
|
To better refine a generic type, you can also bound type parameters using the following syntax:
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
type Container[T <: float] = object
|
||||||
|
```,
|
||||||
|
caption: [Generic container type definition with a bound],
|
||||||
|
)
|
||||||
|
|
||||||
|
This can be read as "`Container` is a generic type which takes one type parameter `T` that must be a subtype of `float`".\
|
||||||
|
You can use a generic type, i.e. instantiate it, by using a similar syntax with concrete type as arguments:
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
type MyContainer = Container[MyType]
|
||||||
|
```,
|
||||||
|
caption: [Application of a generic type],
|
||||||
|
)
|
||||||
|
|
||||||
|
Generic types can also take multiple parameters, which are then separated by commas:
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
type ZipCodeRegistry = dict[int, str]
|
||||||
|
```,
|
||||||
|
caption: [Application of a multi-parameter generic type],
|
||||||
|
)
|
||||||
|
|
||||||
|
The _body_ of a generic type, i.e. the right-hand side of the definition, can contain or even be equal to any number of its parameters.#footnote[The latter is not something that is expressible in standard Python, yet it brings a semantic distinction on top of structurally equivalent values.] For example, the following is a valid type statement:
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
type Price[T <: Currency] = T where _ > 0
|
||||||
|
```,
|
||||||
|
caption: [Type parameters in a generic type's body],
|
||||||
|
)
|
||||||
|
|
||||||
|
=== `Column` / `Frame` types
|
||||||
|
|
||||||
|
To provide useful type-checking for data engineers, Midas offers two special types: `Column` and `Frame`.
|
||||||
|
Their goal is to help type check Pandas' `Series` and `DataFrame` respectively.
|
||||||
|
|
||||||
|
==== `Column`
|
||||||
|
|
||||||
|
The `Column` type is a generic type used to represent a `pandas.Series` object.
|
||||||
|
You can use it like any other generic type and it will provide type checking for some common methods and attributes offered by Pandas.
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
type Temperature = float
|
||||||
|
alias Temperatures = Column[Temperature]
|
||||||
|
```,
|
||||||
|
caption: [Simple column type definition],
|
||||||
|
)
|
||||||
|
|
||||||
|
==== `Frame` <frame-type>
|
||||||
|
|
||||||
|
The `Frame` type is a super-powered generic type used to represent a `pandas.DataFrame` object.
|
||||||
|
In place of type arguments, `Frame` accepts a schema, i.e. a series of column definitions.
|
||||||
|
@simple-frame show how you can define a simple frame type with 3 columns:
|
||||||
|
- `name`: a column of `Name` values
|
||||||
|
- `age`: a column of `int` values
|
||||||
|
- `height`: a column of `float where _ >= 0` values
|
||||||
|
|
||||||
|
Notice that you don't need to specify `Column` types.
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
type Name = str where len(_) != 0
|
||||||
|
alias Data = Frame[
|
||||||
|
name: Name,
|
||||||
|
age: int,
|
||||||
|
height: float where _ >= 0
|
||||||
|
]
|
||||||
|
```,
|
||||||
|
) <simple-frame>
|
||||||
|
|
||||||
|
#pagebreak()
|
||||||
|
|
||||||
|
== Extend Statement <extend-stmt>
|
||||||
|
|
||||||
|
Type statements allow you to define new types, kind of like type aliases. However, a type might have properties or methods of its own. These might override those of the parent type or be brand new members.
|
||||||
|
|
||||||
|
This is where the `extend` statement comes into play. It allows defining members on a given type. Members can either be properties (`prop`) or methods (`def`). The only difference between the two is that methods must be functions and can be overloaded.
|
||||||
|
|
||||||
|
Here is a simple example showing how to define a property and a method on a custom type:
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
type MyType = float
|
||||||
|
extend MyType {
|
||||||
|
prop norm: float
|
||||||
|
def double: fn() -> MyType
|
||||||
|
}
|
||||||
|
```,
|
||||||
|
caption: [Simple `extend` statement defining a property and a method],
|
||||||
|
)
|
||||||
|
|
||||||
|
An `extend` statement can appear anywhere after the type it extends has been defined.
|
||||||
|
You may want to override Python's dunder methods to implement type checking for some basic operators, like `__add__` for the `+` operator.
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
type Money = float
|
||||||
|
extend Money {
|
||||||
|
def __add__(Money, /) -> Money
|
||||||
|
def __mul__(float, /) -> Money
|
||||||
|
}
|
||||||
|
```,
|
||||||
|
caption: [Simple `extend` statement overriding some dunder methods],
|
||||||
|
)
|
||||||
|
|
||||||
|
When extending generic type, you must specify the whole type, including its parameter(s):
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
type Container[T <: float] = object
|
||||||
|
extend Container[T <: float] {
|
||||||
|
prop content: T
|
||||||
|
def set_content: fn(content: T) -> None
|
||||||
|
}
|
||||||
|
```,
|
||||||
|
caption: [Generic `extend` statement using type parameters in the declared members],
|
||||||
|
)
|
||||||
|
|
||||||
|
#pagebreak()
|
||||||
|
|
||||||
|
== Predicate Statement <predicate-stmt>
|
||||||
|
|
||||||
|
A *`predicate`* statement lets you define a named constraint expression, like a function, which can then be used in other constraint expressions (either in other predicate statements or in constraint types). See @constraint-expr for more information about the syntax of constraint expressions.
|
||||||
|
|
||||||
|
The left-hand side of a predicate statement is written as a function signature, without a return type. The right-hand side is a constraint expression. For example:
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
predicate is_positive(v: float) = v >= 0
|
||||||
|
```,
|
||||||
|
caption: [Simple `predicate` statement defining an `is_positive` predicate],
|
||||||
|
)
|
||||||
|
|
||||||
|
The left-hand side can also be curried to allow partial application. For example:
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
predicate in_range(mn: float, mx: float)(v: float) = mn <= v & v <= mx
|
||||||
|
predicate is_ratio = in_range(0.0, 1.0)
|
||||||
|
```,
|
||||||
|
caption: [Curried `predicate` statement and partial application],
|
||||||
|
) <midas-predicate-partial>
|
||||||
|
|
||||||
|
Notice that the second predicate statement doesn't take any parameters. This is simply a partial application of another predicate, kind of like an alias. You can use it in other expressions to finalize the call:
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
type Efficiency = float where is_ratio(_)
|
||||||
|
```,
|
||||||
|
caption: [Constraint type definition using the partially applied predicate from @midas-predicate-partial],
|
||||||
|
)
|
||||||
|
|
||||||
|
Of course you can also directly call `in_range`:
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
type Efficiency = float where in_range(0.0, 1.0)(_)
|
||||||
|
```,
|
||||||
|
caption: [Full call of curried predicate from @midas-predicate-partial],
|
||||||
|
)
|
||||||
|
|
||||||
|
When compiled, named predicates are translated to Python functions which are used in runtime assertions. Only predicates that are referenced are compiled.
|
||||||
|
|
||||||
|
#pagebreak()
|
||||||
|
== Constraint Expressions <constraint-expr>
|
||||||
|
|
||||||
|
*Constraint expressions* are Python-like expressions which can appear in *`predicate`* statements or in constraint types.
|
||||||
|
|
||||||
|
They can contain comparisons, simple computations, logical operations and must evaluate to a boolean value.
|
||||||
|
|
||||||
|
Context is quite restricted inside these expressions. You can only reference some builtin functions, such as type constructors (`float(...)`, `str(...)`, etc.), parameters of predicate statements, and named predicates. In constraint type, the special variable `_` can be used to reference the value targeted by the type. For example:
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
```midas
|
||||||
|
predicate not_nan(v: float) = v != float("nan")
|
||||||
|
type RealFloat = float where not_nan(_)
|
||||||
|
```,
|
||||||
|
caption: [Example constraint expressions],
|
||||||
|
) <ex-constraint-expr>
|
||||||
|
|
||||||
|
In the predicate statement (@ex-constraint-expr:1), we reference the parameter `v` and the builtin `float` function.
|
||||||
|
In the constraint type definition (@ex-constraint-expr:2), we then reference the named predicate `not_nan`, passing the value targeted by the type itself ( `_` )
|
||||||
|
|
||||||
|
= Supported Python Syntax <python-ref>
|
||||||
|
|
||||||
|
Midas integrates naturally in Python via type annotations. Through generated stubs, even other type checker can detect your custom types (see @cmd-stubs).
|
||||||
|
|
||||||
|
It has been designed to leave the user free of typing any amount of their code but be strict about the parts that are annotated. By default, any untyped Python expression is assigned `UnknownType`.
|
||||||
|
|
||||||
|
Any operation is permitted on `UnknownType` and will result in `UnknownType` values.
|
||||||
|
|
||||||
|
The moment an expression can be typed, that be thanks to an annotation or a literal value, the type checker kicks in and will validate your statements.
|
||||||
|
|
||||||
|
Because Python is very flexible language with many features, some expressions and statements might be more complex to properly type check, thus only a subset of the Python language is fully supported. This chapter lists all supported features of Python and how they affect type checking.
|
||||||
|
|
||||||
|
Some examples are presented in the following sections in the form of code blocks. Highlights in the code blocks indicate the type assigned to each expression by the type checker. Some types may be omitted for readability. For example:
|
||||||
|
|
||||||
|
#codly(
|
||||||
|
highlights: (
|
||||||
|
(
|
||||||
|
line: 1,
|
||||||
|
start: 5,
|
||||||
|
fill: green,
|
||||||
|
tag: [_int_],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
line: 2,
|
||||||
|
start: 7,
|
||||||
|
end: 7,
|
||||||
|
fill: green,
|
||||||
|
tag: [_int_],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
```python
|
||||||
|
v = 3
|
||||||
|
print(v)
|
||||||
|
```
|
||||||
|
|
||||||
|
== Literals
|
||||||
|
|
||||||
|
Literal Python values are type checked using builtin types. Lists and dictionaries of literals are also typed liked literals. This does not include comprehension lists/dicts (```py [. for . in .]```), nor formatted strings (```py f"..."```). @supported-literals shows the list of supported literal values and their type.
|
||||||
|
|
||||||
|
#let supported-literals = table(
|
||||||
|
columns: 2,
|
||||||
|
table.header[*Example value*][*Judged Type*],
|
||||||
|
```py 42```, ```py int```,
|
||||||
|
```py 3.14```, ```py float```,
|
||||||
|
```py True```, ```py bool```,
|
||||||
|
```py "Midas"```, ```py str```,
|
||||||
|
```py None```, ```py None```,
|
||||||
|
```py [1, 2, 3]```, ```py list[int]```,
|
||||||
|
```py {1: "One", 2: "Two"}```, ```py dict[int, str]```,
|
||||||
|
```py ("1", 1, True)```, ```py tuple[str, int, bool]```,
|
||||||
|
)
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
supported-literals,
|
||||||
|
caption: [Supported literal values and their judged types],
|
||||||
|
) <supported-literals>
|
||||||
|
|
||||||
|
== Assignments
|
||||||
|
|
||||||
|
Variable assignments allow assigning a new value to a variable. For the type checker, this implies two things:
|
||||||
|
1. If the variable was not already declared in the current scope, it is declared at that point with the type of the right-hand side expression
|
||||||
|
2. If the variable was already declared, the type of the right-hand side expression is checked against the declared type of the variable. Only a subtype of the variable's type can be assigned to it
|
||||||
|
|
||||||
|
Once a variable has been given a type, it cannot be changed in the same scope.
|
||||||
|
|
||||||
|
The walrus operator (```py :=```) is not currently supported.
|
||||||
|
|
||||||
|
A simple annotation declaration, without assigning a value, is enough to declare a variable. For example:
|
||||||
|
#figure(
|
||||||
|
```python
|
||||||
|
var: float
|
||||||
|
```,
|
||||||
|
caption: [Bare Python variable annotation without assignment],
|
||||||
|
)
|
||||||
|
|
||||||
|
Because unpacking is not supported, assigning to multiple values is also not handled by the type checker.
|
||||||
|
For more information about type annotations, see @type-annotations
|
||||||
|
|
||||||
|
== Arithmetic
|
||||||
|
|
||||||
|
- All basic binary operators are supported, through dunder methods.
|
||||||
|
- All comparison operators except ```py in``` are supported.
|
||||||
|
- All unary operators are supported (`+`, `-`, `~`).
|
||||||
|
- All logical operators are supported (```py and```, ```py or```, ```py not```).
|
||||||
|
|
||||||
|
== Ternary operator
|
||||||
|
|
||||||
|
The ternary operator ```py . if . else .``` is supported. As for `if` statements (see @if-else), the test expression must be a boolean. Additionally, both branches must be of the same type.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
#codly(
|
||||||
|
highlights: (
|
||||||
|
(
|
||||||
|
line: 1,
|
||||||
|
start: 10,
|
||||||
|
end: 44,
|
||||||
|
tag: [_str_],
|
||||||
|
fill: blue,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
line: 1,
|
||||||
|
start: 11,
|
||||||
|
end: 16,
|
||||||
|
tag: [_str_],
|
||||||
|
fill: green,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
line: 1,
|
||||||
|
start: 39,
|
||||||
|
end: 43,
|
||||||
|
tag: [_str_],
|
||||||
|
fill: green,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
line: 1,
|
||||||
|
start: 21,
|
||||||
|
end: 32,
|
||||||
|
tag: [_bool_],
|
||||||
|
fill: green,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
#figure(
|
||||||
|
```python
|
||||||
|
parity = ("even" if num % 2 == 0 else "odd")
|
||||||
|
```,
|
||||||
|
caption: [Typing of ternary operator],
|
||||||
|
)
|
||||||
|
|
||||||
|
== Control flow
|
||||||
|
|
||||||
|
Some control flow features are supported. For the limited code of this project, not all constructs are supported. The following are those currently handled and typ checked by Midas.
|
||||||
|
|
||||||
|
=== `if` / `elif` / `else` <if-else>
|
||||||
|
|
||||||
|
Conditional statements are checked relatively strictly by Midas. The test expression, i.e. what comes after the ```py if``` keyword, must be a boolean. While Python allows introducing and leaking new variables from inside an ```py if``` statement, Midas will strictly forbid leaks by restraining bindings to the scope they are defined in. For example, the following Python code will not compile with Midas:
|
||||||
|
#figure(
|
||||||
|
```python
|
||||||
|
age = 22
|
||||||
|
if age >= 18:
|
||||||
|
msg = "You're an adult"
|
||||||
|
else:
|
||||||
|
msg = "You're still a child"
|
||||||
|
print(msg) # -> unknown variable 'msg'
|
||||||
|
```,
|
||||||
|
caption: [`if`/`else` statement cannot leak variables],
|
||||||
|
)
|
||||||
|
|
||||||
|
=== `for` loops
|
||||||
|
|
||||||
|
Simple forms of `for` loops can be used, that is using a single variable and iterating over an object implementing the `__getitem__` method. Like above in @if-else, leaking variables from inside the loop is ignored.
|
||||||
|
|
||||||
|
`for`-`else` statements are not supported. `while` loops are also not supported.
|
||||||
|
|
||||||
|
== Functions
|
||||||
|
|
||||||
|
You can define functions as usual and the type checker will do its best to type it. Apart from argument sinks (`*args`, `**kwargs`), all forms of parameter specifications are supported (positional-only, keyword-only, mixed, optional).
|
||||||
|
|
||||||
|
As for the rest of your code, type annotations are optional, but recommended. If you omit the return type hint, the type checker will try to infer it from the function body and its return statements. If you did specify a return type, all return paths must return values that are subtypes of the type hint.
|
||||||
|
|
||||||
|
#codly(
|
||||||
|
highlights: (
|
||||||
|
(
|
||||||
|
line: 2,
|
||||||
|
start: 12,
|
||||||
|
end: 16,
|
||||||
|
tag: [_float_],
|
||||||
|
fill: green,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
line: 2,
|
||||||
|
start: 12,
|
||||||
|
tag: [_float_],
|
||||||
|
fill: blue,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
line: 3,
|
||||||
|
start: 10,
|
||||||
|
end: 15,
|
||||||
|
tag: [_(value: float) -> float_],
|
||||||
|
fill: green,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
line: 3,
|
||||||
|
start: 17,
|
||||||
|
end: 19,
|
||||||
|
tag: [_float_],
|
||||||
|
fill: green,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
line: 3,
|
||||||
|
start: 10,
|
||||||
|
tag: [_float_],
|
||||||
|
fill: blue,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
#figure(
|
||||||
|
```python
|
||||||
|
def double(value: float) -> float:
|
||||||
|
return value * 2
|
||||||
|
result = double(4.0)
|
||||||
|
```,
|
||||||
|
caption: [Typing of function's body and call],
|
||||||
|
)
|
||||||
|
|
||||||
|
Anonymous functions (```py lambda```) are not yet supported
|
||||||
|
|
||||||
|
== Casts <cast>
|
||||||
|
|
||||||
|
#gc.info[
|
||||||
|
The functions discussed in this section are provided by the `midas.typing` submodule. You can import them in your script like so:
|
||||||
|
#figure(
|
||||||
|
```python
|
||||||
|
from midas.typing import cast, unsafe_cast
|
||||||
|
```,
|
||||||
|
caption: [Importing cast functions],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
Sometimes, you may want to use a value whose type is not known to the type checker in a place where it expects a particular type. In that case, if you do know that the runtime type will correspond to what is expected, you can use a `cast` expression.
|
||||||
|
|
||||||
|
Similar to the `cast` function from the `typing` package of Python's Standard Library, it allows telling the type checker that a value has a given type. While `typing`'s function doesn't have any runtime side-effect, Midas' will generate runtime assertions, ensuring that your statement is true when running the code. What cannot be checked statically is checked at runtime.
|
||||||
|
|
||||||
|
In the following example, a runtime check would be generated to ensure that the value is indeed a `float` and that it satisfies the type's constraint (i.e. `>= 0`):
|
||||||
|
|
||||||
|
#codly(
|
||||||
|
highlights: (
|
||||||
|
(
|
||||||
|
line: 1,
|
||||||
|
start: 35,
|
||||||
|
end: 47,
|
||||||
|
tag: [_UnknownType_],
|
||||||
|
fill: red,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
line: 2,
|
||||||
|
start: 7,
|
||||||
|
end: 17,
|
||||||
|
tag: [_PositiveFloat_],
|
||||||
|
fill: green,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
#figure(
|
||||||
|
```python
|
||||||
|
typed_value = cast(PositiveFloat, unknown_value)
|
||||||
|
print(typed_value)
|
||||||
|
```,
|
||||||
|
caption: [Typing of `cast` expression],
|
||||||
|
)
|
||||||
|
|
||||||
|
#gc.warning[
|
||||||
|
Assertions are statements inserted just before a statement using a `cast` expression. This means that the expression is evaluated _before_ its actual intended usage location, which might cause issues if you rely on logical operator short-circuiting. See @eager-eval for more information.
|
||||||
|
]
|
||||||
|
|
||||||
|
There may be some cases where the cost of checking a value at runtime is simply not worth the safety, for example when dealing with a big dataset. If do wish so, you can use `unsafe_cast` which will only tell the type checker the type of the value, without generating a runtime assertion. This maps to the default behavior of `typing`'s own `cast` function.
|
||||||
|
|
||||||
|
If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a string, a list of literals, etc.), the assertion is evaluated _at compile-time_ and no runtime assertion is generated.
|
||||||
|
|
||||||
|
== Annotations / Type Hints <type-annotations>
|
||||||
|
|
||||||
|
Vanilla Python already lets you use type hints to specify the type of variables and function parameters.
|
||||||
|
|
||||||
|
Midas use them to type check your code. Additionally, it allows you to use a special syntax to define a `Frame` types directly in these annotations.
|
||||||
|
|
||||||
|
Because these annotations are not interpretable by Python, your integrated type checker might complain loudly about them being invalid.
|
||||||
|
A workaround is to silence it by adding a type comment at the end of the line, as shown in @silence-errors.
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
```python
|
||||||
|
var: Frame[name: str, age: float] # type: ignore # noqa: F821
|
||||||
|
```,
|
||||||
|
caption: [MyPy's and Pylance's complaints about custom type annotation can be silenced with type comments],
|
||||||
|
) <silence-errors>
|
||||||
|
|
||||||
|
=== Frame type annotation
|
||||||
|
|
||||||
|
The syntax is similar to how you can define frame types in the Midas language (see @frame-type). The only difference is that types can only be name references; you cannot inline constraint types.
|
||||||
|
|
||||||
|
The example of @python-frame-type shows how you can annotate a dataframe with some columns directly in Python.
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
```python
|
||||||
|
df: Frame[name: Name, age: float, height: Length[Meter]] = ...
|
||||||
|
```,
|
||||||
|
caption: [Frame type annotation in Python],
|
||||||
|
) <python-frame-type>
|
||||||
|
|
||||||
|
= Commands <commands>
|
||||||
|
|
||||||
|
#TODO
|
||||||
|
|
||||||
|
== Type Checking (`check`) <cmd-check>
|
||||||
|
== Compiling (`compile`) <cmd-compile>
|
||||||
|
== Formatting (`format`) <cmd-format>
|
||||||
|
== Highlighting (`highlight`) <cmd-highlight>
|
||||||
|
== Dumping the AST (`parse`) <cmd-parse>
|
||||||
|
== Dumping the Registry (`dump-registry`) <cmd-registry>
|
||||||
|
== Generating Stubs (`stubs`) <cmd-stubs>
|
||||||
|
== Showing Type Judgements (`types`) <cmd-types>
|
||||||
|
== Validating Definitions (`validate`) <cmd-validate>
|
||||||
|
|
||||||
|
= Known limitations <limitations>
|
||||||
|
|
||||||
|
== Eager evaluation in runtime assertions <eager-eval>
|
||||||
|
|
||||||
|
The process of generating assertions to ensure safety at runtime, mainly for `cast` expressions, leads to the creation of aliases for the expressions being casted. These alias definitions eagerly evaluate before the assertion, and most importantly before the real usage location. This means that you should avoid using `cast` expressions inside logical expressions like `and` or `or`, because the normal "short-circuit" behavior will be irrelevant to the evaluations of the operands.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
```py
|
||||||
|
def foo():
|
||||||
|
print("Foo")
|
||||||
|
return True
|
||||||
|
def bar():
|
||||||
|
print("Bar")
|
||||||
|
return True
|
||||||
|
result = foo() or bar()
|
||||||
|
# Foo
|
||||||
|
# Bar
|
||||||
|
```,
|
||||||
|
caption: [Runtime assertions may eagerly evaluate expressions and bypass logical operator's short-circuit],
|
||||||
|
)
|
||||||
211
docs/midas.sublime-syntax
Normal file
211
docs/midas.sublime-syntax
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
%YAML 1.2
|
||||||
|
---
|
||||||
|
name: Midas
|
||||||
|
file_extensions:
|
||||||
|
- midas
|
||||||
|
scope: source.midas
|
||||||
|
|
||||||
|
variables:
|
||||||
|
identifier: "[a-zA-Z_][a-zA-Z0-9_]*"
|
||||||
|
|
||||||
|
contexts:
|
||||||
|
prototype:
|
||||||
|
- include: comments
|
||||||
|
|
||||||
|
main:
|
||||||
|
- include: keywords
|
||||||
|
- include: types
|
||||||
|
|
||||||
|
comments:
|
||||||
|
- match: "//"
|
||||||
|
scope: punctuation.definition.comment.midas
|
||||||
|
push:
|
||||||
|
- meta_scope: comment.line.midas
|
||||||
|
- match: $
|
||||||
|
pop: true
|
||||||
|
- match: /\*
|
||||||
|
scope: punctuation.definition.comment.midas
|
||||||
|
push:
|
||||||
|
- meta_scope: comment.block.midas
|
||||||
|
- match: \*/
|
||||||
|
pop: true
|
||||||
|
|
||||||
|
string:
|
||||||
|
- meta_include_prototype: false
|
||||||
|
- meta_scope: string.quoted.double.c
|
||||||
|
- match: '"'
|
||||||
|
pop: true
|
||||||
|
|
||||||
|
keywords:
|
||||||
|
- match: \balias\b
|
||||||
|
scope: keyword.declaration.midas
|
||||||
|
push: alias-stmt
|
||||||
|
- match: \btype\b
|
||||||
|
scope: keyword.declaration.midas
|
||||||
|
push: type-stmt
|
||||||
|
- match: \bextend\b
|
||||||
|
scope: keyword.declaration.midas
|
||||||
|
push: extend-stmt
|
||||||
|
- match: \bpredicate\b
|
||||||
|
scope: keyword.declaration.midas
|
||||||
|
push: predicate-stmt
|
||||||
|
|
||||||
|
alias-stmt:
|
||||||
|
- match: "{{identifier}}"
|
||||||
|
scope: entity.name.type
|
||||||
|
- match: "="
|
||||||
|
scope: keyword.operator.equal.midas
|
||||||
|
push: type-expr
|
||||||
|
- match: $
|
||||||
|
pop: true
|
||||||
|
|
||||||
|
type-stmt:
|
||||||
|
- match: "{{identifier}}"
|
||||||
|
scope: entity.name.type
|
||||||
|
- match: \[
|
||||||
|
push: type-params
|
||||||
|
- match: "="
|
||||||
|
scope: keyword.operator.equal.midas
|
||||||
|
push: type-expr
|
||||||
|
- match: $
|
||||||
|
pop: true
|
||||||
|
|
||||||
|
type-expr:
|
||||||
|
- match: \b(fn)\s*(\()
|
||||||
|
captures:
|
||||||
|
1: keyword.other.midas
|
||||||
|
2: punctuation.section.group.begin
|
||||||
|
push: fn-params
|
||||||
|
- match: \b(where)\b
|
||||||
|
scope: keyword.other.midas
|
||||||
|
set: constraint
|
||||||
|
- match: "Frame"
|
||||||
|
scope: entity.name.type
|
||||||
|
push:
|
||||||
|
- match: \[
|
||||||
|
push: frame-schema
|
||||||
|
- match: $
|
||||||
|
pop: true
|
||||||
|
- match: "{{identifier}}"
|
||||||
|
scope: entity.name.type
|
||||||
|
- match: $
|
||||||
|
pop: 2
|
||||||
|
|
||||||
|
fn-params:
|
||||||
|
- match: "({{identifier}})(:)"
|
||||||
|
captures:
|
||||||
|
1: variable.parameter.midas
|
||||||
|
2: punctuation.separator.annotation.midas
|
||||||
|
push:
|
||||||
|
- include: type-expr
|
||||||
|
- match: \?
|
||||||
|
scope: keyword.operator.qmark.midas
|
||||||
|
- match: "(?=,)"
|
||||||
|
scope: punctuation.separator.midas
|
||||||
|
pop: true
|
||||||
|
- match: '(?=\))'
|
||||||
|
pop: true
|
||||||
|
- include: type-expr
|
||||||
|
- match: '\)'
|
||||||
|
set:
|
||||||
|
- match: "->"
|
||||||
|
scope: keyword.operator.arrow.midas
|
||||||
|
set: type-expr
|
||||||
|
|
||||||
|
constraint:
|
||||||
|
- match: $
|
||||||
|
pop: 2
|
||||||
|
|
||||||
|
- match: \d+(\.\d+)?
|
||||||
|
scope: constant.numeric.midas
|
||||||
|
|
||||||
|
- match: \b(true|false|none)\b
|
||||||
|
scope: constant.language.midas
|
||||||
|
|
||||||
|
- match: '"'
|
||||||
|
push: string
|
||||||
|
|
||||||
|
- match: (<=|>=|<|>|==|!=|&)
|
||||||
|
scope: keyword.operator
|
||||||
|
|
||||||
|
- match: _
|
||||||
|
scope: variable.language.midas
|
||||||
|
|
||||||
|
- match: '{{identifier}}(?=\s*\()'
|
||||||
|
scope: variable.function.midas
|
||||||
|
|
||||||
|
- match: "{{identifier}}"
|
||||||
|
scope: variable.other.readwrite.midas
|
||||||
|
|
||||||
|
type-params:
|
||||||
|
- match: "<:"
|
||||||
|
scope: keyword.operator.subtype.midas
|
||||||
|
- match: "[a-zA-Z][a-zA-Z_0-9]*"
|
||||||
|
scope: entity.name.type
|
||||||
|
- match: "]"
|
||||||
|
pop: true
|
||||||
|
|
||||||
|
extend-stmt:
|
||||||
|
- match: "{{identifier}}"
|
||||||
|
scope: entity.name.type
|
||||||
|
- match: \[
|
||||||
|
push: type-params
|
||||||
|
- match: \{
|
||||||
|
scope: punctuation.section.block.begin
|
||||||
|
set: extend-body
|
||||||
|
|
||||||
|
extend-body:
|
||||||
|
- include: member-stmt
|
||||||
|
- match: \}
|
||||||
|
scope: punctuation.section.block.end
|
||||||
|
pop: true
|
||||||
|
|
||||||
|
member-stmt:
|
||||||
|
- match: \b(prop|def)\b
|
||||||
|
scope: keyword.other.midas
|
||||||
|
push:
|
||||||
|
- match: "{{identifier}}"
|
||||||
|
scope: variable.other.member
|
||||||
|
- match: ":"
|
||||||
|
push: type-expr
|
||||||
|
- match: $
|
||||||
|
pop: true
|
||||||
|
|
||||||
|
predicate-stmt:
|
||||||
|
- match: "{{identifier}}"
|
||||||
|
scope: entity.name.function.midas
|
||||||
|
- match: '\('
|
||||||
|
push: predicate-params
|
||||||
|
- match: "="
|
||||||
|
scope: keyword.operator.equal.midas
|
||||||
|
set: constraint
|
||||||
|
- match: $
|
||||||
|
pop: true
|
||||||
|
|
||||||
|
predicate-params:
|
||||||
|
- match: "({{identifier}})(:)"
|
||||||
|
captures:
|
||||||
|
1: variable.parameter.midas
|
||||||
|
2: punctuation.separator.annotation.midas
|
||||||
|
push:
|
||||||
|
- include: type-expr
|
||||||
|
- match: "(?=,)"
|
||||||
|
scope: punctuation.separator.midas
|
||||||
|
pop: true
|
||||||
|
- match: '(?=\))'
|
||||||
|
pop: true
|
||||||
|
|
||||||
|
- match: '\)'
|
||||||
|
pop: true
|
||||||
|
|
||||||
|
frame-schema:
|
||||||
|
- include: frame-column
|
||||||
|
- match: \]
|
||||||
|
# scope: punctuation.section.block.end
|
||||||
|
pop: true
|
||||||
|
|
||||||
|
frame-column:
|
||||||
|
- match: "{{identifier}}"
|
||||||
|
scope: variable.other.member
|
||||||
|
- match: ":"
|
||||||
|
push: type-expr
|
||||||
143
docs/template.typ
Normal file
143
docs/template.typ
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
#import "@preview/modpattern:0.2.0": modpattern
|
||||||
|
|
||||||
|
#let TODO = block(
|
||||||
|
width: 6em,
|
||||||
|
height: 3em,
|
||||||
|
stroke: red,
|
||||||
|
fill: modpattern(
|
||||||
|
size: (10pt, 10pt),
|
||||||
|
line(
|
||||||
|
start: (0%, 0%),
|
||||||
|
end: (100%, 100%),
|
||||||
|
stroke: gray.transparentize(60%) + 2pt,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
align(
|
||||||
|
center + horizon,
|
||||||
|
text(fill: red, size: 1.5em)[*TODO*],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
#let _render-header(version, hash) = {
|
||||||
|
let last-heading = query(heading.where(level: 1).before(here())).last(default: none)
|
||||||
|
let next-heading = query(heading.where(level: 1).after(here())).first(default: none)
|
||||||
|
|
||||||
|
let current-heading = if next-heading != none and next-heading.location().page() == here().page() {
|
||||||
|
next-heading
|
||||||
|
} else if last-heading != none {
|
||||||
|
last-heading
|
||||||
|
} else { none }
|
||||||
|
|
||||||
|
let chapter = if current-heading != none {
|
||||||
|
let body = current-heading.body
|
||||||
|
if current-heading.numbering != none {
|
||||||
|
let num = counter(heading).display(current-heading.numbering, at: current-heading.location())
|
||||||
|
body = [#num #body]
|
||||||
|
}
|
||||||
|
body
|
||||||
|
} else []
|
||||||
|
|
||||||
|
grid(
|
||||||
|
columns: (1fr, auto, 1fr),
|
||||||
|
align: (left, center, right),
|
||||||
|
document.title, [v#version - #hash], chapter,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#let _unshift-prefix(prefix, content) = context {
|
||||||
|
pad(left: -measure(prefix).width, prefix + content)
|
||||||
|
}
|
||||||
|
|
||||||
|
#let project(
|
||||||
|
title: none,
|
||||||
|
author: none,
|
||||||
|
version: "0.0.1",
|
||||||
|
hash: "abcdefgh",
|
||||||
|
icon-path: none,
|
||||||
|
doc,
|
||||||
|
) = {
|
||||||
|
assert(title != none, message: "Please provide a title")
|
||||||
|
|
||||||
|
set document(
|
||||||
|
title: title,
|
||||||
|
author: author,
|
||||||
|
)
|
||||||
|
set text(
|
||||||
|
font: "Source Sans 3",
|
||||||
|
)
|
||||||
|
|
||||||
|
set raw(syntaxes: path("midas.sublime-syntax"))
|
||||||
|
|
||||||
|
let front-page() = {
|
||||||
|
align(center)[
|
||||||
|
#{
|
||||||
|
set text(size: 1.5em)
|
||||||
|
std.title()
|
||||||
|
}
|
||||||
|
|
||||||
|
v#version - #hash
|
||||||
|
|
||||||
|
#if icon-path != none {
|
||||||
|
v(1cm)
|
||||||
|
image(icon-path)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
pagebreak()
|
||||||
|
}
|
||||||
|
|
||||||
|
let outlines() = {
|
||||||
|
outline()
|
||||||
|
pagebreak()
|
||||||
|
outline(
|
||||||
|
title: [List of Listings],
|
||||||
|
target: figure.where(kind: raw),
|
||||||
|
)
|
||||||
|
|
||||||
|
outline(
|
||||||
|
title: [List of Tables],
|
||||||
|
target: figure.where(kind: table),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
let main() = {
|
||||||
|
// Adapted from https://github.com/hei-templates/hei-synd-thesis/blob/7d2b941197babae0bf3afd4e5914754e09a64001/lib/template-thesis.typ#L242-L261
|
||||||
|
show heading.where(level: 1): it => {
|
||||||
|
pagebreak()
|
||||||
|
|
||||||
|
set text(size: 1.5em)
|
||||||
|
set block(above: 1.2em, below: 1.2em)
|
||||||
|
if it.numbering != none {
|
||||||
|
let num = numbering(it.numbering, ..counter(heading).at(it.location()))
|
||||||
|
let prefix = num + h(1em)
|
||||||
|
_unshift-prefix(prefix, it.body)
|
||||||
|
} else {
|
||||||
|
it
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
show heading.where(level: 2): it => {
|
||||||
|
if it.numbering != none {
|
||||||
|
let num = numbering(it.numbering, ..counter(heading).at(it.location()))
|
||||||
|
_unshift-prefix(num + h(0.8em), it.body)
|
||||||
|
} else {
|
||||||
|
it
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
set page(
|
||||||
|
header: context _render-header(version, hash),
|
||||||
|
footer: context if page.numbering != none {
|
||||||
|
align(center, counter(page).display(page.numbering, both: true))
|
||||||
|
},
|
||||||
|
numbering: "1 / 1",
|
||||||
|
)
|
||||||
|
|
||||||
|
show heading: set heading(numbering: "I.1.")
|
||||||
|
counter(page).update(1)
|
||||||
|
doc
|
||||||
|
}
|
||||||
|
|
||||||
|
front-page()
|
||||||
|
outlines()
|
||||||
|
main()
|
||||||
|
}
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
from typing import TypeVar, cast
|
from typing import TypeVar
|
||||||
|
|
||||||
from demo_stubs import CHF, EUR, USD, Currency, Price, Discount
|
from demo_stubs import CHF, EUR, USD, Currency, Discount, Price
|
||||||
|
|
||||||
|
from midas.typing import cast, unsafe_cast
|
||||||
|
|
||||||
T = TypeVar("T", bound=Currency)
|
T = TypeVar("T", bound=Currency)
|
||||||
|
|
||||||
@@ -28,3 +30,6 @@ discounted = apply_discount(
|
|||||||
)
|
)
|
||||||
|
|
||||||
print(f"Discounted: CHF {discounted}")
|
print(f"Discounted: CHF {discounted}")
|
||||||
|
|
||||||
|
large_data = [i * 10 for i in range(100)]
|
||||||
|
prices = unsafe_cast(list[Price[EUR]], large_data)
|
||||||
|
|||||||
15
gen/midas.py
15
gen/midas.py
@@ -44,6 +44,11 @@ class TypeStmt:
|
|||||||
type: Type
|
type: Type
|
||||||
|
|
||||||
|
|
||||||
|
class AliasStmt:
|
||||||
|
name: Token
|
||||||
|
type: Type
|
||||||
|
|
||||||
|
|
||||||
class MemberStmt:
|
class MemberStmt:
|
||||||
name: Token
|
name: Token
|
||||||
type: Type
|
type: Type
|
||||||
@@ -152,4 +157,14 @@ class FunctionType:
|
|||||||
required: bool
|
required: bool
|
||||||
|
|
||||||
|
|
||||||
|
class FrameType:
|
||||||
|
columns: list[Column]
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Column:
|
||||||
|
location: Optional[Location] = None
|
||||||
|
name: Token
|
||||||
|
type: Type
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from midas.ast.location import Location
|
|||||||
###> MidasType | Type annotations | node
|
###> MidasType | Type annotations | node
|
||||||
class BaseType:
|
class BaseType:
|
||||||
base: str
|
base: str
|
||||||
param: Optional[MidasType]
|
args: tuple[MidasType, ...]
|
||||||
|
|
||||||
|
|
||||||
class ConstraintType:
|
class ConstraintType:
|
||||||
@@ -145,6 +145,7 @@ class LogicalExpr:
|
|||||||
class CastExpr:
|
class CastExpr:
|
||||||
type: MidasType
|
type: MidasType
|
||||||
expr: Expr
|
expr: Expr
|
||||||
|
unsafe: bool
|
||||||
|
|
||||||
|
|
||||||
class TernaryExpr:
|
class TernaryExpr:
|
||||||
@@ -173,6 +174,10 @@ class SliceExpr:
|
|||||||
step: Optional[Expr]
|
step: Optional[Expr]
|
||||||
|
|
||||||
|
|
||||||
|
class TupleExpr:
|
||||||
|
items: tuple[Expr, ...]
|
||||||
|
|
||||||
|
|
||||||
class RawExpr:
|
class RawExpr:
|
||||||
expr: ast.expr
|
expr: ast.expr
|
||||||
|
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ class Stmt(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
|
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_alias_stmt(self, stmt: AliasStmt) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_member_stmt(self, stmt: MemberStmt) -> T: ...
|
def visit_member_stmt(self, stmt: MemberStmt) -> T: ...
|
||||||
|
|
||||||
@@ -71,6 +74,15 @@ class TypeStmt(Stmt):
|
|||||||
return visitor.visit_type_stmt(self)
|
return visitor.visit_type_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AliasStmt(Stmt):
|
||||||
|
name: Token
|
||||||
|
type: Type
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_alias_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class MemberStmt(Stmt):
|
class MemberStmt(Stmt):
|
||||||
name: Token
|
name: Token
|
||||||
@@ -253,6 +265,9 @@ class Type(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_function_type(self, type: FunctionType) -> T: ...
|
def visit_function_type(self, type: FunctionType) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_frame_type(self, type: FrameType) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class NamedType(Type):
|
class NamedType(Type):
|
||||||
@@ -311,3 +326,17 @@ class FunctionType(Type):
|
|||||||
|
|
||||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||||
return visitor.visit_function_type(self)
|
return visitor.visit_function_type(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FrameType(Type):
|
||||||
|
columns: list[Column]
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Column:
|
||||||
|
location: Optional[Location] = None
|
||||||
|
name: Token
|
||||||
|
type: Type
|
||||||
|
|
||||||
|
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_frame_type(self)
|
||||||
|
|||||||
@@ -105,6 +105,14 @@ class MidasAstPrinter(
|
|||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
stmt.type.accept(self)
|
stmt.type.accept(self)
|
||||||
|
|
||||||
|
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
|
||||||
|
self._write_line("AliasStmt")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||||
|
self._write_line("type", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
stmt.type.accept(self)
|
||||||
|
|
||||||
def _print_type_param(self, param: m.TypeParam) -> None:
|
def _print_type_param(self, param: m.TypeParam) -> None:
|
||||||
self._write_line("Param")
|
self._write_line("Param")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
@@ -350,6 +358,25 @@ class MidasAstPrinter(
|
|||||||
arg.type.accept(self)
|
arg.type.accept(self)
|
||||||
self._write_line(f"required: {arg.required}", last=True)
|
self._write_line(f"required: {arg.required}", last=True)
|
||||||
|
|
||||||
|
def visit_frame_type(self, type: m.FrameType) -> None:
|
||||||
|
self._write_line("FrameType")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
self._write_line("columns")
|
||||||
|
with self._child_level():
|
||||||
|
for i, column in enumerate(type.columns):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(type.columns) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_frame_column(column)
|
||||||
|
|
||||||
|
def _print_frame_column(self, column: m.FrameType.Column) -> None:
|
||||||
|
self._write_line("Column")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f'name: "{column.name.lexeme}"')
|
||||||
|
self._write_line("type")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
column.type.accept(self)
|
||||||
|
|
||||||
|
|
||||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
||||||
def __init__(self, indent: int = 4):
|
def __init__(self, indent: int = 4):
|
||||||
@@ -371,6 +398,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
||||||
return self.indented(res)
|
return self.indented(res)
|
||||||
|
|
||||||
|
def visit_alias_stmt(self, stmt: m.AliasStmt) -> str:
|
||||||
|
return self.indented(f"alias {stmt.name.lexeme} = {stmt.type.accept(self)}")
|
||||||
|
|
||||||
def _print_type_param(self, param: m.TypeParam) -> str:
|
def _print_type_param(self, param: m.TypeParam) -> str:
|
||||||
res: str = param.name.lexeme
|
res: str = param.name.lexeme
|
||||||
if param.bound is not None:
|
if param.bound is not None:
|
||||||
@@ -502,6 +532,23 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
res += "?"
|
res += "?"
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def visit_frame_type(self, type: m.FrameType) -> str:
|
||||||
|
res: str = self.indented("Frame[")
|
||||||
|
if len(type.columns) != 0:
|
||||||
|
res += "\n"
|
||||||
|
self.level += 1
|
||||||
|
columns: list[str] = []
|
||||||
|
for column in type.columns:
|
||||||
|
columns.append(self.indented(self._print_frame_column(column)))
|
||||||
|
res += ",\n".join(columns)
|
||||||
|
self.level -= 1
|
||||||
|
res += "\n"
|
||||||
|
res += "]"
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _print_frame_column(self, column: m.FrameType.Column) -> str:
|
||||||
|
return f"{column.name.lexeme}: {column.type.accept(self)}"
|
||||||
|
|
||||||
|
|
||||||
class PythonAstPrinter(
|
class PythonAstPrinter(
|
||||||
AstPrinter,
|
AstPrinter,
|
||||||
@@ -513,7 +560,13 @@ class PythonAstPrinter(
|
|||||||
self._write_line("BaseType")
|
self._write_line("BaseType")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
self._write_line(f"base: {node.base}")
|
self._write_line(f"base: {node.base}")
|
||||||
self._write_optional_child("param", node.param, last=True)
|
self._write_line("args:", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(node.args):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(node.args) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
arg.accept(self)
|
||||||
|
|
||||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||||
self._write_line("ConstraintType")
|
self._write_line("ConstraintType")
|
||||||
@@ -757,9 +810,10 @@ class PythonAstPrinter(
|
|||||||
self._write_line("type")
|
self._write_line("type")
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
expr.type.accept(self)
|
expr.type.accept(self)
|
||||||
self._write_line("expr", last=True)
|
self._write_line("expr")
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
expr.expr.accept(self)
|
expr.expr.accept(self)
|
||||||
|
self._write_line(f"unsafe: {expr.unsafe}", last=True)
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||||
self._write_line("TernaryExpr")
|
self._write_line("TernaryExpr")
|
||||||
@@ -825,6 +879,17 @@ class PythonAstPrinter(
|
|||||||
self._write_optional_child("upper", expr.upper)
|
self._write_optional_child("upper", expr.upper)
|
||||||
self._write_optional_child("step", expr.step, last=True)
|
self._write_optional_child("step", expr.step, last=True)
|
||||||
|
|
||||||
|
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
|
||||||
|
self._write_line("TupleExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("items", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, item in enumerate(expr.items):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(expr.items) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
item.accept(self)
|
||||||
|
|
||||||
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
||||||
self._write_line("RawExpr")
|
self._write_line("RawExpr")
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class MidasType(ABC):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class BaseType(MidasType):
|
class BaseType(MidasType):
|
||||||
base: str
|
base: str
|
||||||
param: Optional[MidasType]
|
args: tuple[MidasType, ...]
|
||||||
|
|
||||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||||
return visitor.visit_base_type(self)
|
return visitor.visit_base_type(self)
|
||||||
@@ -268,6 +268,9 @@ class Expr(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_slice_expr(self, expr: SliceExpr) -> T: ...
|
def visit_slice_expr(self, expr: SliceExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_tuple_expr(self, expr: TupleExpr) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_raw_expr(self, expr: RawExpr) -> T: ...
|
def visit_raw_expr(self, expr: RawExpr) -> T: ...
|
||||||
|
|
||||||
@@ -350,6 +353,7 @@ class LogicalExpr(Expr):
|
|||||||
class CastExpr(Expr):
|
class CastExpr(Expr):
|
||||||
type: MidasType
|
type: MidasType
|
||||||
expr: Expr
|
expr: Expr
|
||||||
|
unsafe: bool
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
return visitor.visit_cast_expr(self)
|
return visitor.visit_cast_expr(self)
|
||||||
@@ -401,6 +405,14 @@ class SliceExpr(Expr):
|
|||||||
return visitor.visit_slice_expr(self)
|
return visitor.visit_slice_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TupleExpr(Expr):
|
||||||
|
items: tuple[Expr, ...]
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_tuple_expr(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class RawExpr(Expr):
|
class RawExpr(Expr):
|
||||||
expr: ast.expr
|
expr: ast.expr
|
||||||
|
|||||||
@@ -178,4 +178,100 @@ extend dict[K, V] {
|
|||||||
// def __ior__: fn(value: SupportsKeysAndGetItem[K, V], /) -> dict[K, V]
|
// def __ior__: fn(value: SupportsKeysAndGetItem[K, V], /) -> dict[K, V]
|
||||||
// def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V]
|
// def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V]
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extend str {
|
||||||
|
def capitalize: fn() -> str
|
||||||
|
def casefold: fn() -> str
|
||||||
|
def center: fn(width: int, fillchar: str?, /) -> str
|
||||||
|
def count: fn(sub: str, start: None?, end: None?, /) -> int
|
||||||
|
def count: fn(sub: str, start: int, end: None?, /) -> int
|
||||||
|
def count: fn(sub: str, start: None, end: int, /) -> int
|
||||||
|
def count: fn(sub: str, start: int, end: int, /) -> int
|
||||||
|
def encode: fn(encoding: str?, errors: str?) -> bytes
|
||||||
|
def endswith: fn(suffix: str, start: None?, end: None?, /) -> bool
|
||||||
|
def endswith: fn(suffix: str, start: int, end: None?, /) -> bool
|
||||||
|
def endswith: fn(suffix: str, start: None, end: int, /) -> bool
|
||||||
|
def endswith: fn(suffix: str, start: int, end: int, /) -> bool
|
||||||
|
def expandtabs: fn(tabsize: int?) -> str
|
||||||
|
def find: fn(sub: str, start: None?, end: None?, /) -> int
|
||||||
|
def find: fn(sub: str, start: int, end: None?, /) -> int
|
||||||
|
def find: fn(sub: str, start: None, end: int, /) -> int
|
||||||
|
def find: fn(sub: str, start: int, end: int, /) -> int
|
||||||
|
// def format: fn(*args: object, **kwargs: object) -> str
|
||||||
|
// def format_map: fn(mapping: _FormatMapMapping, /) -> str
|
||||||
|
def index: fn(sub: str, start: None?, end: None?, /) -> int
|
||||||
|
def index: fn(sub: str, start: int, end: None?, /) -> int
|
||||||
|
def index: fn(sub: str, start: None, end: int, /) -> int
|
||||||
|
def index: fn(sub: str, start: int, end: int, /) -> int
|
||||||
|
def isalnum: fn() -> bool
|
||||||
|
def isalpha: fn() -> bool
|
||||||
|
def isascii: fn() -> bool
|
||||||
|
def isdecimal: fn() -> bool
|
||||||
|
def isdigit: fn() -> bool
|
||||||
|
def isidentifier: fn() -> bool
|
||||||
|
def islower: fn() -> bool
|
||||||
|
def isnumeric: fn() -> bool
|
||||||
|
def isprintable: fn() -> bool
|
||||||
|
def isspace: fn() -> bool
|
||||||
|
def istitle: fn() -> bool
|
||||||
|
def isupper: fn() -> bool
|
||||||
|
def join: fn(iterable: list[str], /) -> str // TODO: use Iterable
|
||||||
|
def ljust: fn(width: int, fillchar: str?, /) -> str
|
||||||
|
def lower: fn() -> str
|
||||||
|
def lstrip: fn(chars: None?, /) -> str
|
||||||
|
def lstrip: fn(chars: str, /) -> str
|
||||||
|
def partition: fn(sep: str, /) -> tuple[str, str, str]
|
||||||
|
|
||||||
|
def replace: fn(old: str, new: str, count: int?, /) -> str
|
||||||
|
|
||||||
|
def removeprefix: fn(prefix: str, /) -> str
|
||||||
|
def removesuffix: fn(suffix: str, /) -> str
|
||||||
|
def rfind: fn(sub: str, start: None?, end: None?, /) -> int
|
||||||
|
def rfind: fn(sub: str, start: int, end: None?, /) -> int
|
||||||
|
def rfind: fn(sub: str, start: None, end: int, /) -> int
|
||||||
|
def rfind: fn(sub: str, start: int, end: int, /) -> int
|
||||||
|
def rindex: fn(sub: str, start: None?, end: None?, /) -> int
|
||||||
|
def rindex: fn(sub: str, start: int, end: None?, /) -> int
|
||||||
|
def rindex: fn(sub: str, start: None, end: int, /) -> int
|
||||||
|
def rindex: fn(sub: str, start: int, end: int, /) -> int
|
||||||
|
def rjust: fn(width: int, fillchar: str?, /) -> str
|
||||||
|
def rpartition: fn(sep: str, /) -> tuple[str, str, str]
|
||||||
|
def rsplit: fn(sep: None?, maxsplit: int?) -> list[str]
|
||||||
|
def rsplit: fn(sep: str, maxsplit: int?) -> list[str]
|
||||||
|
def rstrip: fn(chars: None?, /) -> str
|
||||||
|
def rstrip: fn(chars: str, /) -> str
|
||||||
|
def split: fn(sep: None?, maxsplit: int?) -> list[str]
|
||||||
|
def split: fn(sep: str, maxsplit: int?) -> list[str]
|
||||||
|
def splitlines: fn(keepends: bool?) -> list[str]
|
||||||
|
def startswith: fn(prefix: str, start: None?, end: None?, /) -> bool
|
||||||
|
def startswith: fn(prefix: str, start: int, end: None?, /) -> bool
|
||||||
|
def startswith: fn(prefix: str, start: None, end: int, /) -> bool
|
||||||
|
def startswith: fn(prefix: str, start: int, end: int, /) -> bool
|
||||||
|
def strip: fn(chars: None?, /) -> str
|
||||||
|
def strip: fn(chars: str, /) -> str
|
||||||
|
def swapcase: fn() -> str
|
||||||
|
def title: fn() -> str
|
||||||
|
// def translate: fn(table: _TranslateTable, /) -> str
|
||||||
|
def upper: fn() -> str
|
||||||
|
def zfill: fn(width: int, /) -> str
|
||||||
|
def __add__: fn(value: str, /) -> str
|
||||||
|
// Incompatible with Sequence.__contains__
|
||||||
|
def __contains__: fn(key: str, /) -> bool
|
||||||
|
def __eq__: fn(value: object, /) -> bool
|
||||||
|
def __ge__: fn(value: str, /) -> bool
|
||||||
|
def __getitem__: fn(key: slice, /) -> str
|
||||||
|
def __getitem__: fn(key: int, /) -> str
|
||||||
|
def __gt__: fn(value: str, /) -> bool
|
||||||
|
def __hash__: fn() -> int
|
||||||
|
// def __iter__: fn() -> Iterator[str]
|
||||||
|
def __le__: fn(value: str, /) -> bool
|
||||||
|
def __len__: fn() -> int
|
||||||
|
def __lt__: fn(value: str, /) -> bool
|
||||||
|
def __mod__: fn(value: Any, /) -> str
|
||||||
|
def __mul__: fn(value: int, /) -> str
|
||||||
|
def __ne__: fn(value: object, /) -> bool
|
||||||
|
def __rmul__: fn(value: int, /) -> str
|
||||||
|
def __getnewargs__: fn() -> tuple[str]
|
||||||
|
def __format__: fn(format_spec: str, /) -> str
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,10 +14,11 @@ if TYPE_CHECKING:
|
|||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
|
|
||||||
|
|
||||||
|
# Hard-coded subtype relationships between builtin types
|
||||||
|
# Circular dependencies and diamond inheritance MUST be avoided
|
||||||
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||||
"object": {"float", "list", "dict", "str"},
|
"object": {"float", "list", "dict", "str", "bytes", "tuple"},
|
||||||
"float": {"int"},
|
"float": {"int"},
|
||||||
"int": {"bool"},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -26,12 +27,15 @@ def define_builtins(reg: TypesRegistry):
|
|||||||
any = reg.define_type("Any", TopType())
|
any = reg.define_type("Any", TopType())
|
||||||
unit = reg.define_type("None", UnitType())
|
unit = reg.define_type("None", UnitType())
|
||||||
object = reg.define_type("object", BaseType(name="object"))
|
object = reg.define_type("object", BaseType(name="object"))
|
||||||
|
bytes = reg.define_type("bytes", BaseType(name="bytes"))
|
||||||
bool = reg.define_type("bool", BaseType(name="bool"))
|
bool = reg.define_type("bool", BaseType(name="bool"))
|
||||||
int = reg.define_type("int", BaseType(name="int"))
|
int = reg.define_type("int", BaseType(name="int"))
|
||||||
float = reg.define_type("float", BaseType(name="float"))
|
float = reg.define_type("float", BaseType(name="float"))
|
||||||
str = reg.define_type("str", BaseType(name="str"))
|
str = reg.define_type("str", BaseType(name="str"))
|
||||||
slice = reg.define_type("slice", BaseType(name="slice"))
|
slice = reg.define_type("slice", BaseType(name="slice"))
|
||||||
|
|
||||||
|
tuple = reg.define_type("tuple", BaseType(name="tuple"))
|
||||||
|
|
||||||
list = reg.define_type(
|
list = reg.define_type(
|
||||||
"list",
|
"list",
|
||||||
GenericType(
|
GenericType(
|
||||||
|
|||||||
484
midas/checker/dispatcher.py
Normal file
484
midas/checker/dispatcher.py
Normal file
@@ -0,0 +1,484 @@
|
|||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Generic, Optional, Protocol, TypeVar, Union
|
||||||
|
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.reporter import FileReporter
|
||||||
|
from midas.checker.types import (
|
||||||
|
AppliedType,
|
||||||
|
DerivedType,
|
||||||
|
Function,
|
||||||
|
GenericType,
|
||||||
|
OverloadedFunction,
|
||||||
|
Type,
|
||||||
|
UnknownType,
|
||||||
|
)
|
||||||
|
from midas.checker.unifier import Unifier
|
||||||
|
|
||||||
|
|
||||||
|
class HasLocation(Protocol):
|
||||||
|
@property
|
||||||
|
def location(self) -> Location: ...
|
||||||
|
|
||||||
|
|
||||||
|
E = TypeVar("E", bound=HasLocation)
|
||||||
|
|
||||||
|
TypedExpr = tuple[E, Type]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class MappedArgument(Generic[E]):
|
||||||
|
expr: E
|
||||||
|
type: Type
|
||||||
|
argument: Function.Argument
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class OverloadCandidate:
|
||||||
|
function: Function
|
||||||
|
mapped: list[MappedArgument]
|
||||||
|
|
||||||
|
|
||||||
|
class CallError(StrEnum):
|
||||||
|
INVALID_ARGS = "Invalid arguments"
|
||||||
|
NO_MATCHING_OVERLOAD = "No matching overload"
|
||||||
|
IMPOSSIBLE_UNIFICATION = "Parameters unification failed"
|
||||||
|
NOT_CALLABLE = "Not callable"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class CallResult:
|
||||||
|
error: Optional[CallError] = None
|
||||||
|
result: Type = UnknownType()
|
||||||
|
message: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
return self.error is None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def error_message(self) -> str:
|
||||||
|
if self.message is not None:
|
||||||
|
return self.message
|
||||||
|
if self.error is not None:
|
||||||
|
return str(self.error)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class CallDispatcher(Generic[E]):
|
||||||
|
def __init__(self, types: TypesRegistry, reporter: FileReporter) -> None:
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
self.reporter: FileReporter = reporter
|
||||||
|
self.logger: logging.Logger = logging.getLogger("CallDispatcher")
|
||||||
|
|
||||||
|
def set_reporter(self, reporter: FileReporter):
|
||||||
|
self.reporter = reporter
|
||||||
|
|
||||||
|
def get_result(
|
||||||
|
self,
|
||||||
|
location: Location,
|
||||||
|
callee: Type,
|
||||||
|
positional: list[TypedExpr[E]],
|
||||||
|
keywords: dict[str, TypedExpr[E]],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> CallResult:
|
||||||
|
"""Get the result type of a function call
|
||||||
|
|
||||||
|
If the function has overloads, the function will try to resolve the
|
||||||
|
appropriate signature.
|
||||||
|
Argument types are matched to the defined parameters.
|
||||||
|
The function doesn't take the raw expression as a parameter to accommodate
|
||||||
|
for desugared calls such as for operators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location (Location): the call location
|
||||||
|
callee (Type): the called function
|
||||||
|
positional (list[TypedExpr]): the list positional arguments
|
||||||
|
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the return type of the call, or `None` if either
|
||||||
|
the call is invalid or no overload matched the arguments uniquely
|
||||||
|
"""
|
||||||
|
match callee:
|
||||||
|
case Function() as function:
|
||||||
|
valid: bool
|
||||||
|
mapped: list[MappedArgument[E]]
|
||||||
|
valid, mapped = self.map_call_arguments(
|
||||||
|
function, location, positional, keywords
|
||||||
|
)
|
||||||
|
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
||||||
|
if not valid:
|
||||||
|
return CallResult(error=CallError.INVALID_ARGS)
|
||||||
|
return CallResult(result=function.returns)
|
||||||
|
|
||||||
|
case OverloadedFunction(overloads=overloads):
|
||||||
|
res = self._match_overload(
|
||||||
|
overloads, location, positional, keywords, report_errors
|
||||||
|
)
|
||||||
|
if res[0] is None:
|
||||||
|
return CallResult(
|
||||||
|
error=CallError.NO_MATCHING_OVERLOAD,
|
||||||
|
message=res[1],
|
||||||
|
)
|
||||||
|
return CallResult(result=res[0].returns)
|
||||||
|
|
||||||
|
case AppliedType(body=body):
|
||||||
|
return self.get_result(
|
||||||
|
location, body, positional, keywords, report_errors
|
||||||
|
)
|
||||||
|
|
||||||
|
case UnknownType():
|
||||||
|
return CallResult(result=UnknownType())
|
||||||
|
|
||||||
|
case DerivedType(type=base):
|
||||||
|
return self.get_result(
|
||||||
|
location, base, positional, keywords, report_errors
|
||||||
|
)
|
||||||
|
|
||||||
|
case GenericType():
|
||||||
|
unifier: Unifier = Unifier(self.types)
|
||||||
|
pos: list[Type] = [a[1] for a in positional]
|
||||||
|
kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()}
|
||||||
|
unified: Optional[Type] = unifier.unify_call(callee, pos, kw)
|
||||||
|
if unified is None:
|
||||||
|
pos_str: str = ", ".join(str(t) for t in pos)
|
||||||
|
kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items())
|
||||||
|
message: str = (
|
||||||
|
f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}"
|
||||||
|
)
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(location, message)
|
||||||
|
return CallResult(
|
||||||
|
error=CallError.IMPOSSIBLE_UNIFICATION,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
return self.get_result(
|
||||||
|
location,
|
||||||
|
unified,
|
||||||
|
positional,
|
||||||
|
keywords,
|
||||||
|
report_errors,
|
||||||
|
)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
message: str = f"{callee} ({callee.__class__.__name__}) is not callable"
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(location, message)
|
||||||
|
return CallResult(
|
||||||
|
error=CallError.NOT_CALLABLE,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _unwrap_function(
|
||||||
|
self,
|
||||||
|
callee: Type,
|
||||||
|
positional: list[TypedExpr[E]],
|
||||||
|
keywords: dict[str, TypedExpr[E]],
|
||||||
|
) -> Union[tuple[Function, None], tuple[None, CallError]]:
|
||||||
|
match callee:
|
||||||
|
case DerivedType(type=base):
|
||||||
|
return self._unwrap_function(base, positional, keywords)
|
||||||
|
|
||||||
|
case GenericType():
|
||||||
|
unifier: Unifier = Unifier(self.types)
|
||||||
|
unified: Optional[Type] = unifier.unify_call(
|
||||||
|
callee,
|
||||||
|
[a[1] for a in positional],
|
||||||
|
{k: v[1] for k, v in keywords.items()},
|
||||||
|
)
|
||||||
|
if unified is None:
|
||||||
|
return None, CallError.IMPOSSIBLE_UNIFICATION
|
||||||
|
return self._unwrap_function(unified, positional, keywords)
|
||||||
|
|
||||||
|
case Function():
|
||||||
|
return callee, None
|
||||||
|
|
||||||
|
case AppliedType(body=body):
|
||||||
|
return self._unwrap_function(body, positional, keywords)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
return None, CallError.NOT_CALLABLE
|
||||||
|
|
||||||
|
def _are_arguments_valid(
|
||||||
|
self,
|
||||||
|
arguments: list[MappedArgument[E]],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
"""Check whether the passed argument types correspond to their matched parameter definitions
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
||||||
|
"""
|
||||||
|
valid: bool = True
|
||||||
|
for arg in arguments:
|
||||||
|
if not self.types.is_subtype(arg.type, arg.argument.type):
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
arg.expr.location,
|
||||||
|
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
return valid
|
||||||
|
|
||||||
|
def _match_overload(
|
||||||
|
self,
|
||||||
|
overloads: list[Type],
|
||||||
|
location: Location,
|
||||||
|
positional: list[TypedExpr[E]],
|
||||||
|
keywords: dict[str, TypedExpr[E]],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> Union[tuple[Function, None], tuple[None, str]]:
|
||||||
|
"""Try and resolve the appropriate overload for the given arguments
|
||||||
|
|
||||||
|
Args:
|
||||||
|
overloads (list[Type]): the list of possible overloads
|
||||||
|
location (Location): the call location
|
||||||
|
positional (list[TypedExpr]): the list of positional arguments
|
||||||
|
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Function]: the resolved function signature if it can be
|
||||||
|
determined unambiguously, or `None`.
|
||||||
|
"""
|
||||||
|
candidates: list[OverloadCandidate] = []
|
||||||
|
errors: list[CallError] = []
|
||||||
|
for overload in overloads:
|
||||||
|
function, unwrap_error = self._unwrap_function(
|
||||||
|
overload, positional, keywords
|
||||||
|
)
|
||||||
|
if function is None:
|
||||||
|
errors.append(unwrap_error) # type: ignore
|
||||||
|
continue
|
||||||
|
|
||||||
|
valid, mapped = self.map_call_arguments(
|
||||||
|
function=function,
|
||||||
|
location=location,
|
||||||
|
positional=positional,
|
||||||
|
keywords=keywords,
|
||||||
|
report_errors=False,
|
||||||
|
)
|
||||||
|
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||||
|
candidates.append(
|
||||||
|
OverloadCandidate(
|
||||||
|
function=function,
|
||||||
|
mapped=mapped,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||||
|
kw_types: str = ", ".join(
|
||||||
|
f"{name}: {type}" for name, (_, type) in keywords.items()
|
||||||
|
)
|
||||||
|
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||||
|
|
||||||
|
n_candidates: int = len(candidates)
|
||||||
|
|
||||||
|
# Exactly 1 match -> return it
|
||||||
|
if n_candidates == 1:
|
||||||
|
return candidates[0].function, None
|
||||||
|
|
||||||
|
# No match -> invalid call
|
||||||
|
if n_candidates == 0:
|
||||||
|
overloads_str: str = ", ".join(map(str, overloads))
|
||||||
|
errors_str: str = ", ".join(errors)
|
||||||
|
message: str = (
|
||||||
|
f"No matching overload in [{overloads_str}] {for_args} (errors: {errors_str})"
|
||||||
|
)
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(location, message)
|
||||||
|
return None, message
|
||||||
|
|
||||||
|
# Multiple matches -> see if one <: all others (more specific)
|
||||||
|
for i1, c1 in enumerate(candidates):
|
||||||
|
mapped1: list[MappedArgument[E]] = c1.mapped
|
||||||
|
best_match: bool = True
|
||||||
|
for i2, c2 in enumerate(candidates):
|
||||||
|
if i1 == i2:
|
||||||
|
continue
|
||||||
|
mapped2: list[MappedArgument[E]] = c2.mapped
|
||||||
|
if not self._are_mapped_subtypes(mapped1, mapped2):
|
||||||
|
best_match = False
|
||||||
|
break
|
||||||
|
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
||||||
|
if best_match:
|
||||||
|
return c1.function, None
|
||||||
|
|
||||||
|
candidates_str: str = ", ".join(
|
||||||
|
str(candidate.function) for candidate in candidates
|
||||||
|
)
|
||||||
|
message: str = f"Multiple matching overloads {for_args}: {candidates_str}"
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(location, message)
|
||||||
|
return None, message
|
||||||
|
|
||||||
|
def map_call_arguments(
|
||||||
|
self,
|
||||||
|
function: Function,
|
||||||
|
location: Location,
|
||||||
|
positional: list[TypedExpr[E]],
|
||||||
|
keywords: dict[str, TypedExpr[E]],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> tuple[bool, list[MappedArgument]]:
|
||||||
|
"""Map call arguments to a function's parameters as defined in its signature
|
||||||
|
|
||||||
|
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||||
|
with the arguments passed at the call site
|
||||||
|
|
||||||
|
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
||||||
|
unless `report_errors` is set to `False`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function (Function): the function definition
|
||||||
|
location (Location): the call location
|
||||||
|
positional (list[TypedExpr]): the list of positional arguments
|
||||||
|
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
||||||
|
the call is valid and the list of mapped arguments
|
||||||
|
"""
|
||||||
|
set_args: set[str] = set()
|
||||||
|
|
||||||
|
required_positional: list[str] = [
|
||||||
|
arg.name for arg in function.pos_args + function.args if arg.required
|
||||||
|
]
|
||||||
|
required_keyword: list[str] = [
|
||||||
|
arg.name for arg in function.kw_args if arg.required
|
||||||
|
]
|
||||||
|
|
||||||
|
mapped: list[MappedArgument[E]] = []
|
||||||
|
|
||||||
|
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||||
|
mixed_params: list[Function.Argument] = list(function.args)
|
||||||
|
kw_params: dict[str, Function.Argument] = {
|
||||||
|
arg.name: arg for arg in function.kw_args
|
||||||
|
}
|
||||||
|
|
||||||
|
valid_call: bool = True
|
||||||
|
|
||||||
|
# TODO: handle *args and **kwargs sinks
|
||||||
|
for arg in positional:
|
||||||
|
param: Function.Argument
|
||||||
|
if len(pos_params) != 0:
|
||||||
|
param = pos_params.pop(0)
|
||||||
|
elif len(mixed_params) != 0:
|
||||||
|
param = mixed_params.pop(0)
|
||||||
|
else:
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, "Too many positional arguments"
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
break
|
||||||
|
name: str = param.name
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||||
|
for name, arg in keywords.items():
|
||||||
|
param: Function.Argument
|
||||||
|
if name not in kw_params:
|
||||||
|
if report_errors:
|
||||||
|
if name in set_args:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, f"Multiple values for argument '{name}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
continue
|
||||||
|
param = kw_params.pop(name)
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def join_args(args: list[str]) -> str:
|
||||||
|
args = list(map(lambda a: f"'{a}'", args))
|
||||||
|
if len(args) == 0:
|
||||||
|
return ""
|
||||||
|
if len(args) == 1:
|
||||||
|
return args[0]
|
||||||
|
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||||
|
|
||||||
|
if len(required_positional) != 0:
|
||||||
|
plural: str = "" if len(required_positional) == 1 else "s"
|
||||||
|
args: str = join_args(required_positional)
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Missing required positional argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
|
||||||
|
if len(required_keyword) != 0:
|
||||||
|
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||||
|
args: str = join_args(required_keyword)
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Missing required keyword argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
|
||||||
|
return valid_call, mapped
|
||||||
|
|
||||||
|
def _are_mapped_subtypes(
|
||||||
|
self, mapped1: list[MappedArgument[E]], mapped2: list[MappedArgument[E]]
|
||||||
|
) -> bool:
|
||||||
|
"""Check whether the given argument mappings are subtype/supertype of one another
|
||||||
|
|
||||||
|
This function checks whether the argument mappings `mapped1` are subtypes
|
||||||
|
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
||||||
|
of the corresponding parameter in `mapped2`, `False` is returned.
|
||||||
|
|
||||||
|
This is used to check whether a given overload is
|
||||||
|
a more specific function/ a subtype of another.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
||||||
|
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
||||||
|
"""
|
||||||
|
by_expr: dict[E, Type] = {}
|
||||||
|
for arg in mapped1:
|
||||||
|
by_expr[arg.expr] = arg.argument.type
|
||||||
|
|
||||||
|
for arg in mapped2:
|
||||||
|
type2: Type = arg.argument.type
|
||||||
|
type1: Type = by_expr[arg.expr]
|
||||||
|
if not self.types.is_subtype(type1, type2):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
203
midas/checker/frames/column_groupby_methods.py
Normal file
203
midas/checker/frames/column_groupby_methods.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import midas.ast.python as p
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.dispatcher import CallResult
|
||||||
|
from midas.checker.frames.utils import MethodRegistry, method
|
||||||
|
from midas.checker.types import ColumnGroupBy, ColumnType, Function, TopType, Type
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from midas.checker.python import TypedExpr
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Call:
|
||||||
|
location: Location
|
||||||
|
call_expr: p.Expr
|
||||||
|
groupby: ColumnGroupBy
|
||||||
|
groupby_expr: p.Expr
|
||||||
|
positional: list[TypedExpr]
|
||||||
|
keywords: dict[str, TypedExpr]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subject(self) -> TypedExpr:
|
||||||
|
return (self.groupby_expr, self.groupby)
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
|
||||||
|
NAMED_ARGS: dict[str, str] = {
|
||||||
|
"numeric_only": "bool",
|
||||||
|
"skipna": "bool",
|
||||||
|
"engine": "str",
|
||||||
|
"engine_kwargs": "dict",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _aggregate(
|
||||||
|
self,
|
||||||
|
call: Call,
|
||||||
|
args: list[str | tuple[str, str, bool]] = [],
|
||||||
|
*,
|
||||||
|
preserve_inner_type: bool = False,
|
||||||
|
) -> Type:
|
||||||
|
real_args: list[Function.Argument] = []
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
match arg:
|
||||||
|
case str() as name:
|
||||||
|
arg = Function.Argument(
|
||||||
|
pos=i,
|
||||||
|
name=name,
|
||||||
|
type=self.types.get_type(self.NAMED_ARGS[name]),
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
case (name, type, required):
|
||||||
|
arg = Function.Argument(
|
||||||
|
pos=i,
|
||||||
|
name=name,
|
||||||
|
type=self.types.get_type(type),
|
||||||
|
required=required,
|
||||||
|
)
|
||||||
|
real_args.append(arg)
|
||||||
|
|
||||||
|
signature = Function(
|
||||||
|
args=real_args,
|
||||||
|
returns=(
|
||||||
|
call.groupby.column
|
||||||
|
if preserve_inner_type
|
||||||
|
else ColumnType(type=TopType())
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
result: CallResult = self.dispatcher.get_result(
|
||||||
|
location=call.location,
|
||||||
|
callee=signature,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
return result.result
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def kurt(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
["skipna", "numeric_only"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def max(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
"numeric_only",
|
||||||
|
(
|
||||||
|
"min_count",
|
||||||
|
"int",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
"skipna",
|
||||||
|
"engine",
|
||||||
|
"engine_kwargs",
|
||||||
|
],
|
||||||
|
preserve_inner_type=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def mean(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
["numeric_only", "skipna", "engine", "engine_kwargs"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def median(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
["numeric_only", "skipna"],
|
||||||
|
preserve_inner_type=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def min(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
"numeric_only",
|
||||||
|
(
|
||||||
|
"min_count",
|
||||||
|
"int",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
"skipna",
|
||||||
|
"engine",
|
||||||
|
"engine_kwargs",
|
||||||
|
],
|
||||||
|
preserve_inner_type=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def prod(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
"numeric_only",
|
||||||
|
(
|
||||||
|
"min_count",
|
||||||
|
"int",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
"skipna",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def std(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"ddof",
|
||||||
|
"int",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
"engine",
|
||||||
|
"engine_kwargs",
|
||||||
|
"numeric_only",
|
||||||
|
"skipna",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def sum(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
"numeric_only",
|
||||||
|
(
|
||||||
|
"min_count",
|
||||||
|
"int",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
"skipna",
|
||||||
|
"engine",
|
||||||
|
"engine_kwargs",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def var(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"var",
|
||||||
|
"int",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
"engine",
|
||||||
|
"engine_kwargs",
|
||||||
|
"numeric_only",
|
||||||
|
"skipna",
|
||||||
|
],
|
||||||
|
)
|
||||||
78
midas/checker/frames/column_manager.py
Normal file
78
midas/checker/frames/column_manager.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import midas.ast.python as p
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.frames.column_groupby_methods import Call as GroupByCall
|
||||||
|
from midas.checker.frames.column_groupby_methods import ColumnGroupByMethodRegistry
|
||||||
|
from midas.checker.frames.column_methods import Call, ColumnMethodRegistry
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.types import ColumnGroupBy, ColumnType, Type
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from midas.checker.python import PythonTyper, TypedExpr
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnManager:
|
||||||
|
def __init__(self, typer: PythonTyper) -> None:
|
||||||
|
self.typer: PythonTyper = typer
|
||||||
|
self.method_resolver: ColumnMethodRegistry = ColumnMethodRegistry(self.typer)
|
||||||
|
self.groupby_method_resolver: ColumnGroupByMethodRegistry = (
|
||||||
|
ColumnGroupByMethodRegistry(self.typer)
|
||||||
|
)
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
location: Location,
|
||||||
|
call_expr: p.Expr,
|
||||||
|
column: ColumnType,
|
||||||
|
column_expr: p.Expr,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
) -> Type:
|
||||||
|
call: Call = Call(
|
||||||
|
location=location,
|
||||||
|
call_expr=call_expr,
|
||||||
|
column=column,
|
||||||
|
column_expr=column_expr,
|
||||||
|
positional=positional,
|
||||||
|
keywords=keywords,
|
||||||
|
)
|
||||||
|
return self.method_resolver.call(method, call)
|
||||||
|
|
||||||
|
def groupby_call(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
location: Location,
|
||||||
|
call_expr: p.Expr,
|
||||||
|
groupby: ColumnGroupBy,
|
||||||
|
groupby_expr: p.Expr,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
) -> Type:
|
||||||
|
call: GroupByCall = GroupByCall(
|
||||||
|
location=location,
|
||||||
|
call_expr=call_expr,
|
||||||
|
groupby=groupby,
|
||||||
|
groupby_expr=groupby_expr,
|
||||||
|
positional=positional,
|
||||||
|
keywords=keywords,
|
||||||
|
)
|
||||||
|
return self.groupby_method_resolver.call(method, call)
|
||||||
|
|
||||||
|
def get_attribute(self, column: ColumnType, name: str) -> Optional[Type]:
|
||||||
|
types: TypesRegistry = self.typer.types
|
||||||
|
match name:
|
||||||
|
case "ndim" | "size":
|
||||||
|
return types.get_type("int")
|
||||||
|
|
||||||
|
case "shape":
|
||||||
|
return types.tuple_of("int")
|
||||||
|
|
||||||
|
case "T":
|
||||||
|
return column
|
||||||
|
|
||||||
|
case _:
|
||||||
|
return None
|
||||||
410
midas/checker/frames/column_methods.py
Normal file
410
midas/checker/frames/column_methods.py
Normal file
@@ -0,0 +1,410 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ast
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import midas.ast.python as p
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.dispatcher import CallResult
|
||||||
|
from midas.checker.frames.utils import MethodRegistry, method
|
||||||
|
from midas.checker.types import (
|
||||||
|
ColumnGroupBy,
|
||||||
|
ColumnType,
|
||||||
|
Function,
|
||||||
|
GenericType,
|
||||||
|
TopType,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
UnknownType,
|
||||||
|
unfold_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from midas.checker.python import TypedExpr
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Call:
|
||||||
|
location: Location
|
||||||
|
call_expr: p.Expr
|
||||||
|
column: ColumnType
|
||||||
|
column_expr: p.Expr
|
||||||
|
positional: list[TypedExpr]
|
||||||
|
keywords: dict[str, TypedExpr]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subject(self) -> TypedExpr:
|
||||||
|
return (self.column_expr, self.column)
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||||
|
def _element_binary_op(self, call: Call, method: str) -> ColumnType:
|
||||||
|
"""Compute the result of an element-wise binary operation
|
||||||
|
|
||||||
|
This function delegates to the inner types for computing the resulting
|
||||||
|
type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call (Call): the call that triggered this resolution
|
||||||
|
method (str): the method name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ColumnType: the resulting column type
|
||||||
|
"""
|
||||||
|
column2: Optional[ColumnType] = None
|
||||||
|
|
||||||
|
col_type1: Type = call.column.type
|
||||||
|
new_column: Type = ColumnType(type=UnknownType())
|
||||||
|
if len(call.positional) != 0:
|
||||||
|
other: Type = call.positional[0][1]
|
||||||
|
unfolded_other: Type = unfold_type(other)
|
||||||
|
if isinstance(unfolded_other, ColumnType):
|
||||||
|
column2 = unfolded_other
|
||||||
|
col_type2: Type = column2.type
|
||||||
|
|
||||||
|
new_inner_type = self.typer.result_of_binary_op(
|
||||||
|
location=call.location,
|
||||||
|
expr=call.call_expr,
|
||||||
|
left=(call.column_expr, col_type1),
|
||||||
|
right=(call.positional[0][0], col_type2),
|
||||||
|
method=method,
|
||||||
|
)
|
||||||
|
new_column = ColumnType(type=new_inner_type)
|
||||||
|
return new_column
|
||||||
|
|
||||||
|
def _element_wise(self, call: Call, method: str) -> Type:
|
||||||
|
# TODO: support add with scalar
|
||||||
|
|
||||||
|
# Build signature with new column type and generic operand
|
||||||
|
param_type: TypeVar = TypeVar(name="T", bound=None)
|
||||||
|
signature = GenericType(
|
||||||
|
name="add",
|
||||||
|
params=[param_type],
|
||||||
|
body=Function(
|
||||||
|
args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="other",
|
||||||
|
type=ColumnType(type=param_type),
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
returns=self._element_binary_op(call, method),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map arguments and compute result type
|
||||||
|
result: CallResult = self.dispatcher.get_result(
|
||||||
|
location=call.location,
|
||||||
|
callee=signature,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
if result.is_valid:
|
||||||
|
self._assert_same_length(
|
||||||
|
call.call_expr, call.column_expr, call.positional[0][0]
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.result
|
||||||
|
|
||||||
|
@method("add", "__add__")
|
||||||
|
def add(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__add__")
|
||||||
|
|
||||||
|
@method("sub", "__sub__")
|
||||||
|
def sub(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__sub__")
|
||||||
|
|
||||||
|
@method("mul", "__mul__")
|
||||||
|
def mul(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__mul__")
|
||||||
|
|
||||||
|
@method("div", "truediv", "__truediv__")
|
||||||
|
def truediv(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__truediv__")
|
||||||
|
|
||||||
|
@method("floordiv", "__floordiv__")
|
||||||
|
def floordiv(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__floordiv__")
|
||||||
|
|
||||||
|
@method("mod", "__mod__")
|
||||||
|
def mod(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__mod__")
|
||||||
|
|
||||||
|
@method("pow", "__pow__")
|
||||||
|
def pow(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__pow__")
|
||||||
|
|
||||||
|
@method("lt", "__lt__")
|
||||||
|
def lt(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__lt__")
|
||||||
|
|
||||||
|
@method("gt", "__gt__")
|
||||||
|
def gt(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__gt__")
|
||||||
|
|
||||||
|
@method("le", "__le__")
|
||||||
|
def le(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__le__")
|
||||||
|
|
||||||
|
@method("ge", "__ge__")
|
||||||
|
def ge(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__ge__")
|
||||||
|
|
||||||
|
@method("ne", "__ne__")
|
||||||
|
def ne(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__ne__")
|
||||||
|
|
||||||
|
@method("eq", "__eq__")
|
||||||
|
def eq(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__eq__")
|
||||||
|
|
||||||
|
def _aggregate(
|
||||||
|
self,
|
||||||
|
call: Call,
|
||||||
|
kwargs: list[Function.Argument] = [],
|
||||||
|
*,
|
||||||
|
preserve_inner_type: bool = False,
|
||||||
|
) -> Type:
|
||||||
|
signature = Function(
|
||||||
|
kw_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="axis",
|
||||||
|
type=TopType(),
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
*kwargs,
|
||||||
|
],
|
||||||
|
returns=call.column if preserve_inner_type else ColumnType(type=TopType()),
|
||||||
|
)
|
||||||
|
|
||||||
|
result: CallResult = self.dispatcher.get_result(
|
||||||
|
location=call.location,
|
||||||
|
callee=signature,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
return result.result
|
||||||
|
|
||||||
|
@method("kurtosis", "kurt")
|
||||||
|
def kurtosis(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def max(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call, preserve_inner_type=True)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def mean(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def median(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call, preserve_inner_type=True)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def min(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call, preserve_inner_type=True)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def mode(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call, preserve_inner_type=True)
|
||||||
|
|
||||||
|
@method("product", "prod")
|
||||||
|
def product(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def std(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
Function.Argument(
|
||||||
|
pos=1,
|
||||||
|
name="ddof",
|
||||||
|
type=self.types.get_type("int"),
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def sum(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def var(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
Function.Argument(
|
||||||
|
pos=1,
|
||||||
|
name="var",
|
||||||
|
type=self.types.get_type("int"),
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def head(self, call: Call) -> Type:
|
||||||
|
signature = Function(
|
||||||
|
args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="n",
|
||||||
|
type=self.types.get_type("int"),
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
returns=call.column,
|
||||||
|
)
|
||||||
|
|
||||||
|
result: CallResult = self.dispatcher.get_result(
|
||||||
|
location=call.location,
|
||||||
|
callee=signature,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
return result.result
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def tail(self, call: Call) -> Type:
|
||||||
|
signature = Function(
|
||||||
|
args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="n",
|
||||||
|
type=self.types.get_type("int"),
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
returns=call.column,
|
||||||
|
)
|
||||||
|
|
||||||
|
result: CallResult = self.dispatcher.get_result(
|
||||||
|
location=call.location,
|
||||||
|
callee=signature,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
return result.result
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def groupby(self, call: Call) -> Type:
|
||||||
|
bool_: Type = self.types.get_type("bool")
|
||||||
|
function: Function = Function(
|
||||||
|
args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="by",
|
||||||
|
type=TopType(),
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
Function.Argument(
|
||||||
|
pos=1,
|
||||||
|
name="level",
|
||||||
|
type=TopType(),
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
kw_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=2,
|
||||||
|
name="as_index",
|
||||||
|
type=bool_,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
Function.Argument(
|
||||||
|
pos=3,
|
||||||
|
name="sort",
|
||||||
|
type=bool_,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
Function.Argument(
|
||||||
|
pos=4,
|
||||||
|
name="group_keys",
|
||||||
|
type=bool_,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
Function.Argument(
|
||||||
|
pos=5,
|
||||||
|
name="observed",
|
||||||
|
type=bool_,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
Function.Argument(
|
||||||
|
pos=6,
|
||||||
|
name="dropna",
|
||||||
|
type=bool_,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
returns=ColumnGroupBy(column=call.column),
|
||||||
|
)
|
||||||
|
|
||||||
|
result: CallResult = self.dispatcher.get_result(
|
||||||
|
location=call.location,
|
||||||
|
callee=function,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
return result.result
|
||||||
|
|
||||||
|
def _assert_same_length(self, call_expr: p.Expr, column1: p.Expr, column2: p.Expr):
|
||||||
|
func_name: str = "__midas_column_same_length__"
|
||||||
|
|
||||||
|
# Efficiently compute length
|
||||||
|
# https://stackoverflow.com/a/15943975/11109181
|
||||||
|
def len_of_col(col: ast.expr) -> ast.expr:
|
||||||
|
return ast.Call(
|
||||||
|
func=ast.Name(id="len"),
|
||||||
|
args=[
|
||||||
|
ast.Attribute(
|
||||||
|
value=col,
|
||||||
|
attr="index",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
keywords=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertions.define(
|
||||||
|
func_name,
|
||||||
|
ast.FunctionDef(
|
||||||
|
name=func_name,
|
||||||
|
args=ast.arguments(
|
||||||
|
posonlyargs=[],
|
||||||
|
args=[
|
||||||
|
ast.arg(arg="column1"),
|
||||||
|
ast.arg(arg="column2"),
|
||||||
|
],
|
||||||
|
kwonlyargs=[],
|
||||||
|
defaults=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
),
|
||||||
|
body=[
|
||||||
|
ast.Return(
|
||||||
|
value=ast.Compare(
|
||||||
|
left=len_of_col(ast.Name(id="column1")),
|
||||||
|
ops=[ast.Eq()],
|
||||||
|
comparators=[
|
||||||
|
len_of_col(ast.Name(id="column2")),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
decorator_list=[],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertions.add(
|
||||||
|
bound_expr=call_expr,
|
||||||
|
inputs=[column1, column2],
|
||||||
|
builder=lambda c1, c2: ast.Call(
|
||||||
|
func=ast.Name(id=func_name),
|
||||||
|
args=[c1, c2],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
message="Columns must have the same length",
|
||||||
|
)
|
||||||
103
midas/checker/frames/frame_groupby_methods.py
Normal file
103
midas/checker/frames/frame_groupby_methods.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import midas.ast.python as p
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.frames.utils import MethodRegistry, method
|
||||||
|
from midas.checker.types import (
|
||||||
|
ColumnGroupBy,
|
||||||
|
ColumnType,
|
||||||
|
DataFrameType,
|
||||||
|
FrameGroupBy,
|
||||||
|
Type,
|
||||||
|
UnknownType,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from midas.checker.python import TypedExpr
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Call:
|
||||||
|
location: Location
|
||||||
|
call_expr: p.Expr
|
||||||
|
groupby: FrameGroupBy
|
||||||
|
groupby_expr: p.Expr
|
||||||
|
positional: list[TypedExpr]
|
||||||
|
keywords: dict[str, TypedExpr]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subject(self) -> TypedExpr:
|
||||||
|
return (self.groupby_expr, self.groupby)
|
||||||
|
|
||||||
|
|
||||||
|
class FrameGroupByMethodRegistry(MethodRegistry[Call]):
|
||||||
|
NAMED_ARGS: dict[str, str] = {
|
||||||
|
"numeric_only": "bool",
|
||||||
|
"skipna": "bool",
|
||||||
|
"engine": "str",
|
||||||
|
"engine_kwargs": "dict",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _aggregate(self, call: Call, method: str) -> Type:
|
||||||
|
new_columns: list[DataFrameType.Column] = []
|
||||||
|
|
||||||
|
for column in call.groupby.frame.columns:
|
||||||
|
column_groupby: ColumnGroupBy = ColumnGroupBy(column=column.type)
|
||||||
|
result_type: Type = self.typer.call_method(
|
||||||
|
location=call.location,
|
||||||
|
call_expr=call.call_expr,
|
||||||
|
obj=(call.groupby_expr, column_groupby),
|
||||||
|
method_name=method,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
if not isinstance(result_type, ColumnType):
|
||||||
|
result_type = ColumnType(type=UnknownType())
|
||||||
|
new_columns.append(
|
||||||
|
DataFrameType.Column(
|
||||||
|
index=column.index,
|
||||||
|
name=column.name,
|
||||||
|
type=result_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return DataFrameType(columns=new_columns)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def kurt(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call, "kurt")
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def max(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call, "max")
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def mean(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call, "mean")
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def median(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call, "median")
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def min(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call, "min")
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def prod(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call, "prod")
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def std(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call, "std")
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def sum(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call, "sum")
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def var(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call, "var")
|
||||||
255
midas/checker/frames/frame_manager.py
Normal file
255
midas/checker/frames/frame_manager.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional, TypeGuard, cast
|
||||||
|
|
||||||
|
import midas.ast.python as p
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.frames.frame_groupby_methods import Call as GroupByCall
|
||||||
|
from midas.checker.frames.frame_groupby_methods import FrameGroupByMethodRegistry
|
||||||
|
from midas.checker.frames.frame_methods import Call, FrameMethodRegistry
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.reporter import FileReporter
|
||||||
|
from midas.checker.types import (
|
||||||
|
ColumnGroupBy,
|
||||||
|
ColumnType,
|
||||||
|
DataFrameType,
|
||||||
|
FrameGroupBy,
|
||||||
|
TupleType,
|
||||||
|
Type,
|
||||||
|
UnknownType,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from midas.checker.python import PythonTyper, TypedExpr
|
||||||
|
|
||||||
|
|
||||||
|
def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
|
||||||
|
return all(isinstance(expr, p.LiteralExpr) for expr in exprs)
|
||||||
|
|
||||||
|
|
||||||
|
class FrameManager:
|
||||||
|
def __init__(self, typer: PythonTyper) -> None:
|
||||||
|
self.typer: PythonTyper = typer
|
||||||
|
self.method_resolver: FrameMethodRegistry = FrameMethodRegistry(self.typer)
|
||||||
|
self.groupby_method_resolver: FrameGroupByMethodRegistry = (
|
||||||
|
FrameGroupByMethodRegistry(self.typer)
|
||||||
|
)
|
||||||
|
|
||||||
|
def assign(
|
||||||
|
self,
|
||||||
|
reporter: FileReporter,
|
||||||
|
location: Location,
|
||||||
|
frame: DataFrameType,
|
||||||
|
index: p.Expr,
|
||||||
|
value_type: Type,
|
||||||
|
) -> Type:
|
||||||
|
match index:
|
||||||
|
case p.LiteralExpr(value=str() as name):
|
||||||
|
return self.assign_column(reporter, location, frame, name, value_type)
|
||||||
|
|
||||||
|
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
|
||||||
|
isinstance(index.value, str) for index in indices
|
||||||
|
):
|
||||||
|
names: list[str] = [cast(str, index.value) for index in indices]
|
||||||
|
|
||||||
|
if not isinstance(value_type, TupleType):
|
||||||
|
reporter.error(
|
||||||
|
location,
|
||||||
|
f"Cannot assign {type} to dataframe columns. Must be a tuple of columns",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
if len(names) != len(value_type.items):
|
||||||
|
reporter.error(
|
||||||
|
location,
|
||||||
|
f"Wrong number of columns. Cannot assign {len(value_type.items)} to {len(names)} targets",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
new_frame: Type = frame
|
||||||
|
for name, value in zip(names, value_type.items):
|
||||||
|
new_frame = self.assign_column(
|
||||||
|
reporter,
|
||||||
|
location,
|
||||||
|
new_frame,
|
||||||
|
name,
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
if not isinstance(new_frame, DataFrameType):
|
||||||
|
return new_frame
|
||||||
|
return new_frame
|
||||||
|
|
||||||
|
case _:
|
||||||
|
reporter.error(
|
||||||
|
location, f"Invalid index type {index} on {frame} (assignment)"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def assign_column(
|
||||||
|
self,
|
||||||
|
reporter: FileReporter,
|
||||||
|
location: Location,
|
||||||
|
frame: DataFrameType,
|
||||||
|
name: str,
|
||||||
|
type: Type,
|
||||||
|
) -> Type:
|
||||||
|
if not isinstance(type, ColumnType):
|
||||||
|
reporter.error(
|
||||||
|
location,
|
||||||
|
f"Cannot assign {type} to dataframe column. Must be a ColumnType",
|
||||||
|
)
|
||||||
|
return frame
|
||||||
|
return self._set_column(frame, name, type)
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
reporter: FileReporter,
|
||||||
|
location: Location,
|
||||||
|
frame: DataFrameType,
|
||||||
|
index: p.Expr,
|
||||||
|
) -> Type:
|
||||||
|
match index:
|
||||||
|
case p.LiteralExpr(value=str() as name):
|
||||||
|
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
|
||||||
|
if column is None:
|
||||||
|
reporter.error(location, f"Unknown column '{name}' on {frame}")
|
||||||
|
return UnknownType()
|
||||||
|
return column
|
||||||
|
|
||||||
|
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
|
||||||
|
isinstance(index.value, str) for index in indices
|
||||||
|
):
|
||||||
|
names: list[str] = [cast(str, index.value) for index in indices]
|
||||||
|
columns: list[ColumnType] = []
|
||||||
|
for name in names:
|
||||||
|
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
|
||||||
|
if column is None:
|
||||||
|
reporter.error(location, f"Unknown column '{name}' on {frame}")
|
||||||
|
return UnknownType()
|
||||||
|
columns.append(column)
|
||||||
|
return TupleType(items=tuple(columns))
|
||||||
|
|
||||||
|
case _:
|
||||||
|
reporter.error(
|
||||||
|
location, f"Invalid index type {index} on {frame} (access)"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def groupby_get(
|
||||||
|
self,
|
||||||
|
reporter: FileReporter,
|
||||||
|
location: Location,
|
||||||
|
groupby: FrameGroupBy,
|
||||||
|
index: p.Expr,
|
||||||
|
) -> Type:
|
||||||
|
result: Type = self.get(reporter, location, groupby.frame, index)
|
||||||
|
match result:
|
||||||
|
case ColumnType():
|
||||||
|
result = ColumnGroupBy(column=result)
|
||||||
|
case TupleType(items=columns):
|
||||||
|
result = TupleType(
|
||||||
|
items=tuple(
|
||||||
|
ColumnGroupBy(column=cast(ColumnType, column))
|
||||||
|
for column in columns
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _set_column(
|
||||||
|
cls, frame: DataFrameType, name: str, column: ColumnType
|
||||||
|
) -> DataFrameType:
|
||||||
|
new_columns: list[DataFrameType.Column] = []
|
||||||
|
index: int = len(frame.columns)
|
||||||
|
replace: bool = False
|
||||||
|
for i, col in enumerate(frame.columns):
|
||||||
|
if col.name == name:
|
||||||
|
index = i
|
||||||
|
replace = True
|
||||||
|
# TODO: check column type here to prevent changing it
|
||||||
|
new_columns.append(col)
|
||||||
|
|
||||||
|
new_col: DataFrameType.Column = DataFrameType.Column(
|
||||||
|
index=index,
|
||||||
|
name=name,
|
||||||
|
type=column,
|
||||||
|
)
|
||||||
|
if replace:
|
||||||
|
new_columns[index] = new_col
|
||||||
|
else:
|
||||||
|
new_columns.append(new_col)
|
||||||
|
|
||||||
|
return DataFrameType(columns=new_columns)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _set_columns(
|
||||||
|
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
|
||||||
|
) -> DataFrameType:
|
||||||
|
for name, col in zip(names, columns):
|
||||||
|
frame = cls._set_column(frame, name, col)
|
||||||
|
return frame
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
|
||||||
|
for col in frame.columns:
|
||||||
|
if col.name == name:
|
||||||
|
return col.type
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_columns(
|
||||||
|
cls, frame: DataFrameType, names: list[str]
|
||||||
|
) -> list[Optional[ColumnType]]:
|
||||||
|
return [cls._get_column(frame, name) for name in names]
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
location: Location,
|
||||||
|
call_expr: p.Expr,
|
||||||
|
frame: DataFrameType,
|
||||||
|
frame_expr: p.Expr,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
) -> Type:
|
||||||
|
call: Call = Call(
|
||||||
|
location=location,
|
||||||
|
call_expr=call_expr,
|
||||||
|
frame=frame,
|
||||||
|
frame_expr=frame_expr,
|
||||||
|
positional=positional,
|
||||||
|
keywords=keywords,
|
||||||
|
)
|
||||||
|
return self.method_resolver.call(method, call)
|
||||||
|
|
||||||
|
def groupby_call(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
location: Location,
|
||||||
|
call_expr: p.Expr,
|
||||||
|
groupby: FrameGroupBy,
|
||||||
|
groupby_expr: p.Expr,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
) -> Type:
|
||||||
|
call: GroupByCall = GroupByCall(
|
||||||
|
location=location,
|
||||||
|
call_expr=call_expr,
|
||||||
|
groupby=groupby,
|
||||||
|
groupby_expr=groupby_expr,
|
||||||
|
positional=positional,
|
||||||
|
keywords=keywords,
|
||||||
|
)
|
||||||
|
return self.groupby_method_resolver.call(method, call)
|
||||||
|
|
||||||
|
def get_attribute(self, frame: DataFrameType, name: str) -> Optional[Type]:
|
||||||
|
types: TypesRegistry = self.typer.types
|
||||||
|
match name:
|
||||||
|
case "ndim" | "size":
|
||||||
|
return types.get_type("int")
|
||||||
|
|
||||||
|
case "shape":
|
||||||
|
return types.tuple_of("int", "int")
|
||||||
|
|
||||||
|
case _:
|
||||||
|
return None
|
||||||
487
midas/checker/frames/frame_methods.py
Normal file
487
midas/checker/frames/frame_methods.py
Normal file
@@ -0,0 +1,487 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ast
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import midas.ast.python as p
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.dispatcher import CallResult
|
||||||
|
from midas.checker.frames.utils import MethodRegistry, method
|
||||||
|
from midas.checker.types import (
|
||||||
|
ColumnType,
|
||||||
|
DataFrameType,
|
||||||
|
FrameGroupBy,
|
||||||
|
Function,
|
||||||
|
OverloadedFunction,
|
||||||
|
TopType,
|
||||||
|
Type,
|
||||||
|
UnknownType,
|
||||||
|
unfold_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from midas.checker.python import TypedExpr
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Call:
|
||||||
|
location: Location
|
||||||
|
call_expr: p.Expr
|
||||||
|
frame: DataFrameType
|
||||||
|
frame_expr: p.Expr
|
||||||
|
positional: list[TypedExpr]
|
||||||
|
keywords: dict[str, TypedExpr]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subject(self) -> TypedExpr:
|
||||||
|
return (self.frame_expr, self.frame)
|
||||||
|
|
||||||
|
|
||||||
|
class FrameMethodRegistry(MethodRegistry[Call]):
|
||||||
|
def _get_method_result(
|
||||||
|
self,
|
||||||
|
call: Call,
|
||||||
|
column1: ColumnType,
|
||||||
|
column2: ColumnType,
|
||||||
|
method: str,
|
||||||
|
) -> ColumnType:
|
||||||
|
"""Get the result of calling a method on a column, passing a second
|
||||||
|
|
||||||
|
This function delegates to the main typer the resolution of the method
|
||||||
|
member, as well as computing the result type. Because we don't have any
|
||||||
|
AST expression for the individual columns, the frame expressions are
|
||||||
|
used instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call (Call): the call that triggered this resolution
|
||||||
|
column1 (ColumnType): the first column, i.e. left operand
|
||||||
|
column2 (ColumnType): the second column, i.e. right operand
|
||||||
|
method (str): the method name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ColumnType: the resulting column.
|
||||||
|
If the operation is invalid / doesn't exist,
|
||||||
|
`ColumnType(type=UnknownType())` is returned
|
||||||
|
"""
|
||||||
|
|
||||||
|
result: Type = self.typer.result_of_binary_op(
|
||||||
|
location=call.location,
|
||||||
|
expr=call.call_expr,
|
||||||
|
left=(call.frame_expr, column1),
|
||||||
|
right=(call.positional[0][0], column2),
|
||||||
|
method=method,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(result, ColumnType):
|
||||||
|
return ColumnType(type=UnknownType())
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _element_binary_op(self, call: Call, method: str) -> DataFrameType:
|
||||||
|
"""Compute the result of an element-wise binary operation
|
||||||
|
|
||||||
|
This function delegates to the matching columns for computing resulting
|
||||||
|
types. Any column only present in one of the frames is forwarded as a
|
||||||
|
generic `ColumnType(type=UnknownType())`. Columns only in the second
|
||||||
|
frame are append at the end of the schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call (Call): the call that triggered this resolution
|
||||||
|
method (str): the method name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrameType: the resulting frame type
|
||||||
|
"""
|
||||||
|
new_columns: list[DataFrameType.Column] = []
|
||||||
|
|
||||||
|
by_name: dict[str, DataFrameType.Column] = {}
|
||||||
|
frame2: Optional[DataFrameType] = None
|
||||||
|
# Get map of operand's columns by name, if there is at least 1 operand, which is a dataframe
|
||||||
|
if len(call.positional) != 0:
|
||||||
|
operand: TypedExpr = call.positional[0]
|
||||||
|
unfolded_other: Type = unfold_type(operand[1])
|
||||||
|
if isinstance(unfolded_other, DataFrameType):
|
||||||
|
frame2 = unfolded_other
|
||||||
|
by_name = {
|
||||||
|
col.name: col for col in frame2.columns if col.name is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
# Compute new schema:
|
||||||
|
# Step 1: for all columns in frame1:
|
||||||
|
# - if present in frame2 -> delegate operation to columns
|
||||||
|
# - if not -> add to schema as unknown
|
||||||
|
in_frame1: set[str] = set()
|
||||||
|
for column in call.frame.columns:
|
||||||
|
if column.name is not None:
|
||||||
|
in_frame1.add(column.name)
|
||||||
|
|
||||||
|
col_type1: ColumnType = column.type
|
||||||
|
col_type: ColumnType = ColumnType(type=UnknownType())
|
||||||
|
if column.name in by_name:
|
||||||
|
column2 = by_name[column.name]
|
||||||
|
col_type2: ColumnType = column2.type
|
||||||
|
|
||||||
|
col_type = self._get_method_result(call, col_type1, col_type2, method)
|
||||||
|
|
||||||
|
new_column = DataFrameType.Column(
|
||||||
|
index=column.index,
|
||||||
|
name=column.name,
|
||||||
|
type=col_type,
|
||||||
|
)
|
||||||
|
new_columns.append(new_column)
|
||||||
|
|
||||||
|
# Step 2: for all columns in frame2
|
||||||
|
# - if not in frame1 -> add to schema as unknown
|
||||||
|
if frame2 is not None:
|
||||||
|
for column in frame2.columns:
|
||||||
|
if column.name in in_frame1:
|
||||||
|
continue
|
||||||
|
new_columns.append(
|
||||||
|
DataFrameType.Column(
|
||||||
|
index=len(new_columns),
|
||||||
|
name=column.name,
|
||||||
|
type=ColumnType(type=UnknownType()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return DataFrameType(columns=new_columns)
|
||||||
|
|
||||||
|
def _element_wise(self, call: Call, method: str) -> Type:
|
||||||
|
# TODO: support scalar, sequence, Series, dict operand
|
||||||
|
# Build signature with new schema and generic operand
|
||||||
|
signature = Function(
|
||||||
|
args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="other",
|
||||||
|
type=DataFrameType(columns=[]),
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
returns=self._element_binary_op(call, method),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map arguments and compute result type
|
||||||
|
result: CallResult = self.dispatcher.get_result(
|
||||||
|
location=call.location,
|
||||||
|
callee=signature,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
if result.is_valid:
|
||||||
|
self._assert_same_length(
|
||||||
|
call.call_expr, call.frame_expr, call.positional[0][0]
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.result
|
||||||
|
|
||||||
|
@method("add", "__add__")
|
||||||
|
def add(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__add__")
|
||||||
|
|
||||||
|
@method("sub", "__sub__")
|
||||||
|
def sub(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__sub__")
|
||||||
|
|
||||||
|
@method("mul", "__mul__")
|
||||||
|
def mul(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__mul__")
|
||||||
|
|
||||||
|
@method("div", "truediv", "__truediv__")
|
||||||
|
def truediv(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__truediv__")
|
||||||
|
|
||||||
|
@method("floordiv", "__floordiv__")
|
||||||
|
def floordiv(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__floordiv__")
|
||||||
|
|
||||||
|
@method("mod", "__mod__")
|
||||||
|
def mod(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__mod__")
|
||||||
|
|
||||||
|
@method("pow", "__pow__")
|
||||||
|
def pow(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__pow__")
|
||||||
|
|
||||||
|
@method("lt", "__lt__")
|
||||||
|
def lt(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__lt__")
|
||||||
|
|
||||||
|
@method("gt", "__gt__")
|
||||||
|
def gt(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__gt__")
|
||||||
|
|
||||||
|
@method("le", "__le__")
|
||||||
|
def le(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__le__")
|
||||||
|
|
||||||
|
@method("ge", "__ge__")
|
||||||
|
def ge(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__ge__")
|
||||||
|
|
||||||
|
@method("ne", "__ne__")
|
||||||
|
def ne(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__ne__")
|
||||||
|
|
||||||
|
@method("eq", "__eq__")
|
||||||
|
def eq(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__eq__")
|
||||||
|
|
||||||
|
def _aggregate(self, call: Call, kwargs: list[Function.Argument] = []) -> Type:
|
||||||
|
with_axis = Function(
|
||||||
|
kw_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="axis",
|
||||||
|
type=self.types.get_type("int"),
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
*kwargs,
|
||||||
|
],
|
||||||
|
returns=ColumnType(type=TopType()),
|
||||||
|
)
|
||||||
|
without_axis = Function(
|
||||||
|
kw_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="axis",
|
||||||
|
type=self.types.get_type("None"),
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
*kwargs,
|
||||||
|
],
|
||||||
|
returns=TopType(),
|
||||||
|
)
|
||||||
|
overload = OverloadedFunction(
|
||||||
|
overloads=[
|
||||||
|
with_axis,
|
||||||
|
without_axis,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result: CallResult = self.dispatcher.get_result(
|
||||||
|
location=call.location,
|
||||||
|
callee=overload,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
return result.result
|
||||||
|
|
||||||
|
@method("kurtosis", "kurt")
|
||||||
|
def kurtosis(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def max(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def mean(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def median(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def min(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def mode(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method("product", "prod")
|
||||||
|
def product(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def std(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
Function.Argument(
|
||||||
|
pos=1,
|
||||||
|
name="ddof",
|
||||||
|
type=self.types.get_type("int"),
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def sum(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def var(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
Function.Argument(
|
||||||
|
pos=1,
|
||||||
|
name="var",
|
||||||
|
type=self.types.get_type("int"),
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def head(self, call: Call) -> Type:
|
||||||
|
signature = Function(
|
||||||
|
args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="n",
|
||||||
|
type=self.types.get_type("int"),
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
returns=call.frame,
|
||||||
|
)
|
||||||
|
|
||||||
|
result: CallResult = self.dispatcher.get_result(
|
||||||
|
location=call.location,
|
||||||
|
callee=signature,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
return result.result
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def tail(self, call: Call) -> Type:
|
||||||
|
signature = Function(
|
||||||
|
args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="n",
|
||||||
|
type=self.types.get_type("int"),
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
returns=call.frame,
|
||||||
|
)
|
||||||
|
|
||||||
|
result: CallResult = self.dispatcher.get_result(
|
||||||
|
location=call.location,
|
||||||
|
callee=signature,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
return result.result
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def groupby(self, call: Call) -> Type:
|
||||||
|
bool_: Type = self.types.get_type("bool")
|
||||||
|
function: Function = Function(
|
||||||
|
args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="by",
|
||||||
|
type=TopType(),
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
Function.Argument(
|
||||||
|
pos=1,
|
||||||
|
name="level",
|
||||||
|
type=TopType(),
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
kw_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=2,
|
||||||
|
name="as_index",
|
||||||
|
type=bool_,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
Function.Argument(
|
||||||
|
pos=3,
|
||||||
|
name="sort",
|
||||||
|
type=bool_,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
Function.Argument(
|
||||||
|
pos=4,
|
||||||
|
name="group_keys",
|
||||||
|
type=bool_,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
Function.Argument(
|
||||||
|
pos=5,
|
||||||
|
name="observed",
|
||||||
|
type=bool_,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
Function.Argument(
|
||||||
|
pos=6,
|
||||||
|
name="dropna",
|
||||||
|
type=bool_,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
returns=FrameGroupBy(frame=call.frame),
|
||||||
|
)
|
||||||
|
|
||||||
|
result: CallResult = self.dispatcher.get_result(
|
||||||
|
location=call.location,
|
||||||
|
callee=function,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
return result.result
|
||||||
|
|
||||||
|
def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr):
|
||||||
|
func_name: str = "__midas_frame_same_length__"
|
||||||
|
|
||||||
|
# Efficiently compute length
|
||||||
|
# https://stackoverflow.com/a/15943975/11109181
|
||||||
|
def len_of_df(df: ast.expr) -> ast.expr:
|
||||||
|
return ast.Call(
|
||||||
|
func=ast.Name(id="len"),
|
||||||
|
args=[
|
||||||
|
ast.Attribute(
|
||||||
|
value=df,
|
||||||
|
attr="index",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
keywords=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertions.define(
|
||||||
|
func_name,
|
||||||
|
ast.FunctionDef(
|
||||||
|
name=func_name,
|
||||||
|
args=ast.arguments(
|
||||||
|
posonlyargs=[],
|
||||||
|
args=[
|
||||||
|
ast.arg(arg="frame1"),
|
||||||
|
ast.arg(arg="frame2"),
|
||||||
|
],
|
||||||
|
kwonlyargs=[],
|
||||||
|
defaults=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
),
|
||||||
|
body=[
|
||||||
|
ast.Return(
|
||||||
|
value=ast.Compare(
|
||||||
|
left=len_of_df(ast.Name(id="frame1")),
|
||||||
|
ops=[ast.Eq()],
|
||||||
|
comparators=[len_of_df(ast.Name(id="frame2"))],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
decorator_list=[],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertions.add(
|
||||||
|
bound_expr=call_expr,
|
||||||
|
inputs=[frame1, frame2],
|
||||||
|
builder=lambda f1, f2: ast.Call(
|
||||||
|
func=ast.Name(id=func_name),
|
||||||
|
args=[f1, f2],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
message="DataFrames must have the same length",
|
||||||
|
)
|
||||||
100
midas/checker/frames/utils.py
Normal file
100
midas/checker/frames/utils.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Generic,
|
||||||
|
Optional,
|
||||||
|
Protocol,
|
||||||
|
Self,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
|
import midas.ast.python as p
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.dispatcher import CallDispatcher
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.reporter import FileReporter
|
||||||
|
from midas.checker.types import Type, UnknownType
|
||||||
|
from midas.generator.collector import AssertionCollector
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from midas.checker.python import PythonTyper, TypedExpr
|
||||||
|
|
||||||
|
|
||||||
|
class _MethodRegistryMeta(type):
|
||||||
|
_methods: dict[str, Callable[..., Type]] = {}
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
bases: tuple[type, ...],
|
||||||
|
namespace: dict[str, Any],
|
||||||
|
):
|
||||||
|
new_class = super().__new__(cls, name, bases, namespace)
|
||||||
|
new_class._methods = {}
|
||||||
|
for attr in namespace.values():
|
||||||
|
if callable(attr) and hasattr(attr, "__method_names__"):
|
||||||
|
for name in attr.__method_names__: # type: ignore
|
||||||
|
new_class._methods[name] = attr # type: ignore
|
||||||
|
return new_class
|
||||||
|
|
||||||
|
|
||||||
|
class MethodCall(Protocol):
|
||||||
|
@property
|
||||||
|
def location(self) -> Location: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def call_expr(self) -> p.Expr: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subject(self) -> TypedExpr: ...
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=MethodCall)
|
||||||
|
|
||||||
|
|
||||||
|
class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
|
||||||
|
def __init__(self, typer: PythonTyper) -> None:
|
||||||
|
self.typer: PythonTyper = typer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reporter(self) -> FileReporter:
|
||||||
|
return self.typer.reporter
|
||||||
|
|
||||||
|
@property
|
||||||
|
def types(self) -> TypesRegistry:
|
||||||
|
return self.typer.types
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dispatcher(self) -> CallDispatcher[p.Expr]:
|
||||||
|
return self.typer.dispatcher
|
||||||
|
|
||||||
|
@property
|
||||||
|
def assertions(self) -> AssertionCollector:
|
||||||
|
return self.typer.assertions
|
||||||
|
|
||||||
|
def call(self, method: str, call: T) -> Type:
|
||||||
|
func: Optional[Callable[[Self, T], Type]] = self._methods.get(method)
|
||||||
|
if func is None:
|
||||||
|
self.reporter.warning(
|
||||||
|
call.location, f"Unknown method {method} on {call.subject[1]}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return func(self, call)
|
||||||
|
|
||||||
|
|
||||||
|
_Self = TypeVar("_Self", bound=MethodRegistry[Any])
|
||||||
|
Method = Callable[[_Self, T], Type]
|
||||||
|
|
||||||
|
|
||||||
|
def method(*names: str) -> Callable[[Method[_Self, T]], Method[_Self, T]]:
|
||||||
|
def wrapper(func: Method[_Self, T]) -> Method[_Self, T]:
|
||||||
|
names_: tuple[str, ...] = names
|
||||||
|
if len(names_) == 0:
|
||||||
|
names_ = (func.__name__,)
|
||||||
|
setattr(func, "__method_names__", names_)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return wrapper
|
||||||
@@ -6,25 +6,25 @@ from typing import Optional
|
|||||||
import midas.ast.midas as m
|
import midas.ast.midas as m
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.checker.builtins import define_builtins
|
from midas.checker.builtins import define_builtins
|
||||||
|
from midas.checker.dispatcher import CallDispatcher, CallResult
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
|
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
|
||||||
from midas.checker.preamble import Preamble
|
from midas.checker.preamble import Preamble
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.reporter import FileReporter, Reporter
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
ColumnType,
|
||||||
AppliedType,
|
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
|
DataFrameType,
|
||||||
|
DerivedType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
|
||||||
Predicate,
|
Predicate,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnknownType,
|
UnknownType,
|
||||||
unfold_type,
|
|
||||||
)
|
)
|
||||||
from midas.checker.variance import VarianceInferrer
|
from midas.checker.variance import VarianceInferrer
|
||||||
from midas.lexer.midas import MidasLexer
|
from midas.lexer.midas import MidasLexer
|
||||||
@@ -39,9 +39,6 @@ class TypedParamSpec:
|
|||||||
kw: list[Function.Argument]
|
kw: list[Function.Argument]
|
||||||
|
|
||||||
|
|
||||||
TypedExpr = tuple[m.Expr, Type]
|
|
||||||
|
|
||||||
|
|
||||||
class ReturnException(Exception):
|
class ReturnException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -65,8 +62,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
|||||||
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||||
self.logger: logging.Logger = logging.getLogger("MidasTyper")
|
self.logger: logging.Logger = logging.getLogger("MidasTyper")
|
||||||
self.reporter: FileReporter = reporter.for_file(None)
|
self.reporter: FileReporter = reporter.for_file(None)
|
||||||
|
|
||||||
self.types: TypesRegistry = types
|
self.types: TypesRegistry = types
|
||||||
|
self.dispatcher: CallDispatcher[m.Expr] = CallDispatcher[m.Expr](
|
||||||
|
self.types, self.reporter
|
||||||
|
)
|
||||||
|
|
||||||
self._local_variables: dict[str, TypeVar] = {}
|
self._local_variables: dict[str, TypeVar] = {}
|
||||||
|
|
||||||
self._predicate_params: dict[str, Type] = {}
|
self._predicate_params: dict[str, Type] = {}
|
||||||
@@ -81,8 +81,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
|||||||
|
|
||||||
self._preamble: Environment = Preamble(self.types)
|
self._preamble: Environment = Preamble(self.types)
|
||||||
|
|
||||||
|
def set_reporter(self, reporter: FileReporter):
|
||||||
|
self.reporter = reporter
|
||||||
|
self.dispatcher.set_reporter(reporter)
|
||||||
|
|
||||||
def process(self, source: str, path: Optional[str]):
|
def process(self, source: str, path: Optional[str]):
|
||||||
self.reporter = self.reporter.for_file(path)
|
reporter: FileReporter = self.reporter.for_file(path)
|
||||||
|
self.set_reporter(reporter)
|
||||||
|
|
||||||
lexer: MidasLexer = MidasLexer(source)
|
lexer: MidasLexer = MidasLexer(source)
|
||||||
tokens: list[Token] = lexer.process()
|
tokens: list[Token] = lexer.process()
|
||||||
parser: MidasParser = MidasParser(tokens)
|
parser: MidasParser = MidasParser(tokens)
|
||||||
@@ -152,11 +158,18 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
|||||||
if len(params) != 0:
|
if len(params) != 0:
|
||||||
type = GenericType(name=name, params=params, body=type)
|
type = GenericType(name=name, params=params, body=type)
|
||||||
else:
|
else:
|
||||||
type = AliasType(name=name, type=type)
|
type = DerivedType(name=name, type=type)
|
||||||
self.types.define_type(name, type)
|
self.types.define_type(name, type)
|
||||||
self._local_variables.clear()
|
self._local_variables.clear()
|
||||||
self._current_name = None
|
self._current_name = None
|
||||||
|
|
||||||
|
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
|
||||||
|
name: str = stmt.name.lexeme
|
||||||
|
self._current_name = name
|
||||||
|
type: Type = stmt.type.accept(self)
|
||||||
|
self.types.define_type(name, type)
|
||||||
|
self._current_name = None
|
||||||
|
|
||||||
def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ...
|
def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ...
|
||||||
|
|
||||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||||
@@ -250,13 +263,13 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
|||||||
)
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
result: Optional[Type] = self._get_call_result(
|
result: CallResult = self.dispatcher.get_result(
|
||||||
location,
|
location=location,
|
||||||
operation,
|
callee=operation,
|
||||||
[(right_expr, right)],
|
positional=[(right_expr, right)],
|
||||||
{},
|
keywords={},
|
||||||
)
|
)
|
||||||
return result or UnknownType()
|
return result.result
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
|
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
|
||||||
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
|
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
|
||||||
@@ -276,31 +289,29 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
|||||||
)
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
result: Optional[Type] = self._get_call_result(
|
result: CallResult = self.dispatcher.get_result(
|
||||||
expr.location,
|
location=expr.location,
|
||||||
operation,
|
callee=operation,
|
||||||
[],
|
positional=[],
|
||||||
{},
|
keywords={},
|
||||||
)
|
)
|
||||||
return result or UnknownType()
|
return result.result
|
||||||
|
|
||||||
def visit_call_expr(self, expr: m.CallExpr) -> Type:
|
def visit_call_expr(self, expr: m.CallExpr) -> Type:
|
||||||
callee: Type = expr.callee.accept(self)
|
callee: Type = expr.callee.accept(self)
|
||||||
positional: list[TypedExpr] = [
|
positional: list[tuple[m.Expr, Type]] = [
|
||||||
(arg, self.type_of(arg)) for arg in expr.arguments
|
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||||
]
|
]
|
||||||
keywords: dict[str, TypedExpr] = {
|
keywords: dict[str, tuple[m.Expr, Type]] = {
|
||||||
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
||||||
}
|
}
|
||||||
return (
|
result: CallResult = self.dispatcher.get_result(
|
||||||
self._get_call_result(
|
location=expr.location,
|
||||||
expr.location,
|
callee=callee,
|
||||||
callee,
|
positional=positional,
|
||||||
positional,
|
keywords=keywords,
|
||||||
keywords,
|
|
||||||
)
|
|
||||||
or UnknownType()
|
|
||||||
)
|
)
|
||||||
|
return result.result
|
||||||
|
|
||||||
def visit_get_expr(self, expr: m.GetExpr) -> Type:
|
def visit_get_expr(self, expr: m.GetExpr) -> Type:
|
||||||
object: Type = expr.expr.accept(self)
|
object: Type = expr.expr.accept(self)
|
||||||
@@ -401,6 +412,18 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
|||||||
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
|
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def visit_frame_type(self, type: m.FrameType) -> Type:
|
||||||
|
def process_column(i: int, col: m.FrameType.Column) -> DataFrameType.Column:
|
||||||
|
return DataFrameType.Column(
|
||||||
|
index=i,
|
||||||
|
name=col.name.lexeme,
|
||||||
|
type=ColumnType(type=col.type.accept(self)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return DataFrameType(
|
||||||
|
columns=[process_column(i, col) for i, col in enumerate(type.columns)]
|
||||||
|
)
|
||||||
|
|
||||||
def _resolve_type_params(self, params: list[m.TypeParam]):
|
def _resolve_type_params(self, params: list[m.TypeParam]):
|
||||||
vars: list[TypeVar] = []
|
vars: list[TypeVar] = []
|
||||||
for param in params:
|
for param in params:
|
||||||
@@ -412,343 +435,3 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
|||||||
self._local_variables[name] = var
|
self._local_variables[name] = var
|
||||||
vars.append(var)
|
vars.append(var)
|
||||||
return vars
|
return vars
|
||||||
|
|
||||||
def _get_call_result(
|
|
||||||
self,
|
|
||||||
location: Location,
|
|
||||||
callee: Type,
|
|
||||||
positional: list[TypedExpr],
|
|
||||||
keywords: dict[str, TypedExpr],
|
|
||||||
report_errors: bool = True,
|
|
||||||
) -> Optional[Type]:
|
|
||||||
"""Get the result type of a function call
|
|
||||||
|
|
||||||
If the function has overloads, the function will try to resolve the
|
|
||||||
appropriate signature.
|
|
||||||
Argument types are matched to the defined parameters.
|
|
||||||
The function doesn't take the raw expression as a parameter to accommodate
|
|
||||||
for desugared calls such as for operators.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
location (Location): the call location
|
|
||||||
callee (Type): the called function
|
|
||||||
positional (list[TypedExpr]): the list positional arguments
|
|
||||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
|
||||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Type: the return type of the call, or `None` if either
|
|
||||||
the call is invalid or no overload matched the arguments uniquely
|
|
||||||
"""
|
|
||||||
match callee:
|
|
||||||
case Function() as function:
|
|
||||||
valid: bool
|
|
||||||
mapped: list[MappedArgument]
|
|
||||||
valid, mapped = self.map_call_arguments(
|
|
||||||
function, location, positional, keywords
|
|
||||||
)
|
|
||||||
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
|
||||||
if not valid:
|
|
||||||
return None
|
|
||||||
return function.returns
|
|
||||||
|
|
||||||
case OverloadedFunction(overloads=overloads):
|
|
||||||
function = self._match_overload(
|
|
||||||
overloads, location, positional, keywords, report_errors
|
|
||||||
)
|
|
||||||
if function is None:
|
|
||||||
return None
|
|
||||||
return function.returns
|
|
||||||
|
|
||||||
case AppliedType(body=body):
|
|
||||||
return self._get_call_result(
|
|
||||||
location, body, positional, keywords, report_errors
|
|
||||||
)
|
|
||||||
|
|
||||||
case UnknownType():
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
case _:
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(location, f"{callee} is not callable")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _are_arguments_valid(
|
|
||||||
self,
|
|
||||||
arguments: list[MappedArgument],
|
|
||||||
report_errors: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""Check whether the passed argument types correspond to their matched parameter definitions
|
|
||||||
|
|
||||||
Args:
|
|
||||||
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
|
||||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
|
||||||
"""
|
|
||||||
valid: bool = True
|
|
||||||
for arg in arguments:
|
|
||||||
if not self.types.is_subtype(arg.type, arg.argument.type):
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
arg.expr.location,
|
|
||||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def _match_overload(
|
|
||||||
self,
|
|
||||||
overloads: list[Type],
|
|
||||||
location: Location,
|
|
||||||
positional: list[TypedExpr],
|
|
||||||
keywords: dict[str, TypedExpr],
|
|
||||||
report_errors: bool = True,
|
|
||||||
) -> Optional[Function]:
|
|
||||||
"""Try and resolve the appropriate overload for the given arguments
|
|
||||||
|
|
||||||
Args:
|
|
||||||
overloads (list[Type]): the list of possible overloads
|
|
||||||
location (Location): the call location
|
|
||||||
positional (list[TypedExpr]): the list of positional arguments
|
|
||||||
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
|
||||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Function]: the resolved function signature if it can be
|
|
||||||
determined unambiguously, or `None`.
|
|
||||||
"""
|
|
||||||
candidates: list[OverloadCandidate] = []
|
|
||||||
for overload in overloads:
|
|
||||||
function: Type = unfold_type(overload)
|
|
||||||
if not isinstance(function, Function):
|
|
||||||
if report_errors:
|
|
||||||
self.logger.error(
|
|
||||||
f"Overload is not a function: {overload} is {function}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
valid, mapped = self.map_call_arguments(
|
|
||||||
function=function,
|
|
||||||
location=location,
|
|
||||||
positional=positional,
|
|
||||||
keywords=keywords,
|
|
||||||
report_errors=False,
|
|
||||||
)
|
|
||||||
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
|
||||||
candidates.append(
|
|
||||||
OverloadCandidate(
|
|
||||||
function=function,
|
|
||||||
mapped=mapped,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
pos_types: str = ", ".join(str(type) for _, type in positional)
|
|
||||||
kw_types: str = ", ".join(
|
|
||||||
f"{name}: {type}" for name, (_, type) in keywords.items()
|
|
||||||
)
|
|
||||||
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
|
||||||
|
|
||||||
n_candidates: int = len(candidates)
|
|
||||||
|
|
||||||
# Exactly 1 match -> return it
|
|
||||||
if n_candidates == 1:
|
|
||||||
return candidates[0].function
|
|
||||||
|
|
||||||
# No match -> invalid call
|
|
||||||
if n_candidates == 0:
|
|
||||||
overloads_str: str = ", ".join(map(str, overloads))
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
location,
|
|
||||||
f"No matching overload in [{overloads_str}] {for_args}",
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Multiple matches -> see if one <: all others (more specific)
|
|
||||||
for i1, c1 in enumerate(candidates):
|
|
||||||
mapped1: list[MappedArgument] = c1.mapped
|
|
||||||
best_match: bool = True
|
|
||||||
for i2, c2 in enumerate(candidates):
|
|
||||||
if i1 == i2:
|
|
||||||
continue
|
|
||||||
mapped2: list[MappedArgument] = c2.mapped
|
|
||||||
if not self._are_mapped_subtypes(mapped1, mapped2):
|
|
||||||
best_match = False
|
|
||||||
break
|
|
||||||
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
|
||||||
if best_match:
|
|
||||||
return c1.function
|
|
||||||
|
|
||||||
candidates_str: str = ", ".join(
|
|
||||||
str(candidate.function) for candidate in candidates
|
|
||||||
)
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
location,
|
|
||||||
f"Multiple matching overloads {for_args}: {candidates_str}",
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def map_call_arguments(
|
|
||||||
self,
|
|
||||||
function: Function,
|
|
||||||
location: Location,
|
|
||||||
positional: list[TypedExpr],
|
|
||||||
keywords: dict[str, TypedExpr],
|
|
||||||
report_errors: bool = True,
|
|
||||||
) -> tuple[bool, list[MappedArgument]]:
|
|
||||||
"""Map call arguments to a function's parameters as defined in its signature
|
|
||||||
|
|
||||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
|
||||||
with the arguments passed at the call site
|
|
||||||
|
|
||||||
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
|
||||||
unless `report_errors` is set to `False`
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function (Function): the function definition
|
|
||||||
location (Location): the call location
|
|
||||||
positional (list[TypedExpr]): the list of positional arguments
|
|
||||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
|
||||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
|
||||||
the call is valid and the list of mapped arguments
|
|
||||||
"""
|
|
||||||
set_args: set[str] = set()
|
|
||||||
|
|
||||||
required_positional: list[str] = [
|
|
||||||
arg.name for arg in function.pos_args + function.args if arg.required
|
|
||||||
]
|
|
||||||
required_keyword: list[str] = [
|
|
||||||
arg.name for arg in function.kw_args if arg.required
|
|
||||||
]
|
|
||||||
|
|
||||||
mapped: list[MappedArgument] = []
|
|
||||||
|
|
||||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
|
||||||
mixed_params: list[Function.Argument] = list(function.args)
|
|
||||||
kw_params: dict[str, Function.Argument] = {
|
|
||||||
arg.name: arg for arg in function.kw_args
|
|
||||||
}
|
|
||||||
|
|
||||||
valid_call: bool = True
|
|
||||||
|
|
||||||
# TODO: handle *args and **kwargs sinks
|
|
||||||
for arg in positional:
|
|
||||||
param: Function.Argument
|
|
||||||
if len(pos_params) != 0:
|
|
||||||
param = pos_params.pop(0)
|
|
||||||
elif len(mixed_params) != 0:
|
|
||||||
param = mixed_params.pop(0)
|
|
||||||
else:
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
arg[0].location, "Too many positional arguments"
|
|
||||||
)
|
|
||||||
valid_call = False
|
|
||||||
break
|
|
||||||
name: str = param.name
|
|
||||||
if name in required_positional:
|
|
||||||
required_positional.remove(name)
|
|
||||||
if name in required_keyword:
|
|
||||||
required_keyword.remove(name)
|
|
||||||
set_args.add(name)
|
|
||||||
mapped.append(
|
|
||||||
MappedArgument(
|
|
||||||
expr=arg[0],
|
|
||||||
type=arg[1],
|
|
||||||
argument=param,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
|
||||||
for name, arg in keywords.items():
|
|
||||||
param: Function.Argument
|
|
||||||
if name not in kw_params:
|
|
||||||
if report_errors:
|
|
||||||
if name in set_args:
|
|
||||||
self.reporter.error(
|
|
||||||
arg[0].location, f"Multiple values for argument '{name}'"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.reporter.error(
|
|
||||||
arg[0].location, f"Unknown keyword argument '{name}'"
|
|
||||||
)
|
|
||||||
valid_call = False
|
|
||||||
continue
|
|
||||||
param = kw_params.pop(name)
|
|
||||||
if name in required_positional:
|
|
||||||
required_positional.remove(name)
|
|
||||||
if name in required_keyword:
|
|
||||||
required_keyword.remove(name)
|
|
||||||
set_args.add(name)
|
|
||||||
mapped.append(
|
|
||||||
MappedArgument(
|
|
||||||
expr=arg[0],
|
|
||||||
type=arg[1],
|
|
||||||
argument=param,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def join_args(args: list[str]) -> str:
|
|
||||||
args = list(map(lambda a: f"'{a}'", args))
|
|
||||||
if len(args) == 0:
|
|
||||||
return ""
|
|
||||||
if len(args) == 1:
|
|
||||||
return args[0]
|
|
||||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
|
||||||
|
|
||||||
if len(required_positional) != 0:
|
|
||||||
plural: str = "" if len(required_positional) == 1 else "s"
|
|
||||||
args: str = join_args(required_positional)
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
location,
|
|
||||||
f"Missing required positional argument{plural}: {args}",
|
|
||||||
)
|
|
||||||
valid_call = False
|
|
||||||
|
|
||||||
if len(required_keyword) != 0:
|
|
||||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
|
||||||
args: str = join_args(required_keyword)
|
|
||||||
if report_errors:
|
|
||||||
self.reporter.error(
|
|
||||||
location,
|
|
||||||
f"Missing required keyword argument{plural}: {args}",
|
|
||||||
)
|
|
||||||
valid_call = False
|
|
||||||
|
|
||||||
return valid_call, mapped
|
|
||||||
|
|
||||||
def _are_mapped_subtypes(
|
|
||||||
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
|
|
||||||
) -> bool:
|
|
||||||
"""Check whether the given argument mappings are subtype/supertype of one another
|
|
||||||
|
|
||||||
This function checks whether the argument mappings `mapped1` are subtypes
|
|
||||||
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
|
||||||
of the corresponding parameter in `mapped2`, `False` is returned.
|
|
||||||
|
|
||||||
This is used to check whether a given overload is
|
|
||||||
a more specific function/ a subtype of another.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
|
||||||
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
|
||||||
"""
|
|
||||||
by_expr: dict[m.Expr, Type] = {}
|
|
||||||
for arg in mapped1:
|
|
||||||
by_expr[arg.expr] = arg.argument.type
|
|
||||||
|
|
||||||
for arg in mapped2:
|
|
||||||
type2: Type = arg.argument.type
|
|
||||||
type1: Type = by_expr[arg.expr]
|
|
||||||
if not self.types.is_subtype(type1, type2):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = {
|
|||||||
|
|
||||||
|
|
||||||
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
|
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
|
||||||
# TokenType.PLUS: "__add__",
|
TokenType.PLUS: "__add__",
|
||||||
TokenType.MINUS: "__sub__",
|
TokenType.MINUS: "__sub__",
|
||||||
TokenType.STAR: "__mul__",
|
TokenType.STAR: "__mul__",
|
||||||
TokenType.SLASH: "__truediv__",
|
TokenType.SLASH: "__truediv__",
|
||||||
|
|||||||
@@ -1,9 +1,17 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.types import Function, GenericType, TopType, Type, TypeVar, UnitType
|
from midas.checker.types import (
|
||||||
|
Function,
|
||||||
|
GenericType,
|
||||||
|
OverloadedFunction,
|
||||||
|
TopType,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
UnitType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -17,7 +25,7 @@ class Preamble(Environment):
|
|||||||
def __init__(self, types: TypesRegistry) -> None:
|
def __init__(self, types: TypesRegistry) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._types: TypesRegistry = types
|
self._types: TypesRegistry = types
|
||||||
self._python_funcs: dict[str, Callable] = {}
|
self._python_funcs: dict[str, Callable[..., Any]] = {}
|
||||||
|
|
||||||
self._def_type_constructor("object", object)
|
self._def_type_constructor("object", object)
|
||||||
self._def_type_constructor("float", float)
|
self._def_type_constructor("float", float)
|
||||||
@@ -34,7 +42,7 @@ class Preamble(Environment):
|
|||||||
# TODO: use sink
|
# TODO: use sink
|
||||||
self._def_function(
|
self._def_function(
|
||||||
name="print",
|
name="print",
|
||||||
pos=[Param("object", TopType())],
|
pos=[Param("object", TopType(), required=False)],
|
||||||
returns=UnitType(),
|
returns=UnitType(),
|
||||||
py_function=print,
|
py_function=print,
|
||||||
)
|
)
|
||||||
@@ -64,11 +72,48 @@ class Preamble(Environment):
|
|||||||
pos=[Param("prompt", TopType(), required=False)],
|
pos=[Param("prompt", TopType(), required=False)],
|
||||||
returns=self._types.get_type("str"),
|
returns=self._types.get_type("str"),
|
||||||
)
|
)
|
||||||
|
self._def_function(
|
||||||
|
name="len",
|
||||||
|
pos=[Param("object", TopType())],
|
||||||
|
returns=self._types.get_type("int"),
|
||||||
|
)
|
||||||
|
|
||||||
def _list_of(self, item_type: Type) -> Type:
|
T = TypeVar(name="T", bound=None)
|
||||||
return self._types.apply_generic(self._types.get_type("list"), [item_type])
|
self._def_overloads(
|
||||||
|
name="max",
|
||||||
|
py_function=max,
|
||||||
|
signatures=[
|
||||||
|
(
|
||||||
|
[Param("arg1", T), Param("arg2", T)],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
T,
|
||||||
|
[T],
|
||||||
|
),
|
||||||
|
([Param("iterable", self._list_of(T))], [], [], T, [T]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self._def_overloads(
|
||||||
|
name="min",
|
||||||
|
py_function=min,
|
||||||
|
signatures=[
|
||||||
|
(
|
||||||
|
[Param("arg1", T), Param("arg2", T)],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
T,
|
||||||
|
[T],
|
||||||
|
),
|
||||||
|
([Param("iterable", self._list_of(T))], [], [], T, [T]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
def _def_type_constructor(self, name: str, py_function: Optional[Callable] = None):
|
def _list_of(self, item_type: str | Type) -> Type:
|
||||||
|
return self._types.list_of(item_type)
|
||||||
|
|
||||||
|
def _def_type_constructor(
|
||||||
|
self, name: str, py_function: Optional[Callable[..., Any]] = None
|
||||||
|
):
|
||||||
# TODO: more specific arg types
|
# TODO: more specific arg types
|
||||||
self._def_function(
|
self._def_function(
|
||||||
name=name,
|
name=name,
|
||||||
@@ -121,7 +166,7 @@ class Preamble(Environment):
|
|||||||
kw: list[Param] = [],
|
kw: list[Param] = [],
|
||||||
returns: Type = UnitType(),
|
returns: Type = UnitType(),
|
||||||
type_vars: list[TypeVar] = [],
|
type_vars: list[TypeVar] = [],
|
||||||
py_function: Optional[Callable] = None,
|
py_function: Optional[Callable[..., Any]] = None,
|
||||||
):
|
):
|
||||||
function: Type = self._make_function(
|
function: Type = self._make_function(
|
||||||
name=name,
|
name=name,
|
||||||
@@ -135,5 +180,31 @@ class Preamble(Environment):
|
|||||||
if py_function is not None:
|
if py_function is not None:
|
||||||
self._python_funcs[name] = py_function
|
self._python_funcs[name] = py_function
|
||||||
|
|
||||||
def get_py_func(self, name: str) -> Optional[Callable]:
|
def _def_overloads(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
signatures: list[
|
||||||
|
tuple[list[Param], list[Param], list[Param], Type, list[TypeVar]]
|
||||||
|
],
|
||||||
|
py_function: Optional[Callable[..., Any]] = None,
|
||||||
|
):
|
||||||
|
overloads: list[Type] = []
|
||||||
|
for pos, mixed, kw, returns, type_vars in signatures:
|
||||||
|
overloads.append(
|
||||||
|
self._make_function(
|
||||||
|
name=name,
|
||||||
|
pos=pos,
|
||||||
|
mixed=mixed,
|
||||||
|
kw=kw,
|
||||||
|
returns=returns,
|
||||||
|
type_vars=type_vars,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
function: Type = OverloadedFunction(overloads=overloads)
|
||||||
|
self.define(name, function)
|
||||||
|
if py_function is not None:
|
||||||
|
self._python_funcs[name] = py_function
|
||||||
|
|
||||||
|
def get_py_func(self, name: str) -> Optional[Callable[..., Any]]:
|
||||||
return self._python_funcs.get(name)
|
return self._python_funcs.get(name)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -5,17 +5,20 @@ from typing import Optional
|
|||||||
from midas.ast.midas import MemberKind
|
from midas.ast.midas import MemberKind
|
||||||
from midas.checker.builtins import BUILTIN_SUBTYPES
|
from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
|
||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
|
ColumnType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
|
DataFrameType,
|
||||||
|
DerivedType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
Predicate,
|
Predicate,
|
||||||
TopType,
|
TopType,
|
||||||
|
TupleType,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnknownType,
|
UnknownType,
|
||||||
@@ -110,6 +113,15 @@ class TypesRegistry:
|
|||||||
raise ValueError(f"Predicate {name} already defined")
|
raise ValueError(f"Predicate {name} already defined")
|
||||||
self._predicates[name] = predicate
|
self._predicates[name] = predicate
|
||||||
|
|
||||||
|
def is_builtin_subtype(self, name1: str, name2: str) -> bool:
|
||||||
|
subtypes: set[str] = BUILTIN_SUBTYPES.get(name2, set())
|
||||||
|
if name1 in subtypes:
|
||||||
|
return True
|
||||||
|
for subtype in subtypes:
|
||||||
|
if self.is_builtin_subtype(name1, subtype):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||||
"""Check whether `type1` is a subtype of `type2`
|
"""Check whether `type1` is a subtype of `type2`
|
||||||
|
|
||||||
@@ -143,11 +155,11 @@ class TypesRegistry:
|
|||||||
return True
|
return True
|
||||||
return self.is_subtype(type1, bound)
|
return self.is_subtype(type1, bound)
|
||||||
|
|
||||||
case (AliasType(type=base1), _):
|
case (DerivedType(type=base1), _):
|
||||||
return self.is_subtype(base1, type2)
|
return self.is_subtype(base1, type2)
|
||||||
|
|
||||||
case (BaseType(name=name1), BaseType(name=name2)):
|
case (BaseType(name=name1), BaseType(name=name2)):
|
||||||
return name1 in BUILTIN_SUBTYPES.get(name2, set())
|
return self.is_builtin_subtype(name1, name2)
|
||||||
|
|
||||||
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
||||||
for k, t in props2.items():
|
for k, t in props2.items():
|
||||||
@@ -157,6 +169,24 @@ class TypesRegistry:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
case (DataFrameType(columns=columns1), DataFrameType(columns=columns2)):
|
||||||
|
# TODO: check order?
|
||||||
|
by_name1: dict[str, DataFrameType.Column] = {
|
||||||
|
col.name: col for col in columns1 if col.name is not None
|
||||||
|
}
|
||||||
|
for col2 in columns2:
|
||||||
|
if col2.name not in by_name1:
|
||||||
|
return False
|
||||||
|
if not self.is_subtype(by_name1[col2.name].type, col2.type):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
case (ColumnType(type=inner1), ColumnType(type=inner2)):
|
||||||
|
# TODO: invariant, replace ColumnType with simple GenericType
|
||||||
|
if not self.are_equivalent(inner1, inner2):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
case (Function(), Function()):
|
case (Function(), Function()):
|
||||||
return self.is_func_subtype(type1, type2)
|
return self.is_func_subtype(type1, type2)
|
||||||
|
|
||||||
@@ -187,6 +217,9 @@ class TypesRegistry:
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def are_equivalent(self, type1: Type, type2: Type) -> bool:
|
||||||
|
return self.is_subtype(type1, type2) and self.is_subtype(type2, type1)
|
||||||
|
|
||||||
# TODO: verify the logic in here
|
# TODO: verify the logic in here
|
||||||
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
|
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
|
||||||
"""Check whether a function is a subtype of another
|
"""Check whether a function is a subtype of another
|
||||||
@@ -294,8 +327,8 @@ class TypesRegistry:
|
|||||||
|
|
||||||
def apply_generic(self, type: Type, args: list[Type]) -> Type:
|
def apply_generic(self, type: Type, args: list[Type]) -> Type:
|
||||||
match type:
|
match type:
|
||||||
case AliasType(name=name, type=base):
|
case DerivedType(name=name, type=base):
|
||||||
return AliasType(name=name, type=self.apply_generic(base, args))
|
return DerivedType(name=name, type=self.apply_generic(base, args))
|
||||||
|
|
||||||
case GenericType(name=name, params=type_vars, body=body):
|
case GenericType(name=name, params=type_vars, body=body):
|
||||||
n_args: int = len(args)
|
n_args: int = len(args)
|
||||||
@@ -323,6 +356,9 @@ class TypesRegistry:
|
|||||||
body=substitute_typevars(body, substitutions),
|
body=substitute_typevars(body, substitutions),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case BaseType(name="tuple"):
|
||||||
|
return TupleType(items=tuple(args))
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"{type} is not a generic type")
|
raise ValueError(f"{type} is not a generic type")
|
||||||
|
|
||||||
@@ -362,7 +398,7 @@ class TypesRegistry:
|
|||||||
return self._members[name][member_name].type
|
return self._members[name][member_name].type
|
||||||
return None
|
return None
|
||||||
|
|
||||||
case AliasType(name=name, type=base):
|
case DerivedType(name=name, type=base):
|
||||||
if name in self._members:
|
if name in self._members:
|
||||||
if member_name in self._members[name]:
|
if member_name in self._members[name]:
|
||||||
return self._members[name][member_name].type
|
return self._members[name][member_name].type
|
||||||
@@ -416,3 +452,29 @@ class TypesRegistry:
|
|||||||
|
|
||||||
def lookup_predicate(self, name: str) -> Optional[Predicate]:
|
def lookup_predicate(self, name: str) -> Optional[Predicate]:
|
||||||
return self._predicates.get(name)
|
return self._predicates.get(name)
|
||||||
|
|
||||||
|
def _by_name_or_type(self, name_or_type: str | Type) -> Type:
|
||||||
|
if isinstance(name_or_type, str):
|
||||||
|
return self.get_type(name_or_type)
|
||||||
|
return name_or_type
|
||||||
|
|
||||||
|
def list_of(self, item_type: str | Type) -> Type:
|
||||||
|
list_ = self.get_type("list")
|
||||||
|
return self.apply_generic(list_, [self._by_name_or_type(item_type)])
|
||||||
|
|
||||||
|
def tuple_of(self, *item_types: str | Type) -> Type:
|
||||||
|
tuple_ = self.get_type("tuple")
|
||||||
|
return self.apply_generic(
|
||||||
|
tuple_,
|
||||||
|
[self._by_name_or_type(item_type) for item_type in item_types],
|
||||||
|
)
|
||||||
|
|
||||||
|
def dict_of(self, key_type: str | Type, value_type: str | Type) -> Type:
|
||||||
|
dict_ = self.get_type("dict")
|
||||||
|
return self.apply_generic(
|
||||||
|
dict_,
|
||||||
|
[
|
||||||
|
self._by_name_or_type(key_type),
|
||||||
|
self._by_name_or_type(value_type),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
@@ -128,6 +128,10 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
|
|
||||||
case p.GetExpr():
|
case p.GetExpr():
|
||||||
target.accept(self)
|
target.accept(self)
|
||||||
|
|
||||||
|
case p.SubscriptExpr():
|
||||||
|
target.accept(self)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise Exception(f"Unsupported assignment to {target}")
|
raise Exception(f"Unsupported assignment to {target}")
|
||||||
|
|
||||||
@@ -232,5 +236,9 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
if expr.step is not None:
|
if expr.step is not None:
|
||||||
self.resolve(expr.step)
|
self.resolve(expr.step)
|
||||||
|
|
||||||
|
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
|
||||||
|
for item in expr.items:
|
||||||
|
self.resolve(item)
|
||||||
|
|
||||||
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Optional, assert_never
|
from typing import Optional, assert_never, cast
|
||||||
|
|
||||||
import midas.ast.midas as m
|
import midas.ast.midas as m
|
||||||
from midas.ast.printer import MidasPrinter
|
from midas.ast.printer import MidasPrinter
|
||||||
@@ -23,7 +23,7 @@ class BaseType:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class AliasType:
|
class DerivedType:
|
||||||
name: str
|
name: str
|
||||||
type: Type
|
type: Type
|
||||||
|
|
||||||
@@ -156,6 +156,53 @@ class ConstraintType:
|
|||||||
return f"{self.type} where {printer.print(self.constraint)}"
|
return f"{self.type} where {printer.print(self.constraint)}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class TupleType:
|
||||||
|
items: tuple[Type, ...]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"({', '.join(map(str, self.items))})"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class ColumnType:
|
||||||
|
type: Type
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"Column[{self.type}]"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class DataFrameType:
|
||||||
|
columns: list[Column]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
schema: list[str] = [f"{col.name}: {col.type}" for col in self.columns]
|
||||||
|
return f"Frame[{', '.join(schema)}]"
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Column:
|
||||||
|
index: int
|
||||||
|
name: Optional[str]
|
||||||
|
type: ColumnType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class FrameGroupBy:
|
||||||
|
frame: DataFrameType
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"FrameGroupBy[{self.frame}]"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class ColumnGroupBy:
|
||||||
|
column: ColumnType
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"ColumnGroupBy[{self.column}]"
|
||||||
|
|
||||||
|
|
||||||
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||||
def sub_argument(arg: Function.Argument):
|
def sub_argument(arg: Function.Argument):
|
||||||
return Function.Argument(
|
return Function.Argument(
|
||||||
@@ -165,6 +212,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
required=arg.required,
|
required=arg.required,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def sub_column(col: DataFrameType.Column):
|
||||||
|
return DataFrameType.Column(
|
||||||
|
index=col.index,
|
||||||
|
name=col.name,
|
||||||
|
type=cast(ColumnType, substitute_typevars(col.type, substitutions)),
|
||||||
|
)
|
||||||
|
|
||||||
match type:
|
match type:
|
||||||
case TopType():
|
case TopType():
|
||||||
return type
|
return type
|
||||||
@@ -175,8 +229,10 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
case BaseType():
|
case BaseType():
|
||||||
return type
|
return type
|
||||||
|
|
||||||
case AliasType(name=name, type=type2):
|
case DerivedType(name=name, type=type2):
|
||||||
return AliasType(name=name, type=substitute_typevars(type2, substitutions))
|
return DerivedType(
|
||||||
|
name=name, type=substitute_typevars(type2, substitutions)
|
||||||
|
)
|
||||||
|
|
||||||
case Function(
|
case Function(
|
||||||
pos_args=pos_args,
|
pos_args=pos_args,
|
||||||
@@ -250,6 +306,31 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
body=substitute_typevars(body, substitutions),
|
body=substitute_typevars(body, substitutions),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case TupleType(items=items):
|
||||||
|
return TupleType(
|
||||||
|
items=tuple(substitute_typevars(item, substitutions) for item in items),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ColumnType(type=items_type):
|
||||||
|
return ColumnType(
|
||||||
|
type=substitute_typevars(items_type, substitutions),
|
||||||
|
)
|
||||||
|
|
||||||
|
case DataFrameType(columns=columns):
|
||||||
|
return DataFrameType(
|
||||||
|
columns=list(map(sub_column, columns)),
|
||||||
|
)
|
||||||
|
|
||||||
|
case FrameGroupBy(frame=frame):
|
||||||
|
return FrameGroupBy(
|
||||||
|
frame=cast(DataFrameType, substitute_typevars(frame, substitutions))
|
||||||
|
)
|
||||||
|
|
||||||
|
case ColumnGroupBy(column=column):
|
||||||
|
return ColumnGroupBy(
|
||||||
|
column=cast(ColumnType, substitute_typevars(column, substitutions))
|
||||||
|
)
|
||||||
|
|
||||||
case UnknownType() | UnitType():
|
case UnknownType() | UnitType():
|
||||||
return type
|
return type
|
||||||
|
|
||||||
@@ -263,7 +344,7 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
|
|
||||||
def unfold_type(type: Type) -> Type:
|
def unfold_type(type: Type) -> Type:
|
||||||
match type:
|
match type:
|
||||||
case AliasType(type=ref_type):
|
case DerivedType(type=ref_type):
|
||||||
return unfold_type(ref_type)
|
return unfold_type(ref_type)
|
||||||
case _:
|
case _:
|
||||||
return type
|
return type
|
||||||
@@ -286,7 +367,7 @@ def to_annotation(type: Type) -> str:
|
|||||||
case BaseType(name=name):
|
case BaseType(name=name):
|
||||||
return name
|
return name
|
||||||
|
|
||||||
case AliasType(name=name):
|
case DerivedType(name=name):
|
||||||
return name
|
return name
|
||||||
|
|
||||||
case UnknownType():
|
case UnknownType():
|
||||||
@@ -317,6 +398,21 @@ def to_annotation(type: Type) -> str:
|
|||||||
case ConstraintType():
|
case ConstraintType():
|
||||||
return str(type)
|
return str(type)
|
||||||
|
|
||||||
|
case TupleType(items=items):
|
||||||
|
return f"Tuple[{', '.join(map(to_annotation, items))}]"
|
||||||
|
|
||||||
|
case ColumnType():
|
||||||
|
return "pd.Series"
|
||||||
|
|
||||||
|
case DataFrameType():
|
||||||
|
return "pd.DataFrame"
|
||||||
|
|
||||||
|
case FrameGroupBy():
|
||||||
|
return "pd.api.typing.DataFrameGroupBy"
|
||||||
|
|
||||||
|
case ColumnGroupBy():
|
||||||
|
return "pd.api.typing.SeriesGroupBy"
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
assert_never(type)
|
assert_never(type)
|
||||||
|
|
||||||
@@ -331,7 +427,7 @@ class Predicate:
|
|||||||
Type = (
|
Type = (
|
||||||
TopType
|
TopType
|
||||||
| BaseType
|
| BaseType
|
||||||
| AliasType
|
| DerivedType
|
||||||
| UnknownType
|
| UnknownType
|
||||||
| UnitType
|
| UnitType
|
||||||
| Function
|
| Function
|
||||||
@@ -342,4 +438,9 @@ Type = (
|
|||||||
| GenericType
|
| GenericType
|
||||||
| AppliedType
|
| AppliedType
|
||||||
| ConstraintType
|
| ConstraintType
|
||||||
|
| TupleType
|
||||||
|
| ColumnType
|
||||||
|
| DataFrameType
|
||||||
|
| FrameGroupBy
|
||||||
|
| ColumnGroupBy
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ from typing import Optional
|
|||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AppliedType,
|
AppliedType,
|
||||||
|
ColumnType,
|
||||||
|
DataFrameType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
TopType,
|
TopType,
|
||||||
@@ -98,6 +100,30 @@ class Unifier:
|
|||||||
|
|
||||||
return substitutions
|
return substitutions
|
||||||
|
|
||||||
|
case (
|
||||||
|
DataFrameType(columns=template_columns),
|
||||||
|
DataFrameType(columns=concrete_columns),
|
||||||
|
) if len(template_columns) == len(concrete_columns):
|
||||||
|
substitutions: dict[str, Type] = {}
|
||||||
|
for template_column, concrete_column in zip(
|
||||||
|
template_columns, concrete_columns
|
||||||
|
):
|
||||||
|
if template_column.index != concrete_column or (
|
||||||
|
template_column.name != concrete_column.name
|
||||||
|
):
|
||||||
|
self.logger.debug(
|
||||||
|
f"Column mismatch: template={template_column}, concrete={concrete_column}"
|
||||||
|
)
|
||||||
|
raise UnificationError
|
||||||
|
new_substistutions: dict[str, Type] = self.match(
|
||||||
|
template_column.type, concrete_column.type
|
||||||
|
)
|
||||||
|
substitutions = self.merge(substitutions, new_substistutions)
|
||||||
|
return substitutions
|
||||||
|
|
||||||
|
case (ColumnType(type=template_column), ColumnType(type=concrete_column)):
|
||||||
|
return self.match(template_column, concrete_column)
|
||||||
|
|
||||||
case (Function(), Function()):
|
case (Function(), Function()):
|
||||||
mapped: list[tuple[Function.Argument, Function.Argument]] = (
|
mapped: list[tuple[Function.Argument, Function.Argument]] = (
|
||||||
self.map_params(template, concrete)
|
self.map_params(template, concrete)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TextIO
|
from typing import Optional, TextIO
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
@@ -19,24 +19,33 @@ from midas.utils import TypedAST
|
|||||||
@click.command(help="Compile source")
|
@click.command(help="Compile source")
|
||||||
@click.argument("file", type=click.File("r"))
|
@click.argument("file", type=click.File("r"))
|
||||||
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||||
|
@click.option("-s", "--stubs", type=str, multiple=True)
|
||||||
|
@click.option("--ignore-errors", is_flag=True)
|
||||||
def compile(
|
def compile(
|
||||||
file: TextIO,
|
file: TextIO,
|
||||||
types: tuple[TextIO],
|
types: tuple[TextIO],
|
||||||
|
stubs: tuple[str],
|
||||||
|
ignore_errors: bool,
|
||||||
):
|
):
|
||||||
source: str = file.read()
|
source: str = file.read()
|
||||||
source_path: Path = Path(file.name).resolve()
|
source_path: Path = Path(file.name).resolve()
|
||||||
|
|
||||||
checker = TypeChecker()
|
checker = TypeChecker()
|
||||||
for types_file in types:
|
type_files: list[tuple[Path, Optional[str]]] = []
|
||||||
checker.import_midas(Path(types_file.name).resolve())
|
for i, types_file in enumerate(types):
|
||||||
|
in_path: Path = Path(types_file.name).resolve()
|
||||||
|
checker.import_midas(in_path)
|
||||||
|
type_files.append((in_path, stubs[i] if i < len(stubs) else None))
|
||||||
|
|
||||||
typed_ast: TypedAST = checker.type_check_source(source, str(source_path))
|
typed_ast: TypedAST = checker.type_check_source(source, str(source_path))
|
||||||
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
|
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
|
||||||
printer = DiagnosticPrinter()
|
printer = DiagnosticPrinter()
|
||||||
printer.print_all(diagnostics)
|
printer.print_all(diagnostics)
|
||||||
|
|
||||||
if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)):
|
if not ignore_errors and any(
|
||||||
|
map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)
|
||||||
|
):
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
generator = Generator(workdir=source_path.parent, types=checker.types)
|
generator = Generator(workdir=source_path.parent, types=checker.types)
|
||||||
generator.generate(typed_ast, source_path)
|
generator.generate(typed_ast, source_path, type_files=type_files)
|
||||||
|
|||||||
@@ -11,14 +11,14 @@ import click
|
|||||||
from midas.ast.printer import MidasPrinter
|
from midas.ast.printer import MidasPrinter
|
||||||
from midas.checker.checker import TypeChecker
|
from midas.checker.checker import TypeChecker
|
||||||
from midas.checker.registry import Member
|
from midas.checker.registry import Member
|
||||||
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
|
from midas.checker.types import AppliedType, BaseType, DerivedType, GenericType, Type
|
||||||
|
|
||||||
|
|
||||||
def base_type(type: Type) -> Type:
|
def base_type(type: Type) -> Type:
|
||||||
match type:
|
match type:
|
||||||
case BaseType():
|
case BaseType():
|
||||||
return type
|
return type
|
||||||
case AliasType(type=base):
|
case DerivedType(type=base):
|
||||||
return base
|
return base
|
||||||
case AppliedType(body=body):
|
case AppliedType(body=body):
|
||||||
return body
|
return body
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import ast
|
import ast
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TextIO
|
from typing import Optional, TextIO
|
||||||
|
|
||||||
import black
|
import black
|
||||||
import click
|
import click
|
||||||
@@ -38,15 +38,17 @@ class Handler(FileSystemEventHandler):
|
|||||||
|
|
||||||
@click.command(help="Generate stubs from Midas definitions")
|
@click.command(help="Generate stubs from Midas definitions")
|
||||||
@click.argument("file", type=click.File("r"))
|
@click.argument("file", type=click.File("r"))
|
||||||
@click.option("-o", "--output", type=click.File("w"), default="-")
|
@click.option("-o", "--output", type=click.File("w"))
|
||||||
@click.option("-w", "--watch", is_flag=True)
|
@click.option("-w", "--watch", is_flag=True)
|
||||||
def stubs(
|
def stubs(
|
||||||
file: TextIO,
|
file: TextIO,
|
||||||
output: TextIO,
|
output: Optional[TextIO],
|
||||||
watch: bool,
|
watch: bool,
|
||||||
):
|
):
|
||||||
source_path: Path = Path(file.name).resolve()
|
source_path: Path = Path(file.name).resolve()
|
||||||
out_path: Path = Path(output.name).resolve()
|
out_path: Path = source_path.with_suffix(".pyi")
|
||||||
|
if output is not None:
|
||||||
|
out_path = Path(output.name).resolve()
|
||||||
generate_stubs(source_path, out_path)
|
generate_stubs(source_path, out_path)
|
||||||
|
|
||||||
if watch:
|
if watch:
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ def types(
|
|||||||
message=f"Type: {type}",
|
message=f"Type: {type}",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
diagnostics.extend(checker.diagnostics)
|
||||||
printer = DiagnosticPrinter()
|
printer = DiagnosticPrinter()
|
||||||
printer.print_all(diagnostics)
|
printer.print_all(diagnostics)
|
||||||
|
|
||||||
|
|||||||
@@ -134,9 +134,9 @@ class PythonHighlighter(
|
|||||||
|
|
||||||
def visit_base_type(self, node: p.BaseType) -> None:
|
def visit_base_type(self, node: p.BaseType) -> None:
|
||||||
self.wrap(node, "base-type")
|
self.wrap(node, "base-type")
|
||||||
if node.param is not None:
|
for arg in node.args:
|
||||||
self.wrap(node.param, "param")
|
self.wrap(arg, "arg")
|
||||||
node.param.accept(self)
|
arg.accept(self)
|
||||||
|
|
||||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||||
self.wrap(node, "constraint-type")
|
self.wrap(node, "constraint-type")
|
||||||
@@ -247,6 +247,10 @@ class PythonHighlighter(
|
|||||||
if expr.step is not None:
|
if expr.step is not None:
|
||||||
expr.step.accept(self)
|
expr.step.accept(self)
|
||||||
|
|
||||||
|
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
|
||||||
|
for item in expr.items:
|
||||||
|
item.accept(self)
|
||||||
|
|
||||||
def visit_raw_expr(self, expr: p.RawExpr) -> None: ...
|
def visit_raw_expr(self, expr: p.RawExpr) -> None: ...
|
||||||
|
|
||||||
def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ...
|
def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ...
|
||||||
@@ -350,6 +354,14 @@ class MidasHighlighter(
|
|||||||
for param in spec.pos + spec.mixed + spec.kw:
|
for param in spec.pos + spec.mixed + spec.kw:
|
||||||
param.type.accept(self)
|
param.type.accept(self)
|
||||||
|
|
||||||
|
def visit_frame_type(self, type: m.FrameType) -> None:
|
||||||
|
self.wrap(type, "frame")
|
||||||
|
for column in type.columns:
|
||||||
|
self._visit_frame_column(column)
|
||||||
|
|
||||||
|
def _visit_frame_column(self, column: m.FrameType.Column) -> None:
|
||||||
|
self.wrap(column, "column")
|
||||||
|
|
||||||
|
|
||||||
class DiagnosticsHighlighter(Highlighter):
|
class DiagnosticsHighlighter(Highlighter):
|
||||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
|
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ span {
|
|||||||
--col: 108, 233, 108;
|
--col: 108, 233, 108;
|
||||||
}
|
}
|
||||||
|
|
||||||
&.param {
|
&.arg {
|
||||||
--col: 103, 192, 224;
|
--col: 103, 192, 224;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -7,6 +8,13 @@ from midas.cli.ansi import Ansi
|
|||||||
|
|
||||||
|
|
||||||
class DiagnosticPrinter:
|
class DiagnosticPrinter:
|
||||||
|
COLORS: dict[DiagnosticType, int] = {
|
||||||
|
DiagnosticType.ERROR: Ansi.RED,
|
||||||
|
DiagnosticType.WARNING: Ansi.YELLOW,
|
||||||
|
DiagnosticType.INFO: Ansi.CYAN,
|
||||||
|
DiagnosticType.DEBUG: Ansi.MAGENTA,
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.files: dict[Optional[str], list[str]] = {}
|
self.files: dict[Optional[str], list[str]] = {}
|
||||||
|
|
||||||
@@ -22,10 +30,25 @@ class DiagnosticPrinter:
|
|||||||
return self.files[filename]
|
return self.files[filename]
|
||||||
|
|
||||||
def print_all(self, diagnostics: list[Diagnostic], indent: int = 4):
|
def print_all(self, diagnostics: list[Diagnostic], indent: int = 4):
|
||||||
|
by_type: dict[DiagnosticType, int] = defaultdict(int)
|
||||||
for diagnostic in diagnostics:
|
for diagnostic in diagnostics:
|
||||||
filename: Optional[str] = diagnostic.file_path
|
filename: Optional[str] = diagnostic.file_path
|
||||||
lines = self.get_lines(filename)
|
lines = self.get_lines(filename)
|
||||||
self.print(lines, diagnostic, indent=indent)
|
self.print(lines, diagnostic, indent=indent)
|
||||||
|
by_type[diagnostic.type] += 1
|
||||||
|
|
||||||
|
if len(diagnostics) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
counts: list[str] = []
|
||||||
|
for type in DiagnosticType:
|
||||||
|
if type not in by_type:
|
||||||
|
continue
|
||||||
|
count: int = by_type[type]
|
||||||
|
color: int = self.COLORS.get(type, Ansi.WHITE)
|
||||||
|
counts.append(f"{Ansi.FG(color)}{type.value}s{Ansi.RESET}: {count}")
|
||||||
|
|
||||||
|
print(" ".join(counts))
|
||||||
|
|
||||||
def print(self, lines: list[str], diagnostic: Diagnostic, indent: int = 4):
|
def print(self, lines: list[str], diagnostic: Diagnostic, indent: int = 4):
|
||||||
"""Pretty-print a diagnostic, showing some context if possible
|
"""Pretty-print a diagnostic, showing some context if possible
|
||||||
@@ -45,7 +68,7 @@ class DiagnosticPrinter:
|
|||||||
|
|
||||||
loc: Location = diagnostic.location
|
loc: Location = diagnostic.location
|
||||||
if loc.lineno != loc.end_lineno:
|
if loc.lineno != loc.end_lineno:
|
||||||
print(diagnostic)
|
self.print_multiline(lines, diagnostic, indent)
|
||||||
return
|
return
|
||||||
|
|
||||||
start_offset: int = loc.col_offset
|
start_offset: int = loc.col_offset
|
||||||
@@ -55,12 +78,7 @@ class DiagnosticPrinter:
|
|||||||
before: str = line[:start_offset]
|
before: str = line[:start_offset]
|
||||||
after: str = line[end_offset:]
|
after: str = line[end_offset:]
|
||||||
|
|
||||||
color: int = {
|
color: int = self.COLORS.get(diagnostic.type, Ansi.WHITE)
|
||||||
DiagnosticType.ERROR: Ansi.RED,
|
|
||||||
DiagnosticType.WARNING: Ansi.YELLOW,
|
|
||||||
DiagnosticType.INFO: Ansi.CYAN,
|
|
||||||
DiagnosticType.DEBUG: Ansi.MAGENTA,
|
|
||||||
}.get(diagnostic.type, Ansi.WHITE)
|
|
||||||
|
|
||||||
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
|
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
|
||||||
cursor: str = (
|
cursor: str = (
|
||||||
@@ -77,3 +95,27 @@ class DiagnosticPrinter:
|
|||||||
print(indent_str + before + subject + after)
|
print(indent_str + before + subject + after)
|
||||||
print(indent_str + cursor)
|
print(indent_str + cursor)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
def print_multiline(
|
||||||
|
self, all_lines: list[str], diagnostic: Diagnostic, indent: int = 4
|
||||||
|
):
|
||||||
|
loc: Location = diagnostic.location
|
||||||
|
lines: list[str] = all_lines[loc.lineno - 1 : loc.end_lineno]
|
||||||
|
|
||||||
|
start_offset: int = loc.col_offset
|
||||||
|
end_offset: int = loc.end_col_offset or (start_offset + 1)
|
||||||
|
|
||||||
|
indent_str: str = " " * indent
|
||||||
|
color: int = self.COLORS.get(diagnostic.type, Ansi.WHITE)
|
||||||
|
res: str = indent_str + lines[0][:start_offset]
|
||||||
|
res += Ansi.FG(color) + lines[0][start_offset:]
|
||||||
|
for line in lines[1:-1]:
|
||||||
|
res += "\n" + indent_str + line
|
||||||
|
res += "\n" + indent_str + lines[-1][:end_offset]
|
||||||
|
res += Ansi.RESET + lines[-1][end_offset:]
|
||||||
|
|
||||||
|
print(diagnostic.location_str + ":")
|
||||||
|
print(res)
|
||||||
|
print()
|
||||||
|
print(Ansi.FG(color) + diagnostic.message + Ansi.RESET)
|
||||||
|
print()
|
||||||
|
|||||||
59
midas/generator/collector.py
Normal file
59
midas/generator/collector.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import ast
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import midas.ast.python as p
|
||||||
|
|
||||||
|
AssertionBuilder = Callable[..., ast.expr]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Assertion:
|
||||||
|
bound_expr: p.Expr
|
||||||
|
inputs: list[p.Expr]
|
||||||
|
builder: AssertionBuilder
|
||||||
|
message: str
|
||||||
|
|
||||||
|
def is_bound_to(self, expr: p.Expr) -> bool:
|
||||||
|
return expr == self.bound_expr
|
||||||
|
|
||||||
|
|
||||||
|
class AssertionCollector:
|
||||||
|
def __init__(self):
|
||||||
|
self.assertions: list[Assertion] = []
|
||||||
|
self.definitions: dict[str, ast.stmt] = {}
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
bound_expr: p.Expr,
|
||||||
|
inputs: list[p.Expr],
|
||||||
|
builder: AssertionBuilder,
|
||||||
|
message: str,
|
||||||
|
):
|
||||||
|
self.assertions.append(
|
||||||
|
Assertion(
|
||||||
|
bound_expr=bound_expr,
|
||||||
|
inputs=inputs,
|
||||||
|
builder=builder,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove(self, assertion: Assertion):
|
||||||
|
try:
|
||||||
|
self.assertions.remove(assertion)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def define(self, name: str, stmt: ast.stmt):
|
||||||
|
if name not in self.definitions:
|
||||||
|
self.definitions[name] = stmt
|
||||||
|
|
||||||
|
def get_definitions(self) -> list[ast.stmt]:
|
||||||
|
return list(self.definitions.values())
|
||||||
|
|
||||||
|
def get_assertions(self) -> list[Assertion]:
|
||||||
|
return self.assertions
|
||||||
|
|
||||||
|
def get_assertions_for(self, expr: p.Expr) -> list[Assertion]:
|
||||||
|
return list(filter(lambda a: a.is_bound_to(expr), self.assertions))
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import ast
|
import ast
|
||||||
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -8,65 +9,96 @@ import midas.ast.midas as m
|
|||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.ast.printer import MidasPrinter
|
from midas.ast.printer import MidasPrinter
|
||||||
|
from midas.checker.checker import TypeChecker
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
|
||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
|
ColumnGroupBy,
|
||||||
|
ColumnType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
|
DataFrameType,
|
||||||
|
DerivedType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
|
FrameGroupBy,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
TopType,
|
TopType,
|
||||||
|
TupleType,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnitType,
|
UnitType,
|
||||||
UnknownType,
|
UnknownType,
|
||||||
)
|
)
|
||||||
|
from midas.generator.collector import Assertion, AssertionCollector
|
||||||
from midas.generator.constraints import ConstraintGenerator
|
from midas.generator.constraints import ConstraintGenerator
|
||||||
|
from midas.generator.stubs import StubsGenerator
|
||||||
from midas.utils import TypedAST
|
from midas.utils import TypedAST
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Scope:
|
class Scope:
|
||||||
pre_assertions: list[ast.stmt] = field(default_factory=list)
|
pre_assertions: list[ast.stmt] = field(default_factory=list[ast.stmt])
|
||||||
aliases: list[str] = field(default_factory=list)
|
aliases: list[str] = field(default_factory=list[str])
|
||||||
|
|
||||||
|
|
||||||
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||||
|
IS_DATAFRAME_FUNC = "__midas_is_dataframe__"
|
||||||
|
IS_COLUMN_FUNC = "__midas_is_column__"
|
||||||
|
|
||||||
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
|
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
|
||||||
self.workdir: Path = workdir.resolve()
|
self.workdir: Path = workdir.resolve()
|
||||||
self.build_dir: Path = self.workdir / "build" / "midas"
|
self.build_dir: Path = self.workdir / "build" / "midas"
|
||||||
self.rel_src_path: Path = Path()
|
self.rel_src_path: Path = Path()
|
||||||
|
self.logger: logging.Logger = logging.getLogger("Generator")
|
||||||
|
|
||||||
self._typed_ast: TypedAST = TypedAST(
|
self._typed_ast: TypedAST = TypedAST(
|
||||||
stmts=[],
|
stmts=[],
|
||||||
judgements=[],
|
judgements=[],
|
||||||
evaluated_casts=[],
|
evaluated_casts=[],
|
||||||
|
assertions=AssertionCollector(),
|
||||||
)
|
)
|
||||||
self._alias_count: int = 0
|
self._alias_count: int = 0
|
||||||
self._predicate_count: int = 0
|
self._predicate_count: int = 0
|
||||||
self._scopes: list[Scope] = []
|
self._scopes: list[Scope] = []
|
||||||
|
self._aliases: list[tuple[p.Expr, ast.expr]] = []
|
||||||
|
|
||||||
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
|
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
|
||||||
self._constraints: list[tuple[m.Expr, ast.expr]] = []
|
self._constraints: list[tuple[m.Expr, ast.expr]] = []
|
||||||
|
|
||||||
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
|
self.define_is_dataframe: bool = False
|
||||||
self.rel_src_path = src_path.resolve().relative_to(self.workdir)
|
self.define_is_column: bool = False
|
||||||
|
|
||||||
|
def set_src_path(self, path: Path):
|
||||||
|
self.rel_src_path = path.resolve().relative_to(self.workdir)
|
||||||
|
|
||||||
|
def generate_ast(self, typed_ast: TypedAST) -> ast.AST:
|
||||||
self._typed_ast = typed_ast
|
self._typed_ast = typed_ast
|
||||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
body: list[ast.stmt] = self._visit_body(typed_ast.stmts, can_be_empty=True)
|
||||||
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
|
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
|
||||||
module = ast.Module(body=predicates + body, type_ignores=[])
|
|
||||||
|
body = predicates + body
|
||||||
|
|
||||||
|
if self.define_is_dataframe:
|
||||||
|
body = [self._is_dataframe_definition()] + body
|
||||||
|
|
||||||
|
if self.define_is_column:
|
||||||
|
body = [self._is_column_definition()] + body
|
||||||
|
|
||||||
|
module = ast.Module(body=body, type_ignores=[])
|
||||||
module = ast.fix_missing_locations(module)
|
module = ast.fix_missing_locations(module)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self, typed_ast: TypedAST, src_path: Path, out_path: Optional[Path] = None
|
self,
|
||||||
|
typed_ast: TypedAST,
|
||||||
|
src_path: Path,
|
||||||
|
out_path: Optional[Path] = None,
|
||||||
|
type_files: Optional[list[tuple[Path, Optional[str]]]] = None,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
module: ast.AST = self.generate_ast(typed_ast, src_path)
|
self.set_src_path(src_path)
|
||||||
compiled: str = ast.unparse(module)
|
|
||||||
if out_path is None:
|
if out_path is None:
|
||||||
if self.build_dir.exists():
|
if self.build_dir.exists():
|
||||||
shutil.rmtree(self.build_dir)
|
shutil.rmtree(self.build_dir)
|
||||||
@@ -78,43 +110,72 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Directory traversal, {self.rel_src_path} points outside of parent directory"
|
f"Directory traversal, {self.rel_src_path} points outside of parent directory"
|
||||||
)
|
)
|
||||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
out_dir: Path = out_path.parent
|
||||||
|
out_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if type_files is not None:
|
||||||
|
for in_path, out_name in type_files:
|
||||||
|
if out_name is None:
|
||||||
|
out_name = in_path.stem
|
||||||
|
self.generate_stubs(in_path, out_dir / f"{out_name}.py")
|
||||||
|
|
||||||
|
module: ast.AST = self.generate_ast(typed_ast)
|
||||||
|
compiled: str = ast.unparse(module)
|
||||||
|
|
||||||
out_path.write_text(compiled)
|
out_path.write_text(compiled)
|
||||||
return out_path
|
return out_path
|
||||||
|
|
||||||
|
def generate_stubs(self, in_path: Path, out_path: Path):
|
||||||
|
checker = TypeChecker()
|
||||||
|
checker.import_midas(in_path)
|
||||||
|
generator = StubsGenerator(checker.types)
|
||||||
|
module: ast.Module = generator.generate_stubs()
|
||||||
|
module = ast.fix_missing_locations(module)
|
||||||
|
output: str = ast.unparse(module)
|
||||||
|
out_path.write_text(output)
|
||||||
|
|
||||||
|
def convert(self, expr: p.Expr) -> ast.expr:
|
||||||
|
for expr2, alias in self._aliases:
|
||||||
|
if expr2 == expr:
|
||||||
|
return alias
|
||||||
|
assertions = self._typed_ast.assertions.get_assertions_for(expr)
|
||||||
|
if len(assertions) != 0:
|
||||||
|
return self._apply_assertions(expr, assertions)
|
||||||
|
return expr.accept(self)
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> ast.expr:
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> ast.expr:
|
||||||
return ast.BinOp(
|
return ast.BinOp(
|
||||||
left=expr.left.accept(self),
|
left=self.convert(expr.left),
|
||||||
op=expr.operator,
|
op=expr.operator,
|
||||||
right=expr.right.accept(self),
|
right=self.convert(expr.right),
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_compare_expr(self, expr: p.CompareExpr) -> ast.expr:
|
def visit_compare_expr(self, expr: p.CompareExpr) -> ast.expr:
|
||||||
return ast.Compare(
|
return ast.Compare(
|
||||||
left=expr.left.accept(self),
|
left=self.convert(expr.left),
|
||||||
ops=[expr.operator],
|
ops=[expr.operator],
|
||||||
comparators=[expr.right.accept(self)],
|
comparators=[self.convert(expr.right)],
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> ast.expr:
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> ast.expr:
|
||||||
return ast.UnaryOp(
|
return ast.UnaryOp(
|
||||||
op=expr.operator,
|
op=expr.operator,
|
||||||
operand=expr.right.accept(self),
|
operand=self.convert(expr.right),
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_call_expr(self, expr: p.CallExpr) -> ast.expr:
|
def visit_call_expr(self, expr: p.CallExpr) -> ast.expr:
|
||||||
return ast.Call(
|
return ast.Call(
|
||||||
func=expr.callee.accept(self),
|
func=self.convert(expr.callee),
|
||||||
args=[arg.accept(self) for arg in expr.arguments],
|
args=[self.convert(arg) for arg in expr.arguments],
|
||||||
keywords=[
|
keywords=[
|
||||||
ast.keyword(arg=name, value=arg.accept(self))
|
ast.keyword(arg=name, value=self.convert(arg))
|
||||||
for name, arg in expr.keywords.items()
|
for name, arg in expr.keywords.items()
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_get_expr(self, expr: p.GetExpr) -> ast.expr:
|
def visit_get_expr(self, expr: p.GetExpr) -> ast.expr:
|
||||||
return ast.Attribute(
|
return ast.Attribute(
|
||||||
value=expr.object.accept(self),
|
value=self.convert(expr.object),
|
||||||
attr=expr.name,
|
attr=expr.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -127,51 +188,58 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> ast.expr:
|
def visit_logical_expr(self, expr: p.LogicalExpr) -> ast.expr:
|
||||||
return ast.BoolOp(
|
return ast.BoolOp(
|
||||||
op=expr.operator,
|
op=expr.operator,
|
||||||
values=[expr.left.accept(self), expr.right.accept(self)],
|
values=[self.convert(expr.left), self.convert(expr.right)],
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
||||||
expr2: ast.expr = expr.expr.accept(self)
|
expr2: ast.expr = self.convert(expr.expr)
|
||||||
|
|
||||||
if expr in self._typed_ast.evaluated_casts:
|
if expr in self._typed_ast.evaluated_casts or expr.unsafe:
|
||||||
return expr2
|
return expr2
|
||||||
|
|
||||||
alias: ast.expr = self._make_alias(expr2)
|
alias: ast.expr = self._make_alias(expr.expr, expr2)
|
||||||
|
|
||||||
type: Type = self._get_expr_type(expr)
|
type: Type = self._get_expr_type(expr)
|
||||||
self._make_cast_asserts(expr.location, alias, type)
|
asserts: list[ast.stmt] = self._make_cast_asserts(expr.location, alias, type)
|
||||||
|
for assert_ in asserts:
|
||||||
|
self._add_assert(assert_)
|
||||||
|
|
||||||
return alias
|
return alias
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
|
||||||
return ast.IfExp(
|
return ast.IfExp(
|
||||||
test=expr.test.accept(self),
|
test=self.convert(expr.test),
|
||||||
body=expr.if_true.accept(self),
|
body=self.convert(expr.if_true),
|
||||||
orelse=expr.if_false.accept(self),
|
orelse=self.convert(expr.if_false),
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_list_expr(self, expr: p.ListExpr) -> ast.expr:
|
def visit_list_expr(self, expr: p.ListExpr) -> ast.expr:
|
||||||
return ast.List(
|
return ast.List(
|
||||||
elts=[item.accept(self) for item in expr.items],
|
elts=[self.convert(item) for item in expr.items],
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_dict_expr(self, expr: p.DictExpr) -> ast.expr:
|
def visit_dict_expr(self, expr: p.DictExpr) -> ast.expr:
|
||||||
return ast.Dict(
|
return ast.Dict(
|
||||||
keys=[key.accept(self) if key is not None else None for key in expr.keys],
|
keys=[self.convert(key) if key is not None else None for key in expr.keys],
|
||||||
values=[value.accept(self) for value in expr.values],
|
values=[self.convert(value) for value in expr.values],
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr:
|
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr:
|
||||||
return ast.Subscript(
|
return ast.Subscript(
|
||||||
value=expr.object.accept(self),
|
value=self.convert(expr.object),
|
||||||
slice=expr.index.accept(self),
|
slice=self.convert(expr.index),
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_slice_expr(self, expr: p.SliceExpr) -> ast.expr:
|
def visit_slice_expr(self, expr: p.SliceExpr) -> ast.expr:
|
||||||
return ast.Slice(
|
return ast.Slice(
|
||||||
lower=expr.lower.accept(self) if expr.lower is not None else None,
|
lower=self.convert(expr.lower) if expr.lower is not None else None,
|
||||||
upper=expr.upper.accept(self) if expr.upper is not None else None,
|
upper=self.convert(expr.upper) if expr.upper is not None else None,
|
||||||
step=expr.step.accept(self) if expr.step is not None else None,
|
step=self.convert(expr.step) if expr.step is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_tuple_expr(self, expr: p.TupleExpr) -> ast.expr:
|
||||||
|
return ast.Tuple(
|
||||||
|
elts=[self.convert(item) for item in expr.items],
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
|
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
|
||||||
@@ -179,7 +247,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
|
|
||||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt:
|
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt:
|
||||||
return ast.Expr(
|
return ast.Expr(
|
||||||
value=stmt.expr.accept(self),
|
value=self.convert(stmt.expr),
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_function(self, stmt: p.Function) -> ast.stmt:
|
def visit_function(self, stmt: p.Function) -> ast.stmt:
|
||||||
@@ -192,12 +260,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
kwonlyargs=[ast.arg(arg=arg.name) for arg in stmt.kwonlyargs],
|
kwonlyargs=[ast.arg(arg=arg.name) for arg in stmt.kwonlyargs],
|
||||||
kwarg=None,
|
kwarg=None,
|
||||||
defaults=[
|
defaults=[
|
||||||
arg.default.accept(self)
|
self.convert(arg.default)
|
||||||
for arg in stmt.posonlyargs + stmt.args
|
for arg in stmt.posonlyargs + stmt.args
|
||||||
if arg.default is not None
|
if arg.default is not None
|
||||||
],
|
],
|
||||||
kw_defaults=[
|
kw_defaults=[
|
||||||
arg.default.accept(self) if arg.default is not None else None
|
self.convert(arg.default) if arg.default is not None else None
|
||||||
for arg in stmt.kwonlyargs
|
for arg in stmt.kwonlyargs
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
@@ -211,20 +279,20 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
|
|
||||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> ast.stmt:
|
def visit_assign_stmt(self, stmt: p.AssignStmt) -> ast.stmt:
|
||||||
return ast.Assign(
|
return ast.Assign(
|
||||||
targets=[target.accept(self) for target in stmt.targets],
|
targets=[self.convert(target) for target in stmt.targets],
|
||||||
value=stmt.value.accept(self),
|
value=self.convert(stmt.value),
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> ast.stmt:
|
def visit_return_stmt(self, stmt: p.ReturnStmt) -> ast.stmt:
|
||||||
return ast.Return(
|
return ast.Return(
|
||||||
value=stmt.value.accept(self) if stmt.value is not None else None,
|
value=self.convert(stmt.value) if stmt.value is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_if_stmt(self, stmt: p.IfStmt) -> ast.stmt:
|
def visit_if_stmt(self, stmt: p.IfStmt) -> ast.stmt:
|
||||||
return ast.If(
|
return ast.If(
|
||||||
test=stmt.test.accept(self),
|
test=self.convert(stmt.test),
|
||||||
body=self._visit_body(stmt.body),
|
body=self._visit_body(stmt.body),
|
||||||
orelse=self._visit_body(stmt.orelse),
|
orelse=self._visit_body(stmt.orelse, can_be_empty=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_pass(self, stmt: p.Pass) -> ast.stmt:
|
def visit_pass(self, stmt: p.Pass) -> ast.stmt:
|
||||||
@@ -232,8 +300,8 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
|
|
||||||
def visit_for_stmt(self, stmt: p.ForStmt) -> ast.stmt:
|
def visit_for_stmt(self, stmt: p.ForStmt) -> ast.stmt:
|
||||||
return ast.For(
|
return ast.For(
|
||||||
target=stmt.target.accept(self),
|
target=self.convert(stmt.target),
|
||||||
iter=stmt.iterator.accept(self),
|
iter=self.convert(stmt.iterator),
|
||||||
body=self._visit_body(stmt.body),
|
body=self._visit_body(stmt.body),
|
||||||
orelse=[],
|
orelse=[],
|
||||||
)
|
)
|
||||||
@@ -241,7 +309,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
def visit_raw_stmt(self, stmt: p.RawStmt) -> ast.stmt:
|
def visit_raw_stmt(self, stmt: p.RawStmt) -> ast.stmt:
|
||||||
return stmt.stmt
|
return stmt.stmt
|
||||||
|
|
||||||
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
|
def _visit_body(
|
||||||
|
self, stmts: list[p.Stmt], can_be_empty: bool = False
|
||||||
|
) -> list[ast.stmt]:
|
||||||
generated: list[ast.stmt] = []
|
generated: list[ast.stmt] = []
|
||||||
for stmt in stmts:
|
for stmt in stmts:
|
||||||
scope = Scope()
|
scope = Scope()
|
||||||
@@ -259,9 +329,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
# Remove redundant pass statements
|
# Remove redundant pass statements
|
||||||
if len(generated) > 1:
|
if len(generated) > 1:
|
||||||
generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)]
|
generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)]
|
||||||
|
if len(generated) == 0 and not can_be_empty:
|
||||||
|
generated = [ast.Pass()]
|
||||||
return generated
|
return generated
|
||||||
|
|
||||||
def _make_alias(self, expr: ast.expr) -> ast.expr:
|
def _make_alias(self, node: p.Expr, expr: ast.expr) -> ast.expr:
|
||||||
name: str = f"__midas_a{self._alias_count}__"
|
name: str = f"__midas_a{self._alias_count}__"
|
||||||
alias = ast.Name(id=name)
|
alias = ast.Name(id=name)
|
||||||
self._alias_count += 1
|
self._alias_count += 1
|
||||||
@@ -272,82 +344,182 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
value=expr,
|
value=expr,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self._aliases.append((node, alias))
|
||||||
return alias
|
return alias
|
||||||
|
|
||||||
def _add_assert(self, expr: ast.expr, message: str | ast.expr):
|
def _build_assert(self, expr: ast.expr, message: str | ast.expr) -> ast.stmt:
|
||||||
if isinstance(message, str):
|
if isinstance(message, str):
|
||||||
message = ast.Constant(value=message)
|
message = ast.Constant(value=message)
|
||||||
self._scopes[-1].pre_assertions.append(
|
return ast.Assert(
|
||||||
ast.Assert(
|
test=expr,
|
||||||
test=expr,
|
msg=message,
|
||||||
msg=message,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _add_assert(self, assertion: ast.stmt):
|
||||||
|
self._scopes[-1].pre_assertions.append(assertion)
|
||||||
|
|
||||||
def _get_expr_type(self, query: p.Expr) -> Type:
|
def _get_expr_type(self, query: p.Expr) -> Type:
|
||||||
for expr, type in self._typed_ast.judgements:
|
for expr, type in self._typed_ast.judgements:
|
||||||
if expr == query:
|
if expr == query:
|
||||||
return type
|
return type
|
||||||
raise RuntimeError(f"Cannot get type judgement for {query}")
|
raise RuntimeError(f"Cannot get type judgement for {query}")
|
||||||
|
|
||||||
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
|
def _make_cast_asserts(
|
||||||
|
self, src_location: Location, expr: ast.expr, type: Type
|
||||||
|
) -> list[ast.stmt]:
|
||||||
match type:
|
match type:
|
||||||
case UnknownType():
|
case UnknownType() | TopType():
|
||||||
pass
|
return []
|
||||||
|
|
||||||
case BaseType(name=name):
|
case BaseType(name=name):
|
||||||
self._add_assert(
|
return [
|
||||||
ast.Call(
|
self._build_assert(
|
||||||
func=ast.Name(id="isinstance"),
|
ast.Call(
|
||||||
args=[expr, ast.Name(id=name)],
|
func=ast.Name(id="isinstance"),
|
||||||
keywords=[],
|
args=[expr, ast.Name(id=name)],
|
||||||
),
|
keywords=[],
|
||||||
self._make_cast_assert_message(src_location, expr, type),
|
),
|
||||||
)
|
self._make_cast_assert_message(src_location, expr, type),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
case AliasType(type=base):
|
case DerivedType(type=base):
|
||||||
self._make_cast_asserts(src_location, expr, base)
|
return self._make_cast_asserts(src_location, expr, base)
|
||||||
|
|
||||||
case UnitType():
|
case UnitType():
|
||||||
self._add_assert(
|
return [
|
||||||
ast.Compare(
|
self._build_assert(
|
||||||
left=expr,
|
ast.Compare(
|
||||||
ops=[ast.Is()],
|
left=expr,
|
||||||
comparators=[
|
ops=[ast.Is()],
|
||||||
ast.Constant(value=None),
|
comparators=[
|
||||||
],
|
ast.Constant(value=None),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(src_location, expr, type),
|
||||||
),
|
),
|
||||||
self._make_cast_assert_message(src_location, expr, type),
|
]
|
||||||
)
|
|
||||||
|
|
||||||
case AppliedType(body=body):
|
case AppliedType(body=body):
|
||||||
self._make_cast_asserts(src_location, expr, body)
|
return self._make_cast_asserts(src_location, expr, body)
|
||||||
|
|
||||||
case ConstraintType(type=base, constraint=constraint):
|
case ConstraintType(type=base, constraint=constraint):
|
||||||
self._make_cast_asserts(src_location, expr, base)
|
asserts: list[ast.stmt] = self._make_cast_asserts(
|
||||||
self._make_constraint_assert(src_location, expr, constraint)
|
src_location, expr, base
|
||||||
|
)
|
||||||
|
asserts.append(
|
||||||
|
self._make_constraint_assert(src_location, expr, constraint)
|
||||||
|
)
|
||||||
|
return asserts
|
||||||
|
|
||||||
case TypeVar(bound=bound):
|
case TypeVar(bound=bound):
|
||||||
# TODO: check with type from arguments / use call-site context
|
# TODO: check with type from arguments / use call-site context
|
||||||
if bound is not None:
|
if bound is None:
|
||||||
self._make_cast_asserts(src_location, expr, bound)
|
return []
|
||||||
|
return self._make_cast_asserts(src_location, expr, bound)
|
||||||
|
|
||||||
|
case TupleType(items=items):
|
||||||
|
asserts: list[ast.stmt] = [
|
||||||
|
self._build_assert(
|
||||||
|
ast.Call(
|
||||||
|
func=ast.Name(id="isinstance"),
|
||||||
|
args=[expr, ast.Name(id="tuple")],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(src_location, expr, type),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
assert isinstance(expr, ast.Tuple)
|
||||||
|
for item, item_type in zip(expr.elts, items):
|
||||||
|
asserts.extend(
|
||||||
|
self._make_cast_asserts(src_location, item, item_type)
|
||||||
|
)
|
||||||
|
return asserts
|
||||||
|
|
||||||
|
case DataFrameType(columns=columns):
|
||||||
|
self.define_is_dataframe = True
|
||||||
|
asserts: list[ast.stmt] = [
|
||||||
|
self._build_assert(
|
||||||
|
ast.Call(
|
||||||
|
func=ast.Name(id=self.IS_DATAFRAME_FUNC),
|
||||||
|
args=[expr],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(
|
||||||
|
src_location, expr, type, ": Not a dataframe"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for column in columns:
|
||||||
|
asserts.append(
|
||||||
|
self._build_assert(
|
||||||
|
ast.Compare(
|
||||||
|
left=ast.Constant(value=column.name),
|
||||||
|
ops=[ast.In()],
|
||||||
|
comparators=[expr],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(
|
||||||
|
src_location,
|
||||||
|
expr,
|
||||||
|
type,
|
||||||
|
f": Missing column {column.name}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
asserts.extend(
|
||||||
|
self._make_cast_asserts(
|
||||||
|
src_location,
|
||||||
|
ast.Subscript(
|
||||||
|
value=expr, slice=ast.Constant(value=column.name)
|
||||||
|
),
|
||||||
|
column.type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return asserts
|
||||||
|
|
||||||
|
case ColumnType():
|
||||||
|
self.define_is_column = True
|
||||||
|
asserts: list[ast.stmt] = [
|
||||||
|
self._build_assert(
|
||||||
|
ast.Call(
|
||||||
|
func=ast.Name(id=self.IS_COLUMN_FUNC),
|
||||||
|
args=[expr],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(
|
||||||
|
src_location, expr, type, ": Not a column"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
inner_assert: Optional[ast.stmt] = self._make_column_inner_assert(
|
||||||
|
src_location, expr, type
|
||||||
|
)
|
||||||
|
if inner_assert is not None:
|
||||||
|
asserts.append(inner_assert)
|
||||||
|
return asserts
|
||||||
|
|
||||||
case (
|
case (
|
||||||
TopType()
|
Function()
|
||||||
| Function()
|
|
||||||
| OverloadedFunction()
|
| OverloadedFunction()
|
||||||
| ComplexType()
|
| ComplexType()
|
||||||
| ExtensionType()
|
| ExtensionType()
|
||||||
| GenericType()
|
| GenericType()
|
||||||
|
| FrameGroupBy()
|
||||||
|
| ColumnGroupBy()
|
||||||
):
|
):
|
||||||
raise NotImplementedError(f"Can't make assertion for type {type}")
|
self.logger.warning(f"Can't make assertion for type {type}")
|
||||||
|
return []
|
||||||
|
|
||||||
# Ensure exhaustiveness
|
# Ensure exhaustiveness
|
||||||
case _:
|
case _:
|
||||||
assert_never(type)
|
assert_never(type)
|
||||||
|
|
||||||
def _make_cast_assert_message(
|
def _make_cast_assert_message(
|
||||||
self, location: Location, expr: ast.expr, type: Type
|
self,
|
||||||
|
location: Location,
|
||||||
|
expr: ast.expr,
|
||||||
|
type: Type,
|
||||||
|
extra: Optional[str] = None,
|
||||||
) -> ast.expr:
|
) -> ast.expr:
|
||||||
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||||
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
|
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
|
||||||
@@ -365,15 +537,15 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
),
|
),
|
||||||
conversion=-1,
|
conversion=-1,
|
||||||
),
|
),
|
||||||
ast.Constant(f" to {type}"),
|
ast.Constant(f" to {type}{extra or ''}"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _make_constraint_assert(
|
def _make_constraint_assert(
|
||||||
self, src_location: Location, expr: ast.expr, constraint: m.Expr
|
self, src_location: Location, expr: ast.expr, constraint: m.Expr
|
||||||
):
|
) -> ast.stmt:
|
||||||
test_func: ast.expr = self._get_constraint(constraint)
|
test_func: ast.expr = self._get_constraint(constraint)
|
||||||
self._add_assert(
|
return self._build_assert(
|
||||||
ast.Call(
|
ast.Call(
|
||||||
func=test_func,
|
func=test_func,
|
||||||
args=[expr],
|
args=[expr],
|
||||||
@@ -401,3 +573,117 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
constraint: ast.expr = self._constraint_generator.generate(expr)
|
constraint: ast.expr = self._constraint_generator.generate(expr)
|
||||||
self._constraints.append((expr, constraint))
|
self._constraints.append((expr, constraint))
|
||||||
return constraint
|
return constraint
|
||||||
|
|
||||||
|
def _is_dataframe_definition(self) -> ast.stmt:
|
||||||
|
"""
|
||||||
|
def IS_DATAFRAME_FUNC(obj) -> bool:
|
||||||
|
import pandas as pd
|
||||||
|
return isinstance(obj, pd.DataFrame)
|
||||||
|
"""
|
||||||
|
|
||||||
|
return ast.FunctionDef(
|
||||||
|
name=self.IS_DATAFRAME_FUNC,
|
||||||
|
args=ast.arguments(
|
||||||
|
posonlyargs=[ast.arg(arg="obj")],
|
||||||
|
args=[],
|
||||||
|
kwonlyargs=[],
|
||||||
|
defaults=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
),
|
||||||
|
body=[
|
||||||
|
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
|
||||||
|
ast.Return(
|
||||||
|
value=ast.Call(
|
||||||
|
func=ast.Name(id="isinstance"),
|
||||||
|
args=[
|
||||||
|
ast.Name(id="obj"),
|
||||||
|
ast.Attribute(
|
||||||
|
value=ast.Name(id="pd"),
|
||||||
|
attr="DataFrame",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
keywords=[],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=ast.Name(id="bool"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_column_definition(self) -> ast.stmt:
|
||||||
|
"""
|
||||||
|
def IS_COLUMN_FUNC(obj) -> bool:
|
||||||
|
import pandas as pd
|
||||||
|
return isinstance(obj, pd.Series)
|
||||||
|
"""
|
||||||
|
|
||||||
|
return ast.FunctionDef(
|
||||||
|
name=self.IS_COLUMN_FUNC,
|
||||||
|
args=ast.arguments(
|
||||||
|
posonlyargs=[ast.arg(arg="obj")],
|
||||||
|
args=[],
|
||||||
|
kwonlyargs=[],
|
||||||
|
defaults=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
),
|
||||||
|
body=[
|
||||||
|
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
|
||||||
|
ast.Return(
|
||||||
|
value=ast.Call(
|
||||||
|
func=ast.Name(id="isinstance"),
|
||||||
|
args=[
|
||||||
|
ast.Name(id="obj"),
|
||||||
|
ast.Attribute(
|
||||||
|
value=ast.Name(id="pd"),
|
||||||
|
attr="Series",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
keywords=[],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=ast.Name(id="bool"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_column_inner_assert(
|
||||||
|
self, src_location: Location, column: ast.expr, type: ColumnType
|
||||||
|
) -> Optional[ast.stmt]:
|
||||||
|
# TODO: improve message, maybe chain contexts
|
||||||
|
col: ast.expr = ast.Name(id="col")
|
||||||
|
body: list[ast.stmt] = self._make_cast_asserts(src_location, col, type.type)
|
||||||
|
if len(body) == 0:
|
||||||
|
return None
|
||||||
|
return ast.For(
|
||||||
|
target=col,
|
||||||
|
iter=column,
|
||||||
|
body=body,
|
||||||
|
orelse=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_assertion(self, assertion: Assertion) -> ast.stmt:
|
||||||
|
inputs: list[ast.expr] = []
|
||||||
|
|
||||||
|
for input in assertion.inputs:
|
||||||
|
converted: ast.expr = self.convert(input)
|
||||||
|
alias: ast.expr = self._make_alias(input, converted)
|
||||||
|
inputs.append(alias)
|
||||||
|
|
||||||
|
test: ast.expr = assertion.builder(*inputs)
|
||||||
|
location: Location = assertion.bound_expr.location
|
||||||
|
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||||
|
return self._build_assert(
|
||||||
|
test, f"{loc_str}: AssertionError: {assertion.message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_assertions(self, expr: p.Expr, assertions: list[Assertion]) -> ast.expr:
|
||||||
|
for assertion in assertions:
|
||||||
|
assert_stmt: ast.stmt
|
||||||
|
assert_stmt = self._convert_assertion(assertion)
|
||||||
|
self._add_assert(assert_stmt)
|
||||||
|
|
||||||
|
# Mutating list in frozen dataclass
|
||||||
|
# Not ideal but easiest way to avoid duplicate assertions
|
||||||
|
self._typed_ast.assertions.remove(assertion)
|
||||||
|
|
||||||
|
return expr.accept(self)
|
||||||
|
|||||||
@@ -4,16 +4,21 @@ from typing import Optional, assert_never
|
|||||||
import midas.ast.midas as m
|
import midas.ast.midas as m
|
||||||
from midas.checker.registry import Member, TypesRegistry
|
from midas.checker.registry import Member, TypesRegistry
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
|
||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
|
ColumnGroupBy,
|
||||||
|
ColumnType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
|
DataFrameType,
|
||||||
|
DerivedType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
|
FrameGroupBy,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
TopType,
|
TopType,
|
||||||
|
TupleType,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnitType,
|
UnitType,
|
||||||
@@ -30,6 +35,7 @@ class StubsGenerator:
|
|||||||
self.types: TypesRegistry = types
|
self.types: TypesRegistry = types
|
||||||
self.stubs: list[ast.stmt] = []
|
self.stubs: list[ast.stmt] = []
|
||||||
self.typing_imports: set[str] = set()
|
self.typing_imports: set[str] = set()
|
||||||
|
self.import_pandas: bool = False
|
||||||
self.protocol_idx: int = 0
|
self.protocol_idx: int = 0
|
||||||
self.stub_idx: int = 0
|
self.stub_idx: int = 0
|
||||||
self.type_var_idx: int = 0
|
self.type_var_idx: int = 0
|
||||||
@@ -38,6 +44,7 @@ class StubsGenerator:
|
|||||||
def generate_stubs(self) -> ast.Module:
|
def generate_stubs(self) -> ast.Module:
|
||||||
self.stubs = []
|
self.stubs = []
|
||||||
self.typing_imports = set()
|
self.typing_imports = set()
|
||||||
|
self.import_pandas = False
|
||||||
for name, type in self.types._types.items():
|
for name, type in self.types._types.items():
|
||||||
# Skip builtin types, not just based on name so the user can override
|
# Skip builtin types, not just based on name so the user can override
|
||||||
# TODO: check if added members on builtin type
|
# TODO: check if added members on builtin type
|
||||||
@@ -53,7 +60,7 @@ class StubsGenerator:
|
|||||||
continue
|
continue
|
||||||
self.generate_stub(name, type)
|
self.generate_stub(name, type)
|
||||||
|
|
||||||
imports = [
|
imports: list[ast.stmt] = [
|
||||||
ast.ImportFrom(
|
ast.ImportFrom(
|
||||||
module="__future__",
|
module="__future__",
|
||||||
names=[ast.alias(name="annotations")],
|
names=[ast.alias(name="annotations")],
|
||||||
@@ -70,11 +77,37 @@ class StubsGenerator:
|
|||||||
level=0,
|
level=0,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if self.import_pandas:
|
||||||
|
imports.append(
|
||||||
|
ast.Import(
|
||||||
|
names=[
|
||||||
|
ast.alias(
|
||||||
|
name="pandas",
|
||||||
|
asname="pd",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
return ast.Module(body=imports + self.stubs, type_ignores=[])
|
return ast.Module(body=imports + self.stubs, type_ignores=[])
|
||||||
|
|
||||||
def generate_stub(self, name: str, type: Type):
|
def generate_stub(self, name: str, type: Type):
|
||||||
base_type: Type = type
|
base_type: Type = type
|
||||||
|
|
||||||
|
# TODO: improve
|
||||||
|
match type:
|
||||||
|
case DerivedType(name=name_) | GenericType(name=name_) if name_ == name:
|
||||||
|
pass
|
||||||
|
case UnitType() if name == "None":
|
||||||
|
pass
|
||||||
|
case TopType() if name == "Any":
|
||||||
|
pass
|
||||||
|
case _:
|
||||||
|
alias = ast.Assign(
|
||||||
|
targets=[ast.Name(id=name)], value=self.dump_type(type)
|
||||||
|
)
|
||||||
|
self.add_stub(alias)
|
||||||
|
return
|
||||||
|
|
||||||
members: dict[str, Member] = self.types._members.get(name, {})
|
members: dict[str, Member] = self.types._members.get(name, {})
|
||||||
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
|
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
|
||||||
return
|
return
|
||||||
@@ -96,7 +129,7 @@ class StubsGenerator:
|
|||||||
|
|
||||||
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
|
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
|
||||||
match type:
|
match type:
|
||||||
case AliasType(type=base):
|
case DerivedType(type=base):
|
||||||
return [self.dump_type(base)], {}
|
return [self.dump_type(base)], {}
|
||||||
|
|
||||||
case GenericType(params=params, body=body):
|
case GenericType(params=params, body=body):
|
||||||
@@ -161,7 +194,7 @@ class StubsGenerator:
|
|||||||
|
|
||||||
def dump_type(self, type: Type) -> ast.expr:
|
def dump_type(self, type: Type) -> ast.expr:
|
||||||
match type:
|
match type:
|
||||||
case AliasType(name=name) | GenericType(name=name) if (
|
case DerivedType(name=name) | GenericType(name=name) if (
|
||||||
name in self.substitutions
|
name in self.substitutions
|
||||||
):
|
):
|
||||||
type = substitute_typevars(type, self.substitutions[name])
|
type = substitute_typevars(type, self.substitutions[name])
|
||||||
@@ -174,7 +207,7 @@ class StubsGenerator:
|
|||||||
case BaseType(name=name):
|
case BaseType(name=name):
|
||||||
return ast.Name(id=name)
|
return ast.Name(id=name)
|
||||||
|
|
||||||
case AliasType(name=name):
|
case DerivedType(name=name):
|
||||||
return ast.Name(id=name)
|
return ast.Name(id=name)
|
||||||
|
|
||||||
case UnitType():
|
case UnitType():
|
||||||
@@ -231,6 +264,57 @@ class StubsGenerator:
|
|||||||
case ConstraintType():
|
case ConstraintType():
|
||||||
return self.dump_type(type.type)
|
return self.dump_type(type.type)
|
||||||
|
|
||||||
|
case TupleType(items=items):
|
||||||
|
return ast.Subscript(
|
||||||
|
value=ast.Name(id="tuple"),
|
||||||
|
slice=ast.Tuple(
|
||||||
|
elts=[self.dump_type(item) for item in items],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ColumnType(type=inner):
|
||||||
|
self.import_pandas = True
|
||||||
|
return ast.Subscript(
|
||||||
|
value=ast.Attribute(
|
||||||
|
value=ast.Name(id="pd"),
|
||||||
|
attr="Series",
|
||||||
|
),
|
||||||
|
slice=self.dump_type(inner),
|
||||||
|
)
|
||||||
|
|
||||||
|
case DataFrameType():
|
||||||
|
self.import_pandas = True
|
||||||
|
return ast.Attribute(
|
||||||
|
value=ast.Name(id="pd"),
|
||||||
|
attr="DataFrame",
|
||||||
|
)
|
||||||
|
|
||||||
|
case FrameGroupBy():
|
||||||
|
self.import_pandas = True
|
||||||
|
return ast.Attribute(
|
||||||
|
value=ast.Attribute(
|
||||||
|
value=ast.Attribute(
|
||||||
|
value=ast.Name(id="pd"),
|
||||||
|
attr="api",
|
||||||
|
),
|
||||||
|
attr="typing",
|
||||||
|
),
|
||||||
|
attr="DataFrameGroupBy",
|
||||||
|
)
|
||||||
|
|
||||||
|
case ColumnGroupBy():
|
||||||
|
self.import_pandas = True
|
||||||
|
return ast.Attribute(
|
||||||
|
value=ast.Attribute(
|
||||||
|
value=ast.Attribute(
|
||||||
|
value=ast.Name(id="pd"),
|
||||||
|
attr="api",
|
||||||
|
),
|
||||||
|
attr="typing",
|
||||||
|
),
|
||||||
|
attr="SeriesGroupBy",
|
||||||
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
assert_never(type)
|
assert_never(type)
|
||||||
|
|
||||||
|
|||||||
@@ -46,8 +46,8 @@ class MidasLexer(Lexer):
|
|||||||
self.add_token(TokenType.UNDERSCORE)
|
self.add_token(TokenType.UNDERSCORE)
|
||||||
case "-" if self.match(">"):
|
case "-" if self.match(">"):
|
||||||
self.add_token(TokenType.ARROW)
|
self.add_token(TokenType.ARROW)
|
||||||
# case "+":
|
case "+":
|
||||||
# self.add_token(TokenType.PLUS)
|
self.add_token(TokenType.PLUS)
|
||||||
case "-":
|
case "-":
|
||||||
self.add_token(TokenType.MINUS)
|
self.add_token(TokenType.MINUS)
|
||||||
case "*":
|
case "*":
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class TokenType(Enum):
|
|||||||
DOT = auto()
|
DOT = auto()
|
||||||
|
|
||||||
# Operators
|
# Operators
|
||||||
# PLUS = auto()
|
PLUS = auto()
|
||||||
MINUS = auto()
|
MINUS = auto()
|
||||||
STAR = auto()
|
STAR = auto()
|
||||||
SLASH = auto()
|
SLASH = auto()
|
||||||
@@ -47,6 +47,7 @@ class TokenType(Enum):
|
|||||||
|
|
||||||
# Keywords
|
# Keywords
|
||||||
TYPE = auto()
|
TYPE = auto()
|
||||||
|
ALIAS = auto()
|
||||||
PREDICATE = auto()
|
PREDICATE = auto()
|
||||||
EXTEND = auto()
|
EXTEND = auto()
|
||||||
WHERE = auto()
|
WHERE = auto()
|
||||||
@@ -63,6 +64,7 @@ class TokenType(Enum):
|
|||||||
|
|
||||||
KEYWORDS: dict[str, TokenType] = {
|
KEYWORDS: dict[str, TokenType] = {
|
||||||
"type": TokenType.TYPE,
|
"type": TokenType.TYPE,
|
||||||
|
"alias": TokenType.ALIAS,
|
||||||
"predicate": TokenType.PREDICATE,
|
"predicate": TokenType.PREDICATE,
|
||||||
"extend": TokenType.EXTEND,
|
"extend": TokenType.EXTEND,
|
||||||
"where": TokenType.WHERE,
|
"where": TokenType.WHERE,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.ast.midas import (
|
from midas.ast.midas import (
|
||||||
|
AliasStmt,
|
||||||
BinaryExpr,
|
BinaryExpr,
|
||||||
CallExpr,
|
CallExpr,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
@@ -9,6 +10,7 @@ from midas.ast.midas import (
|
|||||||
Expr,
|
Expr,
|
||||||
ExtendStmt,
|
ExtendStmt,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
|
FrameType,
|
||||||
FunctionType,
|
FunctionType,
|
||||||
GenericType,
|
GenericType,
|
||||||
GetExpr,
|
GetExpr,
|
||||||
@@ -79,6 +81,8 @@ class MidasParser(Parser):
|
|||||||
try:
|
try:
|
||||||
if self.match(TokenType.TYPE):
|
if self.match(TokenType.TYPE):
|
||||||
return self.type_declaration()
|
return self.type_declaration()
|
||||||
|
if self.match(TokenType.ALIAS):
|
||||||
|
return self.alias_declaration()
|
||||||
if self.match(TokenType.EXTEND):
|
if self.match(TokenType.EXTEND):
|
||||||
return self.extend_declaration()
|
return self.extend_declaration()
|
||||||
if self.match(TokenType.PREDICATE):
|
if self.match(TokenType.PREDICATE):
|
||||||
@@ -158,6 +162,25 @@ class MidasParser(Parser):
|
|||||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
|
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
def alias_declaration(self) -> AliasStmt:
|
||||||
|
"""Parse an alias declaration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AliasStmt: the parsed alias declaration statement
|
||||||
|
"""
|
||||||
|
keyword: Token = self.previous()
|
||||||
|
name: Token = self.consume_identifier("Expected type name")
|
||||||
|
|
||||||
|
self.consume(TokenType.EQUAL, "Expected '=' before alias definition")
|
||||||
|
|
||||||
|
type: Type = self.type_expr()
|
||||||
|
|
||||||
|
return AliasStmt(
|
||||||
|
location=keyword.location_to(self.previous()),
|
||||||
|
name=name,
|
||||||
|
type=type,
|
||||||
|
)
|
||||||
|
|
||||||
def type_expr(self) -> Type:
|
def type_expr(self) -> Type:
|
||||||
"""Parse a type expression
|
"""Parse a type expression
|
||||||
|
|
||||||
@@ -204,8 +227,10 @@ class MidasParser(Parser):
|
|||||||
return self.generic_type()
|
return self.generic_type()
|
||||||
|
|
||||||
def generic_type(self) -> Type:
|
def generic_type(self) -> Type:
|
||||||
type: Type = self.named_type()
|
type: NamedType = self.named_type()
|
||||||
if self.check(TokenType.LEFT_BRACKET):
|
if self.check(TokenType.LEFT_BRACKET):
|
||||||
|
if type.name.lexeme == "Frame":
|
||||||
|
return self.frame_type()
|
||||||
args: list[Type] = self.type_args()
|
args: list[Type] = self.type_args()
|
||||||
return GenericType(
|
return GenericType(
|
||||||
location=Location.span(type.location, self.previous().get_location()),
|
location=Location.span(type.location, self.previous().get_location()),
|
||||||
@@ -224,7 +249,7 @@ class MidasParser(Parser):
|
|||||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
|
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def named_type(self) -> Type:
|
def named_type(self) -> NamedType:
|
||||||
name: Token = self.consume_identifier("Expected type name")
|
name: Token = self.consume_identifier("Expected type name")
|
||||||
return NamedType(
|
return NamedType(
|
||||||
location=name.get_location(),
|
location=name.get_location(),
|
||||||
@@ -259,6 +284,32 @@ class MidasParser(Parser):
|
|||||||
members=members,
|
members=members,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def frame_type(self) -> FrameType:
|
||||||
|
keyword: Token = self.previous()
|
||||||
|
self.consume(TokenType.LEFT_BRACKET, "Expected '[' to start frame schema")
|
||||||
|
|
||||||
|
columns: list[FrameType.Column] = []
|
||||||
|
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
|
||||||
|
name: Token = self.advance()
|
||||||
|
self.consume(TokenType.COLON, "Expected ':' between column name and type")
|
||||||
|
type: Type = self.type_expr()
|
||||||
|
columns.append(
|
||||||
|
FrameType.Column(
|
||||||
|
location=name.location_to(self.previous()),
|
||||||
|
name=name,
|
||||||
|
type=type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not self.match(TokenType.COMMA):
|
||||||
|
break
|
||||||
|
|
||||||
|
self.consume(TokenType.RIGHT_BRACKET, "Unclosed frame schema")
|
||||||
|
|
||||||
|
return FrameType(
|
||||||
|
location=keyword.location_to(self.previous()),
|
||||||
|
columns=columns,
|
||||||
|
)
|
||||||
|
|
||||||
def constraint(self) -> Expr:
|
def constraint(self) -> Expr:
|
||||||
"""Parse a constraint
|
"""Parse a constraint
|
||||||
|
|
||||||
@@ -310,13 +361,35 @@ class MidasParser(Parser):
|
|||||||
Returns:
|
Returns:
|
||||||
Expr: the parsed expression
|
Expr: the parsed expression
|
||||||
"""
|
"""
|
||||||
expr: Expr = self.unary()
|
expr: Expr = self.term()
|
||||||
while self.match(
|
while self.match(
|
||||||
TokenType.LESS,
|
TokenType.LESS,
|
||||||
TokenType.LESS_EQUAL,
|
TokenType.LESS_EQUAL,
|
||||||
TokenType.GREATER,
|
TokenType.GREATER,
|
||||||
TokenType.GREATER_EQUAL,
|
TokenType.GREATER_EQUAL,
|
||||||
):
|
):
|
||||||
|
operator: Token = self.previous()
|
||||||
|
right: Expr = self.term()
|
||||||
|
location: Location = Location.span(expr.location, right.location)
|
||||||
|
expr = BinaryExpr(
|
||||||
|
location=location, left=expr, operator=operator, right=right
|
||||||
|
)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
def term(self) -> Expr:
|
||||||
|
expr: Expr = self.factor()
|
||||||
|
while self.match(TokenType.PLUS, TokenType.MINUS):
|
||||||
|
operator: Token = self.previous()
|
||||||
|
right: Expr = self.factor()
|
||||||
|
location: Location = Location.span(expr.location, right.location)
|
||||||
|
expr = BinaryExpr(
|
||||||
|
location=location, left=expr, operator=operator, right=right
|
||||||
|
)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
def factor(self) -> Expr:
|
||||||
|
expr: Expr = self.unary()
|
||||||
|
while self.match(TokenType.STAR, TokenType.SLASH):
|
||||||
operator: Token = self.previous()
|
operator: Token = self.previous()
|
||||||
right: Expr = self.unary()
|
right: Expr = self.unary()
|
||||||
location: Location = Location.span(expr.location, right.location)
|
location: Location = Location.span(expr.location, right.location)
|
||||||
@@ -348,7 +421,7 @@ class MidasParser(Parser):
|
|||||||
pos_args: list[Expr] = []
|
pos_args: list[Expr] = []
|
||||||
kw_args: dict[str, Expr] = {}
|
kw_args: dict[str, Expr] = {}
|
||||||
keywords: bool = False
|
keywords: bool = False
|
||||||
while not self.match(TokenType.RIGHT_PAREN):
|
while not self.check(TokenType.RIGHT_PAREN):
|
||||||
if self.check_identifier() and self.check_next(TokenType.EQUAL):
|
if self.check_identifier() and self.check_next(TokenType.EQUAL):
|
||||||
keywords = True
|
keywords = True
|
||||||
keyword: Token = self.advance()
|
keyword: Token = self.advance()
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from midas.ast.python import (
|
|||||||
Stmt,
|
Stmt,
|
||||||
SubscriptExpr,
|
SubscriptExpr,
|
||||||
TernaryExpr,
|
TernaryExpr,
|
||||||
|
TupleExpr,
|
||||||
TypeAssign,
|
TypeAssign,
|
||||||
UnaryExpr,
|
UnaryExpr,
|
||||||
VariableExpr,
|
VariableExpr,
|
||||||
@@ -49,6 +50,7 @@ class UnsupportedSyntaxError(Exception):
|
|||||||
|
|
||||||
class PythonParser:
|
class PythonParser:
|
||||||
CAST_FUNCTION = "cast"
|
CAST_FUNCTION = "cast"
|
||||||
|
UNSAFE_CAST_FUNCTION = "unsafe_cast"
|
||||||
|
|
||||||
def parse_module(self, node: ast.Module) -> list[Stmt]:
|
def parse_module(self, node: ast.Module) -> list[Stmt]:
|
||||||
statements: list[Stmt] = []
|
statements: list[Stmt] = []
|
||||||
@@ -299,26 +301,28 @@ class PythonParser:
|
|||||||
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
|
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
|
||||||
return self._parse_frame_type(schema)
|
return self._parse_frame_type(schema)
|
||||||
|
|
||||||
case ast.Subscript(value=ast.Name(id=name), slice=param):
|
case ast.Subscript(value=ast.Name(id=name), slice=arg):
|
||||||
|
args: tuple[MidasType, ...] = (
|
||||||
|
tuple(self._parse_type(a) for a in arg.elts)
|
||||||
|
if isinstance(arg, ast.Tuple)
|
||||||
|
else (self._parse_type(arg),)
|
||||||
|
)
|
||||||
return BaseType(
|
return BaseType(
|
||||||
location=loc,
|
location=loc,
|
||||||
base=name,
|
base=name,
|
||||||
param=self._parse_type(param),
|
args=args,
|
||||||
)
|
)
|
||||||
|
|
||||||
case ast.Name(id=name):
|
case ast.Name(id=name):
|
||||||
return BaseType(
|
return BaseType(
|
||||||
location=loc,
|
location=loc,
|
||||||
base=name,
|
base=name,
|
||||||
param=None,
|
args=(),
|
||||||
)
|
)
|
||||||
|
|
||||||
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
|
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
|
||||||
left = self._parse_type(left_expr)
|
left = self._parse_type(left_expr)
|
||||||
match left:
|
match left:
|
||||||
case None:
|
|
||||||
raise InvalidSyntaxError()
|
|
||||||
|
|
||||||
# If chained constraints, separate base type and rebuild constraint
|
# If chained constraints, separate base type and rebuild constraint
|
||||||
case ConstraintType(type=left_type, constraint=left_constraint):
|
case ConstraintType(type=left_type, constraint=left_constraint):
|
||||||
constraint = ast.BinOp(
|
constraint = ast.BinOp(
|
||||||
@@ -344,7 +348,7 @@ class PythonParser:
|
|||||||
return BaseType(
|
return BaseType(
|
||||||
location=loc,
|
location=loc,
|
||||||
base="None",
|
base="None",
|
||||||
param=None,
|
args=(),
|
||||||
)
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
@@ -423,6 +427,9 @@ class PythonParser:
|
|||||||
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
|
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
|
||||||
return self.parse_cast(node)
|
return self.parse_cast(node)
|
||||||
|
|
||||||
|
case ast.Call(func=ast.Name(id=self.UNSAFE_CAST_FUNCTION)):
|
||||||
|
return self.parse_cast(node)
|
||||||
|
|
||||||
case ast.Call():
|
case ast.Call():
|
||||||
return self.parse_call(node)
|
return self.parse_call(node)
|
||||||
|
|
||||||
@@ -473,6 +480,12 @@ class PythonParser:
|
|||||||
step=self.parse_expr(step) if step is not None else None,
|
step=self.parse_expr(step) if step is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case ast.Tuple(elts=items):
|
||||||
|
return TupleExpr(
|
||||||
|
location=location,
|
||||||
|
items=tuple(self.parse_expr(item) for item in items),
|
||||||
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
print(f"Unsupported expression: {ast.unparse(node)}")
|
print(f"Unsupported expression: {ast.unparse(node)}")
|
||||||
return RawExpr(location=location, expr=node)
|
return RawExpr(location=location, expr=node)
|
||||||
@@ -527,16 +540,19 @@ class PythonParser:
|
|||||||
return expr
|
return expr
|
||||||
|
|
||||||
def parse_cast(self, node: ast.Call) -> CastExpr:
|
def parse_cast(self, node: ast.Call) -> CastExpr:
|
||||||
|
assert isinstance(node.func, ast.Name)
|
||||||
|
func: str = node.func.id
|
||||||
match node:
|
match node:
|
||||||
case ast.Call(args=[type, expr], keywords=[]):
|
case ast.Call(args=[type, expr], keywords=[]):
|
||||||
return CastExpr(
|
return CastExpr(
|
||||||
location=Location.from_ast(node),
|
location=Location.from_ast(node),
|
||||||
type=self._parse_type(type),
|
type=self._parse_type(type),
|
||||||
expr=self.parse_expr(expr),
|
expr=self.parse_expr(expr),
|
||||||
|
unsafe=func == self.UNSAFE_CAST_FUNCTION,
|
||||||
)
|
)
|
||||||
case _:
|
case _:
|
||||||
raise InvalidSyntaxError(
|
raise InvalidSyntaxError(
|
||||||
f"Invalid call to {self.CAST_FUNCTION}, expected type and expression"
|
f"Invalid call to {func}, expected type and expression"
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_call(self, node: ast.Call) -> CallExpr:
|
def parse_call(self, node: ast.Call) -> CallExpr:
|
||||||
|
|||||||
52
midas/typing.py
Normal file
52
midas/typing.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
from typing import Generic, TypeVar
|
||||||
|
from typing import cast as typing_cast
|
||||||
|
|
||||||
|
cast = typing_cast
|
||||||
|
"""### Midas documentation
|
||||||
|
Cast a value to a type.
|
||||||
|
|
||||||
|
- **Compile-time**: tells the type checker that the return value has the designated type.
|
||||||
|
- **Run-time**: generates assertions to ensure the value can be interpreted as the given type.
|
||||||
|
|
||||||
|
---
|
||||||
|
<br>
|
||||||
|
<br>
|
||||||
|
<br>
|
||||||
|
|
||||||
|
_**Internal Python documentation**_
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
unsafe_cast = typing_cast
|
||||||
|
"""### Midas documentation
|
||||||
|
Cast a value to a type.
|
||||||
|
|
||||||
|
- **Compile-time**: tells the type checker that the return value has the designated type.
|
||||||
|
- **Run-time**: -
|
||||||
|
|
||||||
|
This operation is unsound, use at your own risk!
|
||||||
|
|
||||||
|
---
|
||||||
|
<br>
|
||||||
|
<br>
|
||||||
|
<br>
|
||||||
|
|
||||||
|
_**Internal Python documentation**_
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class Frame(Generic[T]):
|
||||||
|
"""A `Frame` is the abstract type implemented by `DataFrame`
|
||||||
|
|
||||||
|
A frame contains any number of named columns (see :class:`Column`)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Column(Generic[T]):
|
||||||
|
"""A `Column` is the abstract type implemented by `Series`
|
||||||
|
|
||||||
|
A column contains a any number of values of the same type
|
||||||
|
"""
|
||||||
@@ -3,6 +3,7 @@ from typing import Any, Callable, Optional
|
|||||||
|
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.checker.types import Type
|
from midas.checker.types import Type
|
||||||
|
from midas.generator.collector import AssertionCollector
|
||||||
|
|
||||||
AllowRepeat = Callable[[object], bool]
|
AllowRepeat = Callable[[object], bool]
|
||||||
|
|
||||||
@@ -63,3 +64,4 @@ class TypedAST:
|
|||||||
stmts: list[p.Stmt]
|
stmts: list[p.Stmt]
|
||||||
judgements: list[tuple[p.Expr, Type]]
|
judgements: list[tuple[p.Expr, Type]]
|
||||||
evaluated_casts: list[p.CastExpr]
|
evaluated_casts: list[p.CastExpr]
|
||||||
|
assertions: AssertionCollector
|
||||||
|
|||||||
43
tests/__main__.py
Normal file
43
tests/__main__.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from midas.cli.ansi import Ansi
|
||||||
|
from tests.base import Tester
|
||||||
|
from tests.checker import CheckerTester
|
||||||
|
from tests.generator import GeneratorTester
|
||||||
|
from tests.midas import MidasTester
|
||||||
|
from tests.python import PythonTester
|
||||||
|
|
||||||
|
|
||||||
|
def print_banner(name: str):
|
||||||
|
horizontal: str = "+" + "-" * (len(name) + 2) + "+"
|
||||||
|
print(horizontal)
|
||||||
|
print(f"| {name} |")
|
||||||
|
print(horizontal)
|
||||||
|
|
||||||
|
|
||||||
|
def run_tests(tester_cls: Type[Tester]) -> bool:
|
||||||
|
print_banner(tester_cls.__name__)
|
||||||
|
tester: Tester = tester_cls()
|
||||||
|
success: bool = tester.run_all_tests()
|
||||||
|
print()
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
testers: list[Type[Tester]] = [
|
||||||
|
PythonTester,
|
||||||
|
MidasTester,
|
||||||
|
CheckerTester,
|
||||||
|
GeneratorTester,
|
||||||
|
]
|
||||||
|
|
||||||
|
success: bool = all(map(run_tests, testers))
|
||||||
|
|
||||||
|
if success:
|
||||||
|
print(Ansi.FG(Ansi.BRIGHT_GREEN) + "All tests passed!" + Ansi.RESET)
|
||||||
|
else:
|
||||||
|
print(Ansi.FG(Ansi.BRIGHT_RED) + "Some tests failed!" + Ansi.RESET)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -7,6 +7,8 @@ from abc import ABC, abstractmethod
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator, Protocol
|
from typing import Iterator, Protocol
|
||||||
|
|
||||||
|
from midas.cli.ansi import Ansi
|
||||||
|
|
||||||
|
|
||||||
class CaseResult(Protocol):
|
class CaseResult(Protocol):
|
||||||
def dumps(self) -> str: ...
|
def dumps(self) -> str: ...
|
||||||
@@ -44,8 +46,11 @@ class Tester(ABC):
|
|||||||
|
|
||||||
print(rule)
|
print(rule)
|
||||||
for i, test in enumerate(tests):
|
for i, test in enumerate(tests):
|
||||||
print(f"Case {i+1}/{n}: {test.resolve().relative_to(self.CASES_DIR)}")
|
path: Path = test.resolve().relative_to(self.CASES_DIR)
|
||||||
|
print(f"{Ansi.FG(Ansi.BRIGHT_CYAN)}Case {i+1}/{n}: {path}{Ansi.RESET}")
|
||||||
|
print(Ansi.DIM, end="")
|
||||||
success: bool = self._run_test(test)
|
success: bool = self._run_test(test)
|
||||||
|
print(Ansi.RESET, end="")
|
||||||
if success:
|
if success:
|
||||||
successes += 1
|
successes += 1
|
||||||
else:
|
else:
|
||||||
@@ -146,8 +151,9 @@ class Tester(ABC):
|
|||||||
if not success:
|
if not success:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
case None:
|
case None:
|
||||||
print("No subcommand provided. Available subcommands: run, update")
|
success: bool = tester.run_all_tests()
|
||||||
sys.exit(1)
|
if not success:
|
||||||
|
sys.exit(1)
|
||||||
case _:
|
case _:
|
||||||
print(f"Unknown subcommand '{args.subcommand}'")
|
print(f"Unknown subcommand '{args.subcommand}'")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|||||||
@@ -4,7 +4,35 @@
|
|||||||
"type": "Warning",
|
"type": "Warning",
|
||||||
"location": {
|
"location": {
|
||||||
"start": [
|
"start": [
|
||||||
6,
|
8,
|
||||||
|
12
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
8,
|
||||||
|
43
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "ConstraintType not yet supported"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Warning",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
10,
|
||||||
|
10
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
10,
|
||||||
|
18
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Unknown type 'datetime'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Warning",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
13,
|
||||||
4
|
4
|
||||||
],
|
],
|
||||||
"end": [
|
"end": [
|
||||||
@@ -12,7 +40,7 @@
|
|||||||
5
|
5
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"message": "FrameType not yet supported"
|
"message": "Unknown type '_'"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"judgments": []
|
"judgments": []
|
||||||
|
|||||||
@@ -328,6 +328,19 @@
|
|||||||
},
|
},
|
||||||
"type": {}
|
"type": {}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:9",
|
||||||
|
"to": "L6:10"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L6:5",
|
"from": "L6:5",
|
||||||
@@ -373,19 +386,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L6:9",
|
|
||||||
"to": "L6:10"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L6:5",
|
"from": "L6:5",
|
||||||
@@ -407,6 +407,32 @@
|
|||||||
},
|
},
|
||||||
"type": {}
|
"type": {}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L7:9",
|
||||||
|
"to": "L7:10"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L7:12",
|
||||||
|
"to": "L7:15"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2.0
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L7:5",
|
"from": "L7:5",
|
||||||
@@ -452,32 +478,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L7:9",
|
|
||||||
"to": "L7:10"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L7:12",
|
|
||||||
"to": "L7:15"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 2.0
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L7:5",
|
"from": "L7:5",
|
||||||
@@ -503,6 +503,32 @@
|
|||||||
},
|
},
|
||||||
"type": {}
|
"type": {}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L8:9",
|
||||||
|
"to": "L8:10"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L8:14",
|
||||||
|
"to": "L8:17"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2.0
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L8:5",
|
"from": "L8:5",
|
||||||
@@ -548,32 +574,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L8:9",
|
|
||||||
"to": "L8:10"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L8:14",
|
|
||||||
"to": "L8:17"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 2.0
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L8:5",
|
"from": "L8:5",
|
||||||
@@ -600,6 +600,45 @@
|
|||||||
},
|
},
|
||||||
"type": {}
|
"type": {}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L9:9",
|
||||||
|
"to": "L9:10"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L9:12",
|
||||||
|
"to": "L9:15"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2.0
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L9:17",
|
||||||
|
"to": "L9:23"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": "test"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "str"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L9:5",
|
"from": "L9:5",
|
||||||
@@ -645,45 +684,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L9:9",
|
|
||||||
"to": "L9:10"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L9:12",
|
|
||||||
"to": "L9:15"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 2.0
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L9:17",
|
|
||||||
"to": "L9:23"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": "test"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "str"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L9:5",
|
"from": "L9:5",
|
||||||
@@ -713,6 +713,45 @@
|
|||||||
},
|
},
|
||||||
"type": {}
|
"type": {}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L10:9",
|
||||||
|
"to": "L10:10"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L10:12",
|
||||||
|
"to": "L10:15"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2.0
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L10:19",
|
||||||
|
"to": "L10:22"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 3.0
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L10:5",
|
"from": "L10:5",
|
||||||
@@ -758,45 +797,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L10:9",
|
|
||||||
"to": "L10:10"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L10:12",
|
|
||||||
"to": "L10:15"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 2.0
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L10:19",
|
|
||||||
"to": "L10:22"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 3.0
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L10:5",
|
"from": "L10:5",
|
||||||
@@ -827,6 +827,19 @@
|
|||||||
},
|
},
|
||||||
"type": {}
|
"type": {}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L11:11",
|
||||||
|
"to": "L11:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L11:5",
|
"from": "L11:5",
|
||||||
@@ -872,19 +885,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L11:11",
|
|
||||||
"to": "L11:12"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L11:5",
|
"from": "L11:5",
|
||||||
@@ -906,6 +906,19 @@
|
|||||||
},
|
},
|
||||||
"type": {}
|
"type": {}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L12:11",
|
||||||
|
"to": "L12:17"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": "test"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "str"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L12:5",
|
"from": "L12:5",
|
||||||
@@ -951,19 +964,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L12:11",
|
|
||||||
"to": "L12:17"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": "test"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "str"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L12:5",
|
"from": "L12:5",
|
||||||
@@ -985,6 +985,45 @@
|
|||||||
},
|
},
|
||||||
"type": {}
|
"type": {}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L14:10",
|
||||||
|
"to": "L14:11"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L14:13",
|
||||||
|
"to": "L14:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2.0
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L14:20",
|
||||||
|
"to": "L14:26"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": "test"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "str"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L14:6",
|
"from": "L14:6",
|
||||||
@@ -1030,45 +1069,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L14:10",
|
|
||||||
"to": "L14:11"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L14:13",
|
|
||||||
"to": "L14:16"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 2.0
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L14:20",
|
|
||||||
"to": "L14:26"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": "test"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "str"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L14:6",
|
"from": "L14:6",
|
||||||
@@ -1101,6 +1101,45 @@
|
|||||||
"name": "bool"
|
"name": "bool"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L15:10",
|
||||||
|
"to": "L15:11"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L15:15",
|
||||||
|
"to": "L15:18"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2.0
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L15:22",
|
||||||
|
"to": "L15:28"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": "test"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "str"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L15:6",
|
"from": "L15:6",
|
||||||
@@ -1146,45 +1185,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L15:10",
|
|
||||||
"to": "L15:11"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L15:15",
|
|
||||||
"to": "L15:18"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 2.0
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L15:22",
|
|
||||||
"to": "L15:28"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": "test"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "str"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L15:6",
|
"from": "L15:6",
|
||||||
@@ -1217,6 +1217,45 @@
|
|||||||
"name": "bool"
|
"name": "bool"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L16:10",
|
||||||
|
"to": "L16:11"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L16:15",
|
||||||
|
"to": "L16:21"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": "test"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "str"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L16:25",
|
||||||
|
"to": "L16:28"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2.0
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L16:6",
|
"from": "L16:6",
|
||||||
@@ -1262,45 +1301,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L16:10",
|
|
||||||
"to": "L16:11"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L16:15",
|
|
||||||
"to": "L16:21"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": "test"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "str"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L16:25",
|
|
||||||
"to": "L16:28"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 2.0
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L16:6",
|
"from": "L16:6",
|
||||||
@@ -1333,6 +1333,45 @@
|
|||||||
"name": "bool"
|
"name": "bool"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L18:10",
|
||||||
|
"to": "L18:13"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "str"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L18:15",
|
||||||
|
"to": "L18:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 3
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L18:20",
|
||||||
|
"to": "L18:25"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": false
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L18:6",
|
"from": "L18:6",
|
||||||
@@ -1378,45 +1417,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L18:10",
|
|
||||||
"to": "L18:13"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": "a"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "str"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L18:15",
|
|
||||||
"to": "L18:16"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 3
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L18:20",
|
|
||||||
"to": "L18:25"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": false
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "bool"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L18:6",
|
"from": "L18:6",
|
||||||
|
|||||||
@@ -24,12 +24,13 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Meter",
|
"base": "Meter",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"expr": {
|
"expr": {
|
||||||
"_type": "LiteralExpr",
|
"_type": "LiteralExpr",
|
||||||
"value": 123.45
|
"value": 123.45
|
||||||
}
|
},
|
||||||
|
"unsafe": false
|
||||||
},
|
},
|
||||||
"type": {
|
"type": {
|
||||||
"name": "Meter",
|
"name": "Meter",
|
||||||
@@ -61,12 +62,13 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Second",
|
"base": "Second",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"expr": {
|
"expr": {
|
||||||
"_type": "LiteralExpr",
|
"_type": "LiteralExpr",
|
||||||
"value": 6.7
|
"value": 6.7
|
||||||
}
|
},
|
||||||
|
"unsafe": false
|
||||||
},
|
},
|
||||||
"type": {
|
"type": {
|
||||||
"name": "Second",
|
"name": "Second",
|
||||||
|
|||||||
@@ -100,6 +100,32 @@
|
|||||||
"name": "float"
|
"name": "float"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L11:13",
|
||||||
|
"to": "L11:15"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v1"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L11:17",
|
||||||
|
"to": "L11:19"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v2"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L11:5",
|
"from": "L11:5",
|
||||||
@@ -135,32 +161,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L11:13",
|
|
||||||
"to": "L11:15"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "VariableExpr",
|
|
||||||
"name": "v1"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L11:17",
|
|
||||||
"to": "L11:19"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "VariableExpr",
|
|
||||||
"name": "v2"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L11:5",
|
"from": "L11:5",
|
||||||
|
|||||||
@@ -72,29 +72,6 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"judgments": [
|
"judgments": [
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L26:0",
|
|
||||||
"to": "L26:5"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "VariableExpr",
|
|
||||||
"name": "print"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"pos_args": [
|
|
||||||
{
|
|
||||||
"pos": 0,
|
|
||||||
"name": "object",
|
|
||||||
"type": {},
|
|
||||||
"required": true
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"args": [],
|
|
||||||
"kw_args": [],
|
|
||||||
"returns": {}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L27:4",
|
"from": "L27:4",
|
||||||
@@ -325,6 +302,29 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L26:0",
|
||||||
|
"to": "L26:5"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "print"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"pos_args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "object",
|
||||||
|
"type": {},
|
||||||
|
"required": false
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"args": [],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {}
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L26:0",
|
"from": "L26:0",
|
||||||
|
|||||||
@@ -63,31 +63,6 @@
|
|||||||
"name": "float"
|
"name": "float"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L6:11",
|
|
||||||
"to": "L6:15"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "VariableExpr",
|
|
||||||
"name": "bool"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"pos_args": [
|
|
||||||
{
|
|
||||||
"pos": 0,
|
|
||||||
"name": "object",
|
|
||||||
"type": {},
|
|
||||||
"required": false
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"args": [],
|
|
||||||
"kw_args": [],
|
|
||||||
"returns": {
|
|
||||||
"name": "bool"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L6:16",
|
"from": "L6:16",
|
||||||
@@ -135,6 +110,31 @@
|
|||||||
"name": "int"
|
"name": "int"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:11",
|
||||||
|
"to": "L6:15"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "bool"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"pos_args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "object",
|
||||||
|
"type": {},
|
||||||
|
"required": false
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"args": [],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L6:11",
|
"from": "L6:11",
|
||||||
@@ -367,6 +367,54 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L12:21",
|
||||||
|
"to": "L12:27"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "double"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"pos_args": [],
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "value",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L12:29",
|
||||||
|
"to": "L12:35"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "floats"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "list",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "list"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L12:17",
|
"from": "L12:17",
|
||||||
@@ -455,54 +503,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L12:21",
|
|
||||||
"to": "L12:27"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "VariableExpr",
|
|
||||||
"name": "double"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"pos_args": [],
|
|
||||||
"args": [
|
|
||||||
{
|
|
||||||
"pos": 0,
|
|
||||||
"name": "value",
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
},
|
|
||||||
"required": true
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"kw_args": [],
|
|
||||||
"returns": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L12:29",
|
|
||||||
"to": "L12:35"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "VariableExpr",
|
|
||||||
"name": "floats"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "list",
|
|
||||||
"args": [
|
|
||||||
{
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"body": {
|
|
||||||
"name": "list"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L12:17",
|
"from": "L12:17",
|
||||||
@@ -538,6 +538,54 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L13:19",
|
||||||
|
"to": "L13:25"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "double"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"pos_args": [],
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "value",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L13:27",
|
||||||
|
"to": "L13:31"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "ints"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "list",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "list"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L13:15",
|
"from": "L13:15",
|
||||||
@@ -626,54 +674,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L13:19",
|
|
||||||
"to": "L13:25"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "VariableExpr",
|
|
||||||
"name": "double"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"pos_args": [],
|
|
||||||
"args": [
|
|
||||||
{
|
|
||||||
"pos": 0,
|
|
||||||
"name": "value",
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
},
|
|
||||||
"required": true
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"kw_args": [],
|
|
||||||
"returns": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L13:27",
|
|
||||||
"to": "L13:31"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "VariableExpr",
|
|
||||||
"name": "ints"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "list",
|
|
||||||
"args": [
|
|
||||||
{
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"body": {
|
|
||||||
"name": "list"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L13:15",
|
"from": "L13:15",
|
||||||
@@ -699,6 +699,54 @@
|
|||||||
},
|
},
|
||||||
"type": {}
|
"type": {}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L14:15",
|
||||||
|
"to": "L14:21"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "is_odd"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"pos_args": [],
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "value",
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L14:23",
|
||||||
|
"to": "L14:27"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "ints"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "list",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "list"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L14:11",
|
"from": "L14:11",
|
||||||
@@ -787,54 +835,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L14:15",
|
|
||||||
"to": "L14:21"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "VariableExpr",
|
|
||||||
"name": "is_odd"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"pos_args": [],
|
|
||||||
"args": [
|
|
||||||
{
|
|
||||||
"pos": 0,
|
|
||||||
"name": "value",
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
},
|
|
||||||
"required": true
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"kw_args": [],
|
|
||||||
"returns": {
|
|
||||||
"name": "bool"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L14:23",
|
|
||||||
"to": "L14:27"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "VariableExpr",
|
|
||||||
"name": "ints"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "list",
|
|
||||||
"args": [
|
|
||||||
{
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"body": {
|
|
||||||
"name": "list"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L14:11",
|
"from": "L14:11",
|
||||||
|
|||||||
117
tests/cases/checker/09_frame_ops.py
Normal file
117
tests/cases/checker/09_frame_ops.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
# type: ignore
|
||||||
|
# ruff: disable [F821]
|
||||||
|
|
||||||
|
df1: Frame[a:int, b:float]
|
||||||
|
df2: Frame[a:int, b:float]
|
||||||
|
|
||||||
|
_: Any
|
||||||
|
|
||||||
|
# Arithmetic
|
||||||
|
_ = df1 + df2
|
||||||
|
_ = df1 - df2
|
||||||
|
_ = df1 * df2
|
||||||
|
_ = df1 / df2
|
||||||
|
_ = df1 // df2
|
||||||
|
_ = df1 % df2
|
||||||
|
_ = df1**df2
|
||||||
|
|
||||||
|
# Comparisons
|
||||||
|
_ = df1 < df2
|
||||||
|
_ = df1 > df2
|
||||||
|
_ = df1 <= df2
|
||||||
|
_ = df1 >= df2
|
||||||
|
_ = df1 != df2
|
||||||
|
_ = df1 == df2
|
||||||
|
|
||||||
|
# Aggregate
|
||||||
|
_ = df1.kurt()
|
||||||
|
_ = df1.kurtosis()
|
||||||
|
_ = df1.max()
|
||||||
|
_ = df1.mean()
|
||||||
|
_ = df1.median()
|
||||||
|
_ = df1.min()
|
||||||
|
_ = df1.mode()
|
||||||
|
_ = df1.prod()
|
||||||
|
_ = df1.product()
|
||||||
|
_ = df1.std()
|
||||||
|
_ = df1.sum()
|
||||||
|
_ = df1.var()
|
||||||
|
|
||||||
|
# Groupby
|
||||||
|
df_gb = df1.groupby(by="a")
|
||||||
|
|
||||||
|
_ = df_gb.kurt()
|
||||||
|
_ = df_gb.max()
|
||||||
|
_ = df_gb.mean()
|
||||||
|
_ = df_gb.median()
|
||||||
|
_ = df_gb.min()
|
||||||
|
_ = df_gb.prod()
|
||||||
|
_ = df_gb.std()
|
||||||
|
_ = df_gb.sum()
|
||||||
|
_ = df_gb.var()
|
||||||
|
|
||||||
|
|
||||||
|
# Columns
|
||||||
|
|
||||||
|
col1 = df1["a"]
|
||||||
|
col2 = df1["a"]
|
||||||
|
|
||||||
|
# Arithmetic
|
||||||
|
_ = col1 + col2
|
||||||
|
_ = col1 - col2
|
||||||
|
_ = col1 * col2
|
||||||
|
_ = col1 / col2
|
||||||
|
_ = col1 // col2
|
||||||
|
_ = col1 % col2
|
||||||
|
_ = col1**col2
|
||||||
|
|
||||||
|
# Comparisons
|
||||||
|
_ = col1 < col2
|
||||||
|
_ = col1 > col2
|
||||||
|
_ = col1 <= col2
|
||||||
|
_ = col1 >= col2
|
||||||
|
_ = col1 != col2
|
||||||
|
_ = col1 == col2
|
||||||
|
|
||||||
|
# Aggregate
|
||||||
|
_ = col1.kurt()
|
||||||
|
_ = col1.kurtosis()
|
||||||
|
_ = col1.max()
|
||||||
|
_ = col1.mean()
|
||||||
|
_ = col1.median()
|
||||||
|
_ = col1.min()
|
||||||
|
_ = col1.mode()
|
||||||
|
_ = col1.prod()
|
||||||
|
_ = col1.product()
|
||||||
|
_ = col1.std()
|
||||||
|
_ = col1.sum()
|
||||||
|
_ = col1.var()
|
||||||
|
|
||||||
|
# Groupby
|
||||||
|
col_gb = col1.groupby(level=0)
|
||||||
|
|
||||||
|
_ = col_gb.kurt()
|
||||||
|
_ = col_gb.max()
|
||||||
|
_ = col_gb.mean()
|
||||||
|
_ = col_gb.median()
|
||||||
|
_ = col_gb.min()
|
||||||
|
_ = col_gb.prod()
|
||||||
|
_ = col_gb.std()
|
||||||
|
_ = col_gb.sum()
|
||||||
|
_ = col_gb.var()
|
||||||
|
|
||||||
|
# Attributes
|
||||||
|
_ = df1.ndim # int
|
||||||
|
_ = df1.size # int
|
||||||
|
_ = df1.shape # (int, int)
|
||||||
|
_ = col1.ndim # int
|
||||||
|
_ = col1.size # int
|
||||||
|
_ = col1.shape # (int)
|
||||||
|
_ = col1.T # Column[int]
|
||||||
|
|
||||||
|
|
||||||
|
# Misc
|
||||||
|
_ = df1.head()
|
||||||
|
_ = df1.tail()
|
||||||
|
_ = col1.head()
|
||||||
|
_ = col1.tail()
|
||||||
4924
tests/cases/checker/09_frame_ops.py.ref.json
Normal file
4924
tests/cases/checker/09_frame_ops.py.ref.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -16,7 +16,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "bool",
|
"base": "bool",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -25,7 +25,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "int",
|
"base": "int",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -36,7 +36,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "float",
|
"base": "float",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"constraint": "(_ > 0) + (_ < 250)"
|
"constraint": "(_ > 0) + (_ < 250)"
|
||||||
}
|
}
|
||||||
@@ -47,7 +47,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "str",
|
"base": "str",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -56,7 +56,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "datetime",
|
"base": "datetime",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -65,7 +65,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "float",
|
"base": "float",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -79,7 +79,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "_",
|
"base": "_",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "GeoLocation",
|
"base": "GeoLocation",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -28,11 +28,13 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Column",
|
"base": "Column",
|
||||||
"param": {
|
"args": [
|
||||||
"_type": "BaseType",
|
{
|
||||||
"base": "GeoLocation",
|
"_type": "BaseType",
|
||||||
"param": null
|
"base": "GeoLocation",
|
||||||
}
|
"args": []
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -65,11 +67,13 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Column",
|
"base": "Column",
|
||||||
"param": {
|
"args": [
|
||||||
"_type": "BaseType",
|
{
|
||||||
"base": "GeoLocation",
|
"_type": "BaseType",
|
||||||
"param": null
|
"base": "GeoLocation",
|
||||||
}
|
"args": []
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -117,7 +121,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Latitude",
|
"base": "Latitude",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -146,7 +150,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Latitude",
|
"base": "Latitude",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -175,11 +179,13 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Difference",
|
"base": "Difference",
|
||||||
"param": {
|
"args": [
|
||||||
"_type": "BaseType",
|
{
|
||||||
"base": "Latitude",
|
"_type": "BaseType",
|
||||||
"param": null
|
"base": "Latitude",
|
||||||
}
|
"args": []
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -217,7 +223,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "int",
|
"base": "int",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"constraint": "_ >= 0"
|
"constraint": "_ >= 0"
|
||||||
}
|
}
|
||||||
@@ -230,7 +236,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "float",
|
"base": "float",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"constraint": "_ >= 0"
|
"constraint": "_ >= 0"
|
||||||
}
|
}
|
||||||
@@ -252,7 +258,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "int",
|
"base": "int",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"constraint": "Positive"
|
"constraint": "Positive"
|
||||||
}
|
}
|
||||||
@@ -265,7 +271,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "float",
|
"base": "float",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"constraint": "Positive"
|
"constraint": "Positive"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,15 +14,17 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Column",
|
"base": "Column",
|
||||||
"param": {
|
"args": [
|
||||||
"_type": "ConstraintType",
|
{
|
||||||
"type": {
|
"_type": "ConstraintType",
|
||||||
"_type": "BaseType",
|
"type": {
|
||||||
"base": "float",
|
"_type": "BaseType",
|
||||||
"param": null
|
"base": "float",
|
||||||
},
|
"args": []
|
||||||
"constraint": "0 <= _ <= 1"
|
},
|
||||||
}
|
"constraint": "0 <= _ <= 1"
|
||||||
|
}
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"default": null
|
"default": null
|
||||||
},
|
},
|
||||||
@@ -31,15 +33,17 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Column",
|
"base": "Column",
|
||||||
"param": {
|
"args": [
|
||||||
"_type": "ConstraintType",
|
{
|
||||||
"type": {
|
"_type": "ConstraintType",
|
||||||
"_type": "BaseType",
|
"type": {
|
||||||
"base": "float",
|
"_type": "BaseType",
|
||||||
"param": null
|
"base": "float",
|
||||||
},
|
"args": []
|
||||||
"constraint": "0 <= _ <= 1"
|
},
|
||||||
}
|
"constraint": "0 <= _ <= 1"
|
||||||
|
}
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"default": null
|
"default": null
|
||||||
}
|
}
|
||||||
@@ -50,15 +54,17 @@
|
|||||||
"returns": {
|
"returns": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Column",
|
"base": "Column",
|
||||||
"param": {
|
"args": [
|
||||||
"_type": "ConstraintType",
|
{
|
||||||
"type": {
|
"_type": "ConstraintType",
|
||||||
"_type": "BaseType",
|
"type": {
|
||||||
"base": "float",
|
"_type": "BaseType",
|
||||||
"param": null
|
"base": "float",
|
||||||
},
|
"args": []
|
||||||
"constraint": "0 <= _ <= 2"
|
},
|
||||||
}
|
"constraint": "0 <= _ <= 2"
|
||||||
|
}
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"body": [
|
"body": [
|
||||||
{
|
{
|
||||||
@@ -67,15 +73,17 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Column",
|
"base": "Column",
|
||||||
"param": {
|
"args": [
|
||||||
"_type": "ConstraintType",
|
{
|
||||||
"type": {
|
"_type": "ConstraintType",
|
||||||
"_type": "BaseType",
|
"type": {
|
||||||
"base": "float",
|
"_type": "BaseType",
|
||||||
"param": null
|
"base": "float",
|
||||||
},
|
"args": []
|
||||||
"constraint": "0 <= _ <= 2"
|
},
|
||||||
}
|
"constraint": "0 <= _ <= 2"
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -117,7 +125,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "int",
|
"base": "int",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"default": null
|
"default": null
|
||||||
}
|
}
|
||||||
@@ -128,7 +136,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "float",
|
"base": "float",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"default": null
|
"default": null
|
||||||
}
|
}
|
||||||
@@ -140,7 +148,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "str",
|
"base": "str",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"default": null
|
"default": null
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,7 +46,8 @@ class GeneratorTester(Tester):
|
|||||||
|
|
||||||
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
|
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
|
||||||
generator = Generator(workdir=path.parent, types=checker.types)
|
generator = Generator(workdir=path.parent, types=checker.types)
|
||||||
result.compiled_ast = generator.generate_ast(typed_ast, path)
|
generator.set_src_path(path)
|
||||||
|
result.compiled_ast = generator.generate_ast(typed_ast)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from typing import Optional, Sequence
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
from midas.ast.midas import (
|
from midas.ast.midas import (
|
||||||
|
AliasStmt,
|
||||||
BinaryExpr,
|
BinaryExpr,
|
||||||
CallExpr,
|
CallExpr,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
@@ -8,6 +9,7 @@ from midas.ast.midas import (
|
|||||||
Expr,
|
Expr,
|
||||||
ExtendStmt,
|
ExtendStmt,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
|
FrameType,
|
||||||
FunctionType,
|
FunctionType,
|
||||||
GenericType,
|
GenericType,
|
||||||
GetExpr,
|
GetExpr,
|
||||||
@@ -60,6 +62,13 @@ class MidasAstJsonSerializer(
|
|||||||
"bound": self._serialize_optional(param.bound),
|
"bound": self._serialize_optional(param.bound),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def visit_alias_stmt(self, stmt: AliasStmt) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "AliasStmt",
|
||||||
|
"name": stmt.name.lexeme,
|
||||||
|
"type": stmt.type.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
def visit_member_stmt(self, stmt: MemberStmt) -> dict:
|
def visit_member_stmt(self, stmt: MemberStmt) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "MemberStmt",
|
"_type": "MemberStmt",
|
||||||
@@ -197,3 +206,15 @@ class MidasAstJsonSerializer(
|
|||||||
"base": type.base.accept(self),
|
"base": type.base.accept(self),
|
||||||
"extension": type.extension.accept(self),
|
"extension": type.extension.accept(self),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def visit_frame_type(self, type: FrameType) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "FrameType",
|
||||||
|
"columns": [self._serialize_column(col) for col in type.columns],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _serialize_column(self, column: FrameType.Column):
|
||||||
|
return {
|
||||||
|
"name": column.name.lexeme,
|
||||||
|
"type": column.type.accept(self),
|
||||||
|
}
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from midas.ast.python import (
|
|||||||
Stmt,
|
Stmt,
|
||||||
SubscriptExpr,
|
SubscriptExpr,
|
||||||
TernaryExpr,
|
TernaryExpr,
|
||||||
|
TupleExpr,
|
||||||
TypeAssign,
|
TypeAssign,
|
||||||
UnaryExpr,
|
UnaryExpr,
|
||||||
VariableExpr,
|
VariableExpr,
|
||||||
@@ -98,7 +99,7 @@ class PythonAstJsonSerializer(
|
|||||||
return {
|
return {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": node.base,
|
"base": node.base,
|
||||||
"param": self._serialize_optional(node.param),
|
"args": self._serialize_list(node.args),
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_constraint_type(self, node: ConstraintType) -> dict:
|
def visit_constraint_type(self, node: ConstraintType) -> dict:
|
||||||
@@ -263,6 +264,7 @@ class PythonAstJsonSerializer(
|
|||||||
"_type": "CastExpr",
|
"_type": "CastExpr",
|
||||||
"type": expr.type.accept(self),
|
"type": expr.type.accept(self),
|
||||||
"expr": expr.expr.accept(self),
|
"expr": expr.expr.accept(self),
|
||||||
|
"unsafe": expr.unsafe,
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: TernaryExpr) -> dict:
|
def visit_ternary_expr(self, expr: TernaryExpr) -> dict:
|
||||||
@@ -301,6 +303,12 @@ class PythonAstJsonSerializer(
|
|||||||
"step": self._serialize_optional(expr.step),
|
"step": self._serialize_optional(expr.step),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def visit_tuple_expr(self, expr: TupleExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "TupleExpr",
|
||||||
|
"items": [item.accept(self) for item in expr.items],
|
||||||
|
}
|
||||||
|
|
||||||
def visit_raw_expr(self, expr: RawExpr) -> dict:
|
def visit_raw_expr(self, expr: RawExpr) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "RawExpr",
|
"_type": "RawExpr",
|
||||||
|
|||||||
Reference in New Issue
Block a user