Compare commits
434 Commits
v0.0.1-pro
...
feat/add-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
9764484fd9
|
|||
|
5b9e322c91
|
|||
|
c18d9c18de
|
|||
|
9229f00375
|
|||
|
6b7a682dc5
|
|||
|
35b97fd17b
|
|||
| 03bc32400b | |||
|
4a93ee45d9
|
|||
|
8197131d8d
|
|||
|
cf91187b7a
|
|||
|
1b2bdf0b79
|
|||
| c6cc38bfeb | |||
|
4d3e3f44a1
|
|||
|
ec80b1e92e
|
|||
|
4ea15519f3
|
|||
|
7a6e01cff8
|
|||
|
733c8736b8
|
|||
|
20173a0b07
|
|||
|
a143972ef1
|
|||
|
0c70048b62
|
|||
|
1c0c917873
|
|||
|
1f6189daa4
|
|||
|
66b585c3d6
|
|||
|
819ab3c2bf
|
|||
|
d8c0b17512
|
|||
|
6e06f9078e
|
|||
|
ece2e3a6a3
|
|||
|
74c07c9afb
|
|||
|
be2fd4c837
|
|||
|
1bc4c704c3
|
|||
|
0288a05901
|
|||
|
b14f46d405
|
|||
|
8e8ed62266
|
|||
|
2fce2f4bfc
|
|||
|
640f2d1771
|
|||
|
b48dfe5301
|
|||
|
0d5840a4ce
|
|||
|
3c92f0867d
|
|||
|
b5acae4078
|
|||
|
5d20f8ec3e
|
|||
|
955c2233ed
|
|||
|
ff69b65171
|
|||
|
8df01afd8c
|
|||
|
47b2dfdd73
|
|||
|
bd4d793ce0
|
|||
|
f7a36f61b6
|
|||
|
ad2fabf471
|
|||
|
a59a58d21a
|
|||
| 3260ae4a1e | |||
|
bd1c9581c7
|
|||
|
663642ea6c
|
|||
|
e2abc04fe4
|
|||
|
a4016b55ce
|
|||
|
1ea5da7024
|
|||
|
a017a8cf1f
|
|||
|
8fc5ab623e
|
|||
|
14007db846
|
|||
|
6ad2ce4b68
|
|||
|
9a276c34c7
|
|||
|
6e717a3f9e
|
|||
|
77aadfa264
|
|||
| c81287df7f | |||
|
ffccc1bedd
|
|||
|
d14f208897
|
|||
|
293953a078
|
|||
|
bccc96e4d0
|
|||
|
9db56adf56
|
|||
|
3f99563ac8
|
|||
|
b36896cc7b
|
|||
|
cb75878ae9
|
|||
|
a5fe985eb2
|
|||
|
e324f414e6
|
|||
|
256536562f
|
|||
|
64f4314f0d
|
|||
|
6f6245d283
|
|||
|
3392bc347d
|
|||
|
7e0319906a
|
|||
|
75bd203d4a
|
|||
|
db40198357
|
|||
|
d79e1dee18
|
|||
|
4ea400265c
|
|||
|
24bffdabd4
|
|||
|
d7bb6326de
|
|||
|
dbf6f9e2db
|
|||
|
3cdc9031d3
|
|||
|
00e2ca8fe3
|
|||
|
4efb01285c
|
|||
|
f84a19159f
|
|||
|
946b2e0d2e
|
|||
|
08dd7408ec
|
|||
|
b33fadf768
|
|||
|
7219109e5d
|
|||
|
cdf1725c26
|
|||
|
7074b074bc
|
|||
|
ede7272c09
|
|||
|
87d5e286d2
|
|||
|
c91b206791
|
|||
|
a31d295eb1
|
|||
|
0d20993f02
|
|||
|
5357ca8e58
|
|||
|
556765fd35
|
|||
| d039a8e4b3 | |||
|
c4533421eb
|
|||
|
73769b42c1
|
|||
|
087f6b4669
|
|||
| d582df5927 | |||
|
6a0401833c
|
|||
|
e15607b763
|
|||
|
e28f324a85
|
|||
|
31e696c938
|
|||
|
759b416bf3
|
|||
|
4b2b0fe476
|
|||
|
4c39504750
|
|||
|
f9f3ade6c7
|
|||
|
386018b956
|
|||
|
bd47d33355
|
|||
|
93ddb28802
|
|||
| f7c43837b5 | |||
|
32ed62a6f1
|
|||
|
66f39acec0
|
|||
|
6c04e2fee4
|
|||
| 2bb2e0a684 | |||
|
5630320d21
|
|||
|
9f05ba3224
|
|||
|
5fbe965919
|
|||
| 252a5abdfd | |||
|
55fba6a088
|
|||
|
70ce263ea2
|
|||
|
e1d5eac8b8
|
|||
|
82666a4918
|
|||
|
45f84a2f23
|
|||
|
dedfcb4dbb
|
|||
|
d9ea6365ea
|
|||
| 9c7a93412c | |||
|
d6b8fbfb60
|
|||
|
b290c59ac4
|
|||
|
093f2bc477
|
|||
|
7c771c4070
|
|||
|
a50a207385
|
|||
|
7e5ea5e414
|
|||
|
0ba0266bae
|
|||
|
216c80f08c
|
|||
|
f75d7722a1
|
|||
|
2f29c47274
|
|||
|
80af2b9048
|
|||
|
577454ee7e
|
|||
|
878693383e
|
|||
|
0b91de75a8
|
|||
| 739871c101 | |||
|
4395e9339b
|
|||
|
29e601128d
|
|||
|
b591f5508f
|
|||
|
41d0c84bbe
|
|||
| cccf2f8f9f | |||
|
3f48c2138f
|
|||
|
e4ab27673d
|
|||
|
b02ecc6326
|
|||
|
9e83079910
|
|||
|
ec468dd982
|
|||
|
3edc25d778
|
|||
|
451e54b009
|
|||
|
0dc14f67aa
|
|||
|
ff79f25628
|
|||
| 12782dda1e | |||
|
48a20b4aa0
|
|||
|
9467187313
|
|||
|
cd8f14153d
|
|||
| 6eea0c02e0 | |||
|
3205e7b961
|
|||
|
0aba134290
|
|||
|
1f0bcab2ca
|
|||
|
db8d88ef35
|
|||
|
7695d50537
|
|||
|
8461d05fa6
|
|||
|
43d2118db7
|
|||
|
6a87b5396f
|
|||
|
e6a581ba6e
|
|||
|
2a7aac69ed
|
|||
|
eb5bf19c61
|
|||
|
657406ea01
|
|||
|
2974386110
|
|||
|
92ca6b6732
|
|||
|
6aacdb98b7
|
|||
|
1b100b6ceb
|
|||
|
6b4c7d27bc
|
|||
|
2523d638f7
|
|||
|
5fc7461e29
|
|||
|
c5154bde81
|
|||
|
d07e8ac0ca
|
|||
|
3380995082
|
|||
|
7efc44c496
|
|||
|
ca94443699
|
|||
|
c513a85cf2
|
|||
|
2a106c5d07
|
|||
| 9672dfd588 | |||
|
7639ccc94d
|
|||
| a4a2ed5d64 | |||
|
e5cb90aff6
|
|||
|
75f8e4af53
|
|||
|
42c2d7a098
|
|||
| 5ce3b4abed | |||
|
2a8b7d559c
|
|||
|
da38cad23d
|
|||
|
591012d059
|
|||
|
4b1087d6b9
|
|||
|
732f7b0796
|
|||
|
c4062c9595
|
|||
|
c3229b557c
|
|||
|
0a8e0fb6c2
|
|||
|
61514d036c
|
|||
| 2e5cf6f8a2 | |||
|
25fabdd6c3
|
|||
|
af1aba41e7
|
|||
|
48e13d3348
|
|||
|
faa98ce0ef
|
|||
|
274e366561
|
|||
| 119c262da4 | |||
|
81181891c4
|
|||
|
59c1a0c7b6
|
|||
|
74f51f361a
|
|||
|
f25341b1e7
|
|||
|
3281324caf
|
|||
| 5b062b46e6 | |||
|
635bf73531
|
|||
|
bd0421b5d8
|
|||
|
37a464d2bc
|
|||
|
1eedcff5aa
|
|||
|
35798e5752
|
|||
|
0a35563aaf
|
|||
|
e1da87eaa0
|
|||
|
2a579c06b1
|
|||
|
46a22797b6
|
|||
|
7598681729
|
|||
|
2df0380815
|
|||
|
178e24cd02
|
|||
|
c92b6b5c18
|
|||
|
6577241af9
|
|||
|
1c71badf24
|
|||
|
064702fe13
|
|||
|
890e2f035a
|
|||
|
0d0115534b
|
|||
|
221b5ca926
|
|||
|
9a227b6d4c
|
|||
|
df2e609c60
|
|||
|
3ee1161680
|
|||
|
eb223c6cb7
|
|||
|
6f5d971c66
|
|||
|
109c8eb35a
|
|||
|
99924ee6c2
|
|||
|
4c9cbd9faa
|
|||
|
84a5f41e62
|
|||
|
6d6bb66c54
|
|||
|
50eaafc388
|
|||
|
2935c71366
|
|||
|
52981f12f2
|
|||
|
2e898ab1e9
|
|||
|
01ff5ca8d5
|
|||
|
b5de28e291
|
|||
|
179b88bfed
|
|||
|
b3665c6462
|
|||
|
42284704de
|
|||
|
650f60e70c
|
|||
|
efea1b29e7
|
|||
|
ae0bd75f3b
|
|||
|
d9100d8300
|
|||
|
900be47d34
|
|||
|
3d5f97a0f4
|
|||
|
9fde115016
|
|||
|
f8897dd075
|
|||
|
380753ca7a
|
|||
|
4715318913
|
|||
|
a78aee1639
|
|||
|
3581b7600b
|
|||
|
32207c3d6f
|
|||
|
9474a7336a
|
|||
|
5a6a279eaf
|
|||
|
c1f95edc96
|
|||
|
098bbc35c5
|
|||
|
314d4d344b
|
|||
|
7236749bd5
|
|||
|
2ff1f27614
|
|||
|
111afe4dd4
|
|||
|
c4c142482a
|
|||
|
f9c15abaf4
|
|||
|
d51d24f865
|
|||
|
1d00875a8c
|
|||
|
f89722fad8
|
|||
| 27917496c1 | |||
|
e0179bc442
|
|||
|
e665d03533
|
|||
|
b8cb2b4273
|
|||
|
d278dc5f5b
|
|||
|
59e73f0fd9
|
|||
|
3e0dc60283
|
|||
|
c24eb5125e
|
|||
|
25bd895dde
|
|||
|
bccd75317e
|
|||
|
f0e3f7574f
|
|||
|
5d44081847
|
|||
|
2a2bb0aec7
|
|||
|
67c40a3909
|
|||
|
1c30188122
|
|||
|
82a0f13242
|
|||
| 288d15a9bc | |||
|
504703d0f7
|
|||
|
e48895d0af
|
|||
| 13d32d0d27 | |||
| 19b9fdd623 | |||
|
ddcaebb51a
|
|||
|
f182312cd2
|
|||
|
73b21789d5
|
|||
|
5d7c724bc8
|
|||
|
74b297c89c
|
|||
|
822a74acce
|
|||
|
9a934fabfd
|
|||
|
828ec9a3fa
|
|||
|
63a43d79dd
|
|||
|
029caf4526
|
|||
|
1c5c418f1c
|
|||
|
a4139d4652
|
|||
|
2fd2071d40
|
|||
|
97b1ee8ab8
|
|||
|
dee479def5
|
|||
|
c8536e20d2
|
|||
|
d70137775f
|
|||
|
35ceda99aa
|
|||
|
7f3d74ee49
|
|||
|
b9f378de6f
|
|||
|
ccb17c7290
|
|||
|
505779310a
|
|||
|
bea3f399ad
|
|||
|
55060bfecd
|
|||
|
dd126f2559
|
|||
|
4151f5373d
|
|||
|
bd31713ab4
|
|||
|
f4dc57cb96
|
|||
|
261fd47494
|
|||
|
1b66a8553d
|
|||
|
65164abadb
|
|||
|
9d45163d9c
|
|||
|
ab0fa1de1a
|
|||
|
5d4df7978b
|
|||
|
86ad348b99
|
|||
|
29f691e38a
|
|||
|
f2c61d24e2
|
|||
|
112ed0e816
|
|||
|
7eb1e13b70
|
|||
|
893e1ba190
|
|||
|
1a1b0e8e15
|
|||
|
4ddde364ed
|
|||
|
4a3363a3d6
|
|||
|
0a3216e07d
|
|||
|
c29c0ed3ec
|
|||
|
fa7e56cb77
|
|||
|
13c19db818
|
|||
|
95b218fbed
|
|||
|
c3722c7438
|
|||
|
9dd547d6c1
|
|||
|
e2d5943517
|
|||
|
86e4763a12
|
|||
|
89ec63cb05
|
|||
|
e6375f1aa9
|
|||
|
d16e192a3a
|
|||
|
3f61f84e5a
|
|||
|
fd5399f50a
|
|||
|
8906ac3db8
|
|||
|
022aebf55b
|
|||
|
5dc6903425
|
|||
|
1b078b832c
|
|||
|
7515716864
|
|||
|
218b0c5b78
|
|||
|
928901ef9c
|
|||
|
4b62c78874
|
|||
|
f882eebaf5
|
|||
|
a872938405
|
|||
|
146be72fd7
|
|||
|
6de54e1da1
|
|||
|
c82b41a4df
|
|||
|
8304760fe0
|
|||
|
6bf91db757
|
|||
|
3f6b650a4b
|
|||
| ec079f32ca | |||
|
6524b3591a
|
|||
|
170101aa37
|
|||
|
0b3f33d7fe
|
|||
|
8a9b4f3989
|
|||
|
bbd0e3ae8d
|
|||
|
4d23e8840e
|
|||
|
c64d626d1c
|
|||
|
ecab1b74a4
|
|||
|
0bbdf04621
|
|||
|
939e5af4ce
|
|||
|
a735113466
|
|||
|
0e0a1b26f2
|
|||
|
e94db2181f
|
|||
|
9b59058881
|
|||
|
d0c54db33a
|
|||
|
5aedddfabb
|
|||
|
8d7c115432
|
|||
|
832c350b61
|
|||
|
3d599b3462
|
|||
|
4f799caaf5
|
|||
|
f4d2be3b1b
|
|||
|
7ce2840f03
|
|||
|
e2f3cabe15
|
|||
|
5a112332f2
|
|||
|
eb79cf6dc3
|
|||
|
8a9bb6ef4e
|
|||
|
6e0190a378
|
|||
| b5969e9a2b | |||
|
409d9f8fa6
|
|||
|
12d762429d
|
|||
|
53929ee514
|
|||
|
2f6e137f1a
|
|||
|
5224e79d9f
|
|||
|
bdcb12c58a
|
|||
|
5cb4d587e3
|
|||
|
8f9ec8d73b
|
|||
|
c1c50a448e
|
|||
|
19229db0b1
|
|||
|
f3b6bd146f
|
|||
|
98c3510bd4
|
|||
|
429d0d98fe
|
|||
|
db8fe5d3ff
|
|||
|
7477ec8d70
|
|||
|
adf7f4e7a2
|
|||
|
abf6787946
|
|||
|
e282b08597
|
|||
|
0a02b9d3d9
|
|||
| 875ca589e4 | |||
|
88f92d6e1f
|
|||
|
db4ed74365
|
|||
|
7cbf4fdece
|
|||
|
1fa9a09bfe
|
5
.gitignore
vendored
5
.gitignore
vendored
@@ -3,4 +3,7 @@ __pycache__
|
||||
.env
|
||||
venv
|
||||
.venv
|
||||
*.pyc
|
||||
*.pyc
|
||||
uv.lock
|
||||
.python-version
|
||||
/out
|
||||
154
README.md
154
README.md
@@ -1,7 +1,159 @@
|
||||
# Midas
|
||||
<h1>Midas</h1>
|
||||
|
||||
*Midas* is a type system to _Maintain Integrity of Data with Annotated Structures_. In Greek mythology, [Midas](https://en.wikipedia.org/wiki/Midas) was a Phrygian king who was blessed with the gift of turning everything he touched into gold.
|
||||
|
||||
*Midas* aims at providing Python developers with a simple annotation system to enable compile-time integrity and data type checks, as well as generating runtime assertions.
|
||||
|
||||
This framework is being developed as part of a Bachelor's Thesis by Louis Heredero at HEI Sion.
|
||||
|
||||
<details>
|
||||
<summary><strong>Table of Contents</strong></summary>
|
||||
|
||||
- [Requirements](#requirements)
|
||||
- [Installation](#installation)
|
||||
- [Commands](#commands)
|
||||
- [Type Checking](#type-checking)
|
||||
- [Compiling](#compiling)
|
||||
- [Formatting](#formatting)
|
||||
- [Highlighting](#highlighting)
|
||||
- [Dumping the AST](#dumping-the-ast)
|
||||
- [Dumping the Registry](#dumping-the-registry)
|
||||
- [Generating Stubs](#generating-stubs)
|
||||
- [Showing Type Judgements](#showing-type-judgements)
|
||||
- [Validating Definitions](#validating-definitions)
|
||||
- [Tests](#tests)
|
||||
</details>
|
||||
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python 3.11+
|
||||
- [uv](https://docs.astral.sh/uv/getting-started/installation/)
|
||||
|
||||
## Installation
|
||||
|
||||
1. Clone the repository
|
||||
```shell
|
||||
git clone https://git.kb28.ch/HEL/midas.git
|
||||
```
|
||||
2. Go in the project directory
|
||||
```shell
|
||||
cd midas
|
||||
```
|
||||
3. Install the CLI as a user-wide tool
|
||||
```shell
|
||||
uv tool install .
|
||||
```
|
||||
4. You can now run the `midas` command from anywhere
|
||||
```shell
|
||||
midas --help
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
<!--
|
||||
check
|
||||
compile
|
||||
format
|
||||
highlight
|
||||
parse
|
||||
dump_registry
|
||||
types
|
||||
validate
|
||||
-->
|
||||
|
||||
### Type Checking
|
||||
|
||||
```shell
|
||||
midas check -t types.midas source.py
|
||||
```
|
||||
|
||||
This command parses the given files and run the type checkers against the Midas definitions and Python program. Diagnostics are then printed showing warnings and errors.
|
||||
|
||||
### Compiling
|
||||
|
||||
```shell
|
||||
midas compile -t types.midas source.py
|
||||
```
|
||||
|
||||
With the `compile` command, you can process a source Python file, with any number of custom type definition files (`-t FILE` option), and the type checker will verify the coherence of your program and generate the runnable code with valid syntax and runtime assertions.
|
||||
|
||||
### Formatting
|
||||
|
||||
```shell
|
||||
midas format types.midas
|
||||
midas format types.midas -o formatted.midas
|
||||
```
|
||||
|
||||
This command parses the given Midas file and outputs a pretty printed file from the AST.
|
||||
|
||||
### Highlighting
|
||||
|
||||
```shell
|
||||
midas highlight source.py
|
||||
midas highlight source.py -o highlighted.html
|
||||
midas highlight types.midas
|
||||
midas highlight types.midas -o highlighted.html
|
||||
```
|
||||
|
||||
The `highlight` command takes in a source file (Python or Midas), runs the appropriate parser and outputs an HTML file containing the source code with added highlighting. This highlighting takes the form of hoverable annotations showing some of the parsed structures (e.g. a function definition, an assignment, a generic type, etc.)
|
||||
|
||||
The optional `-o FILE` option can be used to specify an output path. By default, the file is printed in stdout (equivalent to `-o -`).
|
||||
|
||||
### Dumping the AST
|
||||
|
||||
```shell
|
||||
midas parse source.py
|
||||
midas parse types.midas
|
||||
```
|
||||
|
||||
For debugging purposes, you can output the AST parsed from a Python or Midas file. For Python files, the `--raw` flags lets you toggle the custom AST parsing. With `--raw`, the raw AST is returned, as produced by the builtin `ast` module. This flag has no effect on Midas files.
|
||||
|
||||
### Dumping the Registry
|
||||
|
||||
```shell
|
||||
midas dump-registry -t types.midas
|
||||
```
|
||||
|
||||
This command processes the given Midas definitions and dumps the contents of the types registry.
|
||||
|
||||
### Generating Stubs
|
||||
|
||||
```shell
|
||||
midas stubs types.midas -o stubs.pyi
|
||||
```
|
||||
|
||||
This command generate Python stubs from a Midas definition file
|
||||
|
||||
### Showing Type Judgements
|
||||
|
||||
```shell
|
||||
midas types -t types.midas source.py
|
||||
```
|
||||
|
||||
This command type checks the given Python source file and logs all typing judgements made by the type checker.
|
||||
|
||||
### Validating Definitions
|
||||
|
||||
```shell
|
||||
midas validate types.midas
|
||||
```
|
||||
|
||||
This command lets you validate a Midas definition file by running the parser and type checker, verifying syntax and references.
|
||||
|
||||
## Tests
|
||||
|
||||
Several snapshot tests are available to assert the good behaviour of the parsers and type checker. They can be run as follows:
|
||||
|
||||
```shell
|
||||
uv run -m tests.midas run -a
|
||||
uv run -m tests.python run -a
|
||||
uv run -m tests.checker run -a
|
||||
uv run -m tests.generator run -a
|
||||
```
|
||||
|
||||
**Available subcommands:**
|
||||
- Run all tests: `run -a`
|
||||
- Run specific tests: `run tests/cases/test1.py tests/cases/test2.py ...`
|
||||
- Update all tests: `update -a`
|
||||
- Update specific tests: `update tests/cases/test1.py tests/cases/test2.py ...`
|
||||
|
||||
117
assets/icon.svg
Normal file
117
assets/icon.svg
Normal file
@@ -0,0 +1,117 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<!-- Created with Inkscape (http://www.inkscape.org/) -->
|
||||
|
||||
<svg
|
||||
width="128"
|
||||
height="128"
|
||||
viewBox="0 0 128 128"
|
||||
version="1.1"
|
||||
id="svg1"
|
||||
inkscape:export-filename="logo.png"
|
||||
inkscape:export-xdpi="96"
|
||||
inkscape:export-ydpi="96"
|
||||
inkscape:version="1.4.4 (1:1.4.4+202605061436+dcaf3e7d9e)"
|
||||
sodipodi:docname="logo.svg"
|
||||
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
|
||||
xmlns:xlink="http://www.w3.org/1999/xlink"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
xmlns:svg="http://www.w3.org/2000/svg">
|
||||
<sodipodi:namedview
|
||||
id="namedview1"
|
||||
pagecolor="#ffffff"
|
||||
bordercolor="#000000"
|
||||
borderopacity="0.25"
|
||||
inkscape:showpageshadow="2"
|
||||
inkscape:pageopacity="0.0"
|
||||
inkscape:pagecheckerboard="0"
|
||||
inkscape:deskcolor="#d1d1d1"
|
||||
inkscape:document-units="mm"
|
||||
showgrid="true"
|
||||
inkscape:zoom="1.9332778"
|
||||
inkscape:cx="-8.2760999"
|
||||
inkscape:cy="112.2446"
|
||||
inkscape:window-width="2584"
|
||||
inkscape:window-height="1028"
|
||||
inkscape:window-x="0"
|
||||
inkscape:window-y="24"
|
||||
inkscape:window-maximized="1"
|
||||
inkscape:current-layer="layer1">
|
||||
<inkscape:grid
|
||||
id="grid1"
|
||||
units="px"
|
||||
originx="0"
|
||||
originy="0"
|
||||
spacingx="4"
|
||||
spacingy="4"
|
||||
empcolor="#0099e5"
|
||||
empopacity="0.30196078"
|
||||
color="#0099e5"
|
||||
opacity="0.14901961"
|
||||
empspacing="4"
|
||||
enabled="true"
|
||||
visible="true" />
|
||||
</sodipodi:namedview>
|
||||
<defs
|
||||
id="defs1">
|
||||
<linearGradient
|
||||
inkscape:collect="always"
|
||||
xlink:href="#linearGradient4689"
|
||||
id="linearGradient1478"
|
||||
gradientUnits="userSpaceOnUse"
|
||||
gradientTransform="matrix(0.562541,0,0,0.567972,-9.399749,-5.305317)"
|
||||
x1="26.648937"
|
||||
y1="20.603781"
|
||||
x2="135.66525"
|
||||
y2="114.39767" />
|
||||
<linearGradient
|
||||
id="linearGradient4689">
|
||||
<stop
|
||||
style="stop-color:#e1be1e;stop-opacity:1;"
|
||||
offset="0"
|
||||
id="stop4691" />
|
||||
<stop
|
||||
style="stop-color:#ffeb82;stop-opacity:1;"
|
||||
offset="1"
|
||||
id="stop4693" />
|
||||
</linearGradient>
|
||||
<linearGradient
|
||||
inkscape:collect="always"
|
||||
xlink:href="#linearGradient4671"
|
||||
id="linearGradient1475"
|
||||
gradientUnits="userSpaceOnUse"
|
||||
gradientTransform="matrix(0.562541,0,0,0.567972,-9.399749,-5.305317)"
|
||||
x1="150.96111"
|
||||
y1="192.35176"
|
||||
x2="112.03144"
|
||||
y2="137.27299" />
|
||||
<linearGradient
|
||||
id="linearGradient4671">
|
||||
<stop
|
||||
style="stop-color:#ffdc21;stop-opacity:1;"
|
||||
offset="0"
|
||||
id="stop4673" />
|
||||
<stop
|
||||
style="stop-color:#ffeb82;stop-opacity:1;"
|
||||
offset="1"
|
||||
id="stop4675" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
<g
|
||||
inkscape:label="Calque 1"
|
||||
inkscape:groupmode="layer"
|
||||
id="layer1">
|
||||
<g
|
||||
id="g1"
|
||||
transform="translate(2.911719,3.414527)">
|
||||
<path
|
||||
style="fill:url(#linearGradient1478);fill-opacity:1"
|
||||
d="m 60.510156,6.3979729 c -4.583653,0.021298 -8.960939,0.4122177 -12.8125,1.09375 C 36.35144,9.4962267 34.291407,13.691825 34.291406,21.429223 v 10.21875 h 26.8125 v 3.40625 h -26.8125 -10.0625 c -7.792459,0 -14.6157592,4.683717 -16.7500002,13.59375 -2.46182,10.212966 -2.5710151,16.586023 0,27.25 1.9059283,7.937852 6.4575432,13.593748 14.2500002,13.59375 h 9.21875 v -12.25 c 0,-8.849902 7.657144,-16.656248 16.75,-16.65625 h 26.78125 c 7.454951,0 13.406253,-6.138164 13.40625,-13.625 v -25.53125 c 0,-7.266339 -6.12998,-12.7247775 -13.40625,-13.9375001 -4.605987,-0.7667253 -9.385097,-1.1150483 -13.96875,-1.09375 z m -14.5,8.2187501 c 2.769547,0 5.03125,2.298646 5.03125,5.125 -2e-6,2.816336 -2.261703,5.09375 -5.03125,5.09375 -2.779476,-1e-6 -5.03125,-2.277415 -5.03125,-5.09375 -1e-6,-2.826353 2.251774,-5.125 5.03125,-5.125 z"
|
||||
id="path1948" />
|
||||
<path
|
||||
style="fill:url(#linearGradient1475);fill-opacity:1"
|
||||
d="m 91.228906,35.054223 v 11.90625 c 0,9.230755 -7.825895,16.999999 -16.75,17 h -26.78125 c -7.335833,0 -13.406249,6.278483 -13.40625,13.625 v 25.531247 c 0,7.26634 6.318588,11.54032 13.40625,13.625 8.487331,2.49561 16.626237,2.94663 26.78125,0 6.750155,-1.95439 13.406253,-5.88761 13.40625,-13.625 V 92.897973 h -26.78125 v -3.40625 h 26.78125 13.406254 c 7.79246,0 10.69625,-5.435408 13.40624,-13.59375 2.79933,-8.398886 2.68022,-16.475776 0,-27.25 -1.92578,-7.757441 -5.60387,-13.59375 -13.40624,-13.59375 z m -15.0625,64.65625 c 2.779478,3e-6 5.03125,2.277417 5.03125,5.093747 -2e-6,2.82635 -2.251775,5.125 -5.03125,5.125 -2.76955,0 -5.03125,-2.29865 -5.03125,-5.125 2e-6,-2.81633 2.261697,-5.093747 5.03125,-5.093747 z"
|
||||
id="path1950" />
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 4.7 KiB |
@@ -1,107 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from lexer.token import Token
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Stmt(ABC):
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_annotation_stmt(self, stmt: AnnotationStmt) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AnnotationStmt(Stmt):
|
||||
name: Token
|
||||
schema: Optional[SchemaExpr]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_annotation_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Expr(ABC):
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_type_expr(self, expr: TypeExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint_expr(self, expr: ConstraintExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_schema_expr(self, expr: SchemaExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_schema_element_expr(self, expr: SchemaElementExpr) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WildcardExpr(Expr):
|
||||
token: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_wildcard_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LiteralExpr(Expr):
|
||||
value: Any
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_literal_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeExpr(Expr):
|
||||
name: Token
|
||||
constraints: list[ConstraintExpr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_type_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstraintExpr(Expr):
|
||||
left: Expr
|
||||
op: Token
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SchemaExpr(Expr):
|
||||
left: Token
|
||||
elements: list[Expr]
|
||||
right: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_schema_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SchemaElementExpr(Expr):
|
||||
name: Optional[Token]
|
||||
type: Optional[Expr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_schema_element_expr(self)
|
||||
@@ -1,138 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from lexer.token import Token
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# Statements
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Stmt(ABC):
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_op_stmt(self, stmt: OpStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint_stmt(self, stmt: ConstraintStmt) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeStmt(Stmt):
|
||||
name: Token
|
||||
bases: list[TypeExpr]
|
||||
body: Optional[TypeBodyExpr]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_type_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PropertyStmt(Stmt):
|
||||
name: Token
|
||||
type: TypeExpr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_property_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OpStmt(Stmt):
|
||||
left: TypeExpr
|
||||
op: Token
|
||||
right: TypeExpr
|
||||
result: TypeExpr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_op_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstraintStmt(Stmt):
|
||||
name: Token
|
||||
constraint: ConstraintExpr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_stmt(self)
|
||||
|
||||
|
||||
# Expressions
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Expr(ABC):
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_type_expr(self, expr: TypeExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint_expr(self, expr: ConstraintExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_type_body_expr(self, expr: TypeBodyExpr) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WildcardExpr(Expr):
|
||||
token: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_wildcard_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LiteralExpr(Expr):
|
||||
value: Any
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_literal_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeExpr(Expr):
|
||||
name: Token
|
||||
constraints: list[ConstraintExpr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_type_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstraintExpr(Expr):
|
||||
left: Expr
|
||||
op: Token
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeBodyExpr(Expr):
|
||||
properties: list[PropertyStmt]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_type_body_expr(self)
|
||||
@@ -1,360 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
import io
|
||||
from typing import Generator, Generic, Optional, Protocol, TypeVar
|
||||
|
||||
import core.ast.annotations as a
|
||||
import core.ast.midas as m
|
||||
|
||||
|
||||
class _Level(Enum):
|
||||
EMPTY = auto()
|
||||
ACTIVE = auto()
|
||||
LAST = auto()
|
||||
|
||||
|
||||
class Expr(Protocol):
|
||||
def accept(self, printer: AstPrinter) -> None: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Expr)
|
||||
|
||||
|
||||
class AstPrinter(Generic[T]):
|
||||
LAST_CHILD = "└── "
|
||||
CHILD = "├── "
|
||||
VERTICAL = "│ "
|
||||
EMPTY = " "
|
||||
|
||||
def __init__(self):
|
||||
self._levels: list[_Level] = []
|
||||
self._idx: Optional[int] = None
|
||||
self._buf: io.StringIO = io.StringIO()
|
||||
|
||||
def print(self, expr: T):
|
||||
self._buf = io.StringIO()
|
||||
expr.accept(self)
|
||||
return self._buf.getvalue()
|
||||
|
||||
@contextmanager
|
||||
def _child_level(self, last: bool = False) -> Generator[None, None, None]:
|
||||
self._levels.append(_Level.LAST if last else _Level.ACTIVE)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._levels.pop()
|
||||
|
||||
def _mark_last(self):
|
||||
if self._levels:
|
||||
self._levels[-1] = _Level.LAST
|
||||
|
||||
def _write_line(self, text: str, *, last: bool = False):
|
||||
if last:
|
||||
self._mark_last()
|
||||
indent: str = self._build_indent()
|
||||
if self._idx is not None:
|
||||
text = f"[{self._idx}] {text}"
|
||||
self._idx = None
|
||||
self._buf.write(indent + text + "\n")
|
||||
|
||||
def _build_indent(self) -> str:
|
||||
parts: list[str] = []
|
||||
for level in self._levels[:-1]:
|
||||
parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL)
|
||||
if self._levels:
|
||||
if self._levels[-1] == _Level.LAST:
|
||||
parts.append(self.LAST_CHILD)
|
||||
self._levels[-1] = _Level.EMPTY
|
||||
else:
|
||||
parts.append(self.CHILD)
|
||||
return "".join(parts)
|
||||
|
||||
def _write_optional_child(
|
||||
self, label: str, child: Optional[T], *, last: bool = False
|
||||
):
|
||||
if last:
|
||||
self._mark_last()
|
||||
if child is None:
|
||||
self._write_line(f"{label}: None")
|
||||
else:
|
||||
self._write_line(label)
|
||||
with self._child_level(last=True):
|
||||
child.accept(self)
|
||||
|
||||
|
||||
class AnnotationAstPrinter(AstPrinter, a.Expr.Visitor[None], a.Stmt.Visitor[None]):
|
||||
def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> None:
|
||||
self._write_line("AnnotationStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_optional_child("schema", stmt.schema, last=True)
|
||||
|
||||
def visit_type_expr(self, expr: a.TypeExpr):
|
||||
self._write_line("TypeExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"')
|
||||
self._write_line("constraints", last=True)
|
||||
with self._child_level():
|
||||
for i, constraint in enumerate(expr.constraints):
|
||||
self._idx = i
|
||||
if i == len(expr.constraints) - 1:
|
||||
self._mark_last()
|
||||
constraint.accept(self)
|
||||
|
||||
def visit_constraint_expr(self, expr: a.ConstraintExpr) -> None:
|
||||
self._write_line("ConstraintExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.op.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_schema_expr(self, expr: a.SchemaExpr):
|
||||
self._write_line("SchemaExpr")
|
||||
with self._child_level():
|
||||
for i, elmt in enumerate(expr.elements):
|
||||
self._idx = i
|
||||
if i == len(expr.elements) - 1:
|
||||
self._mark_last()
|
||||
elmt.accept(self)
|
||||
|
||||
def visit_schema_element_expr(self, expr: a.SchemaElementExpr):
|
||||
self._write_line("SchemaElementExpr")
|
||||
with self._child_level():
|
||||
name_text: str = "None" if expr.name is None else f'"{expr.name.lexeme}"'
|
||||
self._write_line(f"name: {name_text}")
|
||||
self._write_optional_child("type", expr.type, last=True)
|
||||
|
||||
def visit_wildcard_expr(self, expr: a.WildcardExpr) -> None:
|
||||
self._write_line("WildcardExpr")
|
||||
|
||||
def visit_literal_expr(self, expr: a.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"value: {expr.value}", last=True)
|
||||
|
||||
|
||||
class AnnotationPrinter(a.Expr.Visitor[str], a.Stmt.Visitor[str]):
|
||||
def print(self, expr: a.Expr | a.Stmt):
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> str:
|
||||
schema: str = ""
|
||||
if stmt.schema is not None:
|
||||
schema = stmt.schema.accept(self)
|
||||
return f"{stmt.name.lexeme}{schema}"
|
||||
|
||||
def visit_type_expr(self, expr: a.TypeExpr) -> str:
|
||||
parts: list[str] = [expr.name.lexeme]
|
||||
for constraint in expr.constraints:
|
||||
parts.append("(" + constraint.accept(self) + ")")
|
||||
return " + ".join(parts)
|
||||
|
||||
def visit_constraint_expr(self, expr: a.ConstraintExpr) -> str:
|
||||
parts: list[str] = [
|
||||
expr.left.accept(self),
|
||||
expr.op.lexeme,
|
||||
expr.right.accept(self),
|
||||
]
|
||||
return " ".join(parts)
|
||||
|
||||
def visit_schema_expr(self, expr: a.SchemaExpr) -> str:
|
||||
res: str = expr.left.lexeme
|
||||
res += ", ".join(elmt.accept(self) for elmt in expr.elements)
|
||||
res += expr.right.lexeme
|
||||
return res
|
||||
|
||||
def visit_schema_element_expr(self, expr: a.SchemaElementExpr) -> str:
|
||||
parts: list[str] = []
|
||||
if expr.name is not None:
|
||||
parts.append(expr.name.lexeme)
|
||||
|
||||
if expr.type is None:
|
||||
parts.append("_")
|
||||
else:
|
||||
parts.append(expr.type.accept(self))
|
||||
return ": ".join(parts)
|
||||
|
||||
def visit_wildcard_expr(self, expr: a.WildcardExpr) -> str:
|
||||
return "_"
|
||||
|
||||
def visit_literal_expr(self, expr: a.LiteralExpr) -> str:
|
||||
return str(expr.value)
|
||||
|
||||
|
||||
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt):
|
||||
self._write_line("TypeStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("bases")
|
||||
with self._child_level():
|
||||
for i, base in enumerate(stmt.bases):
|
||||
self._idx = i
|
||||
if i == len(stmt.bases) - 1:
|
||||
self._mark_last()
|
||||
base.accept(self)
|
||||
self._write_optional_child("body", stmt.body, last=True)
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||
self._write_line("PropertyStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
|
||||
self._write_line("OpStmt")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
stmt.left.accept(self)
|
||||
|
||||
self._write_line(f'op: "{stmt.op.lexeme}"')
|
||||
|
||||
self._write_line("right")
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
stmt.right.accept(self)
|
||||
|
||||
self._write_line("result", last=True)
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
stmt.result.accept(self)
|
||||
|
||||
def visit_constraint_stmt(self, stmt: m.ConstraintStmt):
|
||||
self._write_line("ConstraintStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("constraint", last=True)
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
stmt.constraint.accept(self)
|
||||
|
||||
def visit_type_expr(self, expr: m.TypeExpr):
|
||||
self._write_line("TypeExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"')
|
||||
self._write_line("constraints", last=True)
|
||||
with self._child_level():
|
||||
for i, constraint in enumerate(expr.constraints):
|
||||
self._idx = i
|
||||
if i == len(expr.constraints) - 1:
|
||||
self._mark_last()
|
||||
constraint.accept(self)
|
||||
|
||||
def visit_constraint_expr(self, expr: m.ConstraintExpr):
|
||||
self._write_line("ConstraintExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.op.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_type_body_expr(self, expr: m.TypeBodyExpr):
|
||||
self._write_line("TypeBodyExpr")
|
||||
with self._child_level():
|
||||
self._write_line("properties", last=True)
|
||||
with self._child_level():
|
||||
for i, property in enumerate(expr.properties):
|
||||
self._idx = i
|
||||
if i == len(expr.properties) - 1:
|
||||
self._mark_last()
|
||||
property.accept(self)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||
self._write_line("WildcardExpr")
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"value: {expr.value}", last=True)
|
||||
|
||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
||||
def __init__(self, indent: int = 4):
|
||||
self.indent: int = indent
|
||||
self.level: int = 0
|
||||
|
||||
def indented(self, text: str) -> str:
|
||||
return " " * (self.level * self.indent) + text
|
||||
|
||||
def print(self, expr: m.Expr | m.Stmt):
|
||||
self.level = 0
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt):
|
||||
bases: list[str] = [
|
||||
b.accept(self)
|
||||
for b in stmt.bases
|
||||
]
|
||||
|
||||
res: str = self.indented(f"type {stmt.name.lexeme}<{', '.join(bases)}>")
|
||||
if stmt.body is not None:
|
||||
res += " {\n"
|
||||
self.level += 1
|
||||
res += stmt.body.accept(self)
|
||||
self.level -= 1
|
||||
res += "\n" + self.indented("}")
|
||||
|
||||
return res
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||
return f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt):
|
||||
left: str = stmt.left.accept(self)
|
||||
op: str = stmt.op.lexeme
|
||||
right: str = stmt.right.accept(self)
|
||||
result: str = stmt.result.accept(self)
|
||||
return self.indented(f"op <{left}> {op} <{right}> = <{result}>")
|
||||
|
||||
def visit_constraint_stmt(self, stmt: m.ConstraintStmt):
|
||||
name: str = stmt.name.lexeme
|
||||
constraint: str = stmt.constraint.accept(self)
|
||||
return self.indented(f"constraint {name} = {constraint}")
|
||||
|
||||
def visit_type_expr(self, expr: m.TypeExpr):
|
||||
parts: list[str] = [expr.name.lexeme]
|
||||
for constraint in expr.constraints:
|
||||
parts.append("(" + constraint.accept(self) + ")")
|
||||
return " + ".join(parts)
|
||||
|
||||
def visit_constraint_expr(self, expr: m.ConstraintExpr):
|
||||
parts: list[str] = [
|
||||
expr.left.accept(self),
|
||||
expr.op.lexeme,
|
||||
expr.right.accept(self),
|
||||
]
|
||||
return " ".join(parts)
|
||||
|
||||
def visit_type_body_expr(self, expr: m.TypeBodyExpr):
|
||||
properties: list[str] = [
|
||||
self.indented(prop.accept(self))
|
||||
for prop in expr.properties
|
||||
]
|
||||
return "\n".join(properties)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr):
|
||||
return "_"
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr):
|
||||
return str(expr.value)
|
||||
150
docs/architecture.typ
Normal file
150
docs/architecture.typ
Normal file
@@ -0,0 +1,150 @@
|
||||
#import "@preview/cetz:0.5.2": canvas, draw
|
||||
|
||||
#let diagram-only = false
|
||||
|
||||
#set document(
|
||||
title: [Midas Architecture],
|
||||
//author: "Louis Heredero",
|
||||
)
|
||||
|
||||
#set text(
|
||||
font: "Source Sans 3",
|
||||
)
|
||||
|
||||
#let diagram = canvas({
|
||||
let framed = draw.content.with(
|
||||
padding: (x: .8em, y: 1em),
|
||||
frame: "rect",
|
||||
stroke: black,
|
||||
)
|
||||
let arrow = draw.line.with(mark: (end: ">", fill: black))
|
||||
framed(
|
||||
(0, 0),
|
||||
name: "python-parser",
|
||||
)[Python parser]
|
||||
|
||||
draw.content(
|
||||
(rel: (0, 1), to: "python-parser.north"),
|
||||
padding: 5pt,
|
||||
anchor: "south",
|
||||
name: "source-py",
|
||||
)[_`source.py`_]
|
||||
arrow("source-py", "python-parser")
|
||||
|
||||
framed(
|
||||
(rel: (3, 0), to: "python-parser.east"),
|
||||
anchor: "west",
|
||||
name: "custom-parser",
|
||||
align(center)[Custom python\ parser],
|
||||
)
|
||||
|
||||
arrow("python-parser", "custom-parser", name: "arrow-python-ast")
|
||||
draw.content(
|
||||
"arrow-python-ast",
|
||||
anchor: "south",
|
||||
padding: 5pt,
|
||||
)[`ast.Module`]
|
||||
|
||||
framed(
|
||||
(rel: (-3, -2), to: "custom-parser.south"),
|
||||
anchor: "east",
|
||||
name: "python-resolver",
|
||||
)[Python Resolver]
|
||||
arrow(
|
||||
"custom-parser",
|
||||
((), "|-", "python-resolver.east"),
|
||||
"python-resolver",
|
||||
name: "arrow-python-custom-ast",
|
||||
)
|
||||
draw.content(
|
||||
(rel: (1.5, 0), to: "arrow-python-custom-ast.end"),
|
||||
padding: 5pt,
|
||||
anchor: "south",
|
||||
)[P-AST#footnote[#strong[P]ython *AST*]<fn-past>]
|
||||
draw.content(
|
||||
"python-resolver.west",
|
||||
padding: 5pt,
|
||||
anchor: "south-east",
|
||||
)[Resolved P-AST@fn-past]
|
||||
|
||||
draw.circle(
|
||||
(rel: (1, -2), to: "custom-parser.south-east"),
|
||||
radius: .4,
|
||||
name: "midas-loader",
|
||||
)
|
||||
arrow(
|
||||
"custom-parser",
|
||||
"midas-loader",
|
||||
name: "arrow-load-midas",
|
||||
mark: (end: (symbol: ">", fill: black), start: "o"),
|
||||
)
|
||||
draw.content(
|
||||
"arrow-load-midas",
|
||||
anchor: "west",
|
||||
padding: 5pt,
|
||||
)[```python midas.using("types.midas")```]
|
||||
|
||||
framed(
|
||||
(rel: (0, -2), to: "midas-loader.south"),
|
||||
name: "midas-parser",
|
||||
)[Midas lexer/parser]
|
||||
arrow("midas-loader", "midas-parser", name: "arrow-midas-source")
|
||||
draw.content(
|
||||
"arrow-midas-source",
|
||||
anchor: "west",
|
||||
padding: 5pt,
|
||||
)[_`types.midas`_]
|
||||
|
||||
|
||||
framed(
|
||||
(rel: (-2, 0), to: "midas-parser.west"),
|
||||
anchor: "east",
|
||||
name: "midas-resolver",
|
||||
)[Midas Resolver]
|
||||
arrow("midas-parser", "midas-resolver", name: "arrow-midas-ast")
|
||||
draw.content(
|
||||
"arrow-midas-ast",
|
||||
anchor: "south",
|
||||
padding: 5pt,
|
||||
)[M-AST#footnote[#strong[M]idas *AST*]<fn-mast>]
|
||||
|
||||
framed(
|
||||
(rel: (-3, 0), to: "midas-resolver.west"),
|
||||
anchor: "east",
|
||||
name: "checker",
|
||||
)[Checker]
|
||||
arrow("midas-resolver", "checker", name: "arrow-type-ctx")
|
||||
arrow(
|
||||
"python-resolver",
|
||||
((), "-|", "checker.north"),
|
||||
"checker",
|
||||
)
|
||||
draw.content(
|
||||
"arrow-type-ctx",
|
||||
anchor: "south",
|
||||
padding: 5pt,
|
||||
)[Types context]
|
||||
})
|
||||
|
||||
#show: doc => if diagram-only {
|
||||
set page(width: auto, height: auto, margin: .5cm)
|
||||
diagram
|
||||
} else { doc }
|
||||
|
||||
#align(center, title())
|
||||
|
||||
#v(1cm)
|
||||
|
||||
#figure(
|
||||
diagram,
|
||||
caption: [Midas type-checker architecture],
|
||||
)
|
||||
|
||||
== Components
|
||||
|
||||
- *Python parser*: builtin Python AST parser, extracts abstract syntax from the raw Python source (```python ast.parse(...)```)
|
||||
- *Custom python parser*: converts the raw Python AST into custom, more suitable constructs, especially for type annotations
|
||||
- *Python resolver*: resolves bindings and references, tracks binding scopes
|
||||
- *Midas lexer/parser*: parses a Midas type definition file and extracts its AST
|
||||
- *Midas resolver*: walks the AST and fills the environment with the defined types and operations
|
||||
- *Checker*: evaluates expressions and checks type coherence
|
||||
809
docs/manual.typ
Normal file
809
docs/manual.typ
Normal file
@@ -0,0 +1,809 @@
|
||||
//#import "@preview/codly:1.3.0": codly, codly-init
|
||||
// Fix unaligned highlights in v0.15.0 ()
|
||||
// See https://github.com/Dherse/codly/pull/132
|
||||
#import "@local/codly:1.3.1": codly, codly-init
|
||||
|
||||
#import "@preview/codly-languages:0.1.10": codly-languages
|
||||
#import "template.typ": TODO, project
|
||||
#import "@preview/gentle-clues:1.3.1" as gc
|
||||
|
||||
#let midas-version = toml("../pyproject.toml").project.version
|
||||
#let head-ref = read("../.git/HEAD").split(":").at(1).trim()
|
||||
#let commit-hash = read("../.git/" + head-ref).slice(0, 8)
|
||||
|
||||
#show: project.with(
|
||||
title: [Midas User Manual],
|
||||
author: "Louis Heredero",
|
||||
version: midas-version,
|
||||
hash: commit-hash,
|
||||
icon-path: path("../assets/icon.svg"),
|
||||
)
|
||||
|
||||
#show: codly-init
|
||||
#codly(
|
||||
languages: codly-languages
|
||||
+ (
|
||||
midas: (
|
||||
name: "Midas",
|
||||
color: rgb("#eedd47"),
|
||||
icon: box(
|
||||
image(
|
||||
"../assets/icon.svg",
|
||||
height: 130%,
|
||||
fit: "contain",
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
= Introduction
|
||||
|
||||
Python is a very popular programming language, especially in data sciences.
|
||||
However, it has been designed for simplicity, distancing itself from typed languages such as Java or C to embrace dynamic typing.
|
||||
What this means is that in Python, type checks are deferred to runtime when operations are concretely executed.
|
||||
|
||||
For developers, it might seem like a great way of simplifying the language and making it very flexible, but it does come with a cost.
|
||||
Indeed, type errors are very easy to make in Python. While passing an integer where a string is expected might not be an issue in some cases, these are the sort of thing that can cause crashes or incorrect results without a clear diagnostic to help the user fix it.
|
||||
|
||||
Fortunately, developers using IDEs or properly configured text editors can benefit from external type checkers such as MyPy which will perform static type analysis of their Python code. Some can also be configured to be very strict, forcing the user to make the whole code typeable statically, thus avoiding any runtime type errors.
|
||||
|
||||
This is not the end of the problem though. Some parts of a program, especially in data related fields, may not be available at "compile-time". For example, a dataset can be loaded from an external file, or data can be fetched from an API, with no guarantees of having the expected format when analyzing the code statically.
|
||||
|
||||
In turn, that can cause a range of loud and silent errors at runtime. A malformed number will probably crash the program when trying to convert it, but a NaN in a series of value might just produce wrong results without any exception. Combine this with often long-running data-processing pipelines and this is how developers can waste hours of precious computation time.
|
||||
|
||||
Midas is a type system which can be used on top of Python to provide better type checking capabilities and gradual typing.
|
||||
It aims at providing optional but strict type annotations and casting operations which can produce runtime assertions. It also allows the user to define dependent types with value constraints that are translated into runtime checks.
|
||||
|
||||
= Installation
|
||||
|
||||
Midas comes as a very light Python package that you can install on your system in a few simple steps.
|
||||
|
||||
== Requirements
|
||||
|
||||
Here below are the requirements for installing Midas. All Python dependencies will be installed by `uv` in the installation process described in @install-steps.
|
||||
|
||||
- Python 3.11+
|
||||
- `uv`
|
||||
|
||||
== Steps <install-steps>
|
||||
|
||||
1. Clone the repository
|
||||
```bash
|
||||
git clone https://git.kb28.ch/HEL/midas.git
|
||||
```
|
||||
2. Navigate inside the directory
|
||||
```bash
|
||||
cd midas
|
||||
```
|
||||
3. Install Midas as a tool in your local user space
|
||||
```bash
|
||||
uv tool install .
|
||||
```
|
||||
|
||||
And that's it ! You can now use Midas commands anywhere, like this:
|
||||
```bash
|
||||
midas --help
|
||||
```
|
||||
|
||||
= Quick Start
|
||||
|
||||
This chapter will give you the keys to quickly start using Midas in your project.
|
||||
|
||||
== Defining custom types
|
||||
|
||||
To begin with, you might want to define some custom types for your project, to avoid handling anonymous float values everywhere. To do so, create a `*.midas` file in your project, and write some definitions for your types. See @midas-ref for more information on syntax and features.
|
||||
|
||||
@qs-midas shows a simple example of what it might look like.
|
||||
|
||||
#codly(header: [types.midas])
|
||||
|
||||
#figure(
|
||||
```midas
|
||||
type Meter = float
|
||||
extend Meter {
|
||||
def __add__: fn(Meter, /) -> Meter
|
||||
def __sub__: fn(Meter, /) -> Meter
|
||||
}
|
||||
|
||||
type Coordinate = object
|
||||
extend Coordinate {
|
||||
prop x: Meter
|
||||
prop y: Meter
|
||||
}
|
||||
```,
|
||||
caption: [Example Midas type definitions],
|
||||
) <qs-midas>
|
||||
|
||||
You can check for any syntax error using the following command:
|
||||
```bash
|
||||
midas validate types.midas
|
||||
```
|
||||
|
||||
When you are happy with your definitions, you can generate Python stubs to use in your source code. This allows other type checkers like MyPy to recognize your custom types and avoid reporting them as undefined. It can also help catch some type errors in your IDE.
|
||||
```bash
|
||||
midas stubs types.midas -o stubs.pyi
|
||||
```
|
||||
|
||||
This command will generate a file as shown in @qs-stubs, providing stub classes to represent the type lattice including methods and properties.
|
||||
|
||||
#codly(header: [stubs.pyi])
|
||||
|
||||
#figure(
|
||||
```pyi
|
||||
from __future__ import annotations
|
||||
|
||||
class Meter(float):
|
||||
def __add__(self, _0: Meter, /) -> Meter: ...
|
||||
def __sub__(self, _0: Meter, /) -> Meter: ...
|
||||
|
||||
class Coordinate(object):
|
||||
x: Meter
|
||||
y: Meter
|
||||
```,
|
||||
caption: [Generated stubs from example definitions of @qs-midas],
|
||||
) <qs-stubs>
|
||||
|
||||
== Using Midas in Python
|
||||
|
||||
You can now write your Python program as you would normally. You can import your custom types from the generated stubs file and use them in type annotations.
|
||||
|
||||
You can also import the `cast` and `unsafe_cast` functions from `midas.typing` to explicitly cast a value to a specific type (see @cast for more information).
|
||||
|
||||
An example Python script is shown in @qs-python, demonstrating how you can use custom types in type annotations. Notice the comments describing errors that will be caught by the type checker in @qs-type-checking.
|
||||
|
||||
#codly(header: [script.py])
|
||||
|
||||
#figure(
|
||||
```python
|
||||
from lib import load_coordinate
|
||||
from midas.typing import cast
|
||||
from stubs import Coordinate, Meter
|
||||
|
||||
p1 = cast(Coordinate, load_coordinate(0))
|
||||
p2 = cast(Coordinate, load_coordinate(1))
|
||||
|
||||
diff_x = p2.x - p1.x
|
||||
diff_y = p2.y - p1.y
|
||||
|
||||
dist = diff_x + diff_y
|
||||
|
||||
p2.x += cast(Meter, 1)
|
||||
p2.y = True # invalid, wrong type
|
||||
p2.z = 3 # invalid, no property 'z' on Coordinate
|
||||
p2.x.a = 3 # invalid, no properties on Meter
|
||||
```,
|
||||
caption: [Example Python script],
|
||||
) <qs-python>
|
||||
|
||||
== Type checking <qs-type-checking>
|
||||
|
||||
Now that you have defined some types and written a script, you can run the type checker with the following command. You can also skip this step and directly run the compilation command in @qs-compilation.
|
||||
|
||||
```bash
|
||||
midas check -t types.midas script.py
|
||||
```
|
||||
|
||||
== Compiling <qs-compilation>
|
||||
|
||||
The final step is to compile your code. This step will produce a runnable Python script, including runtime assertions generated by `cast` expressions.
|
||||
|
||||
```bash
|
||||
midas compile -t types.midas script.py
|
||||
python3 build/midas/script.py
|
||||
```
|
||||
|
||||
= Midas Language Reference <midas-ref>
|
||||
|
||||
In this chapter, you will find a complete reference for the Midas definition language.
|
||||
|
||||
A `*.midas` file contains a number of statements, which can be:
|
||||
- *`alias`* statements (see @alias-stmt): to define a new type alias
|
||||
- *`type`* statements (see @type-stmt): to define a new type
|
||||
- *`extend`* statements (see @extend-stmt): to define member of a type
|
||||
- *`predicate`* statements (see @predicate-stmt): to define named predicates that can be used in constraint types
|
||||
|
||||
== Alias Statement <alias-stmt>
|
||||
|
||||
An *`alias`* statement lets you define a new type alias. It requires a unique name and base type.
|
||||
|
||||
While a `type` statement (see @type-stmt) allows generic definitions, aliases are purely a for givin an alternative name to a type.
|
||||
|
||||
#figure(
|
||||
```midas
|
||||
alias MyType = float
|
||||
```,
|
||||
caption: [Simple `alias` statement declaring a new type "`MyType`" equivalent to `float`],
|
||||
) <midas-simple-alias>
|
||||
|
||||
This statement defines a new type called `MyType` which is equivalent to `float`. `MyType` and `float` can be used interchangeably.
|
||||
|
||||
== Type Statement <type-stmt>
|
||||
|
||||
A *`type`* statement lets you define a new type. It requires a unique name and base type.
|
||||
|
||||
The simplest form of a *`type`* statement is:
|
||||
#figure(
|
||||
```midas
|
||||
type MyType = float
|
||||
```,
|
||||
caption: [Simple `type` statement declaring a new type "`MyType`" as a subtype of `float`],
|
||||
) <midas-simple-type>
|
||||
|
||||
This statement defines a new type called `MyType` which is a subtype of `float`. `MyType` is a `float` but a `float` is not necessarily `MyType`.
|
||||
|
||||
=== Builtin / base types
|
||||
|
||||
A number of base types are provided out of the box, which can be used to derive other types.
|
||||
|
||||
They correspond to Python's builtin types:
|
||||
```py object```,
|
||||
```py str```,
|
||||
```py float```,
|
||||
```py int```,
|
||||
```py bool```,
|
||||
```py list```,
|
||||
```py dict```,
|
||||
```py None```.
|
||||
|
||||
Some differences are to be noted however.
|
||||
1. ```py bool``` is not a subtype of ```py int```
|
||||
2. ```py list``` are homogeneous, i.e. all items must be of the same type
|
||||
3. ```py dict``` keys and values are homogeneous, i.e. all keys must be of the same type and all values must be of the same type (can be different from keys).
|
||||
|
||||
=== Function types
|
||||
|
||||
A function type is written in a similar notation to Python function definitions:
|
||||
#figure(
|
||||
```midas
|
||||
type Repeater = fn(text: str, count: int) -> str
|
||||
```,
|
||||
caption: [Simple function type definition],
|
||||
)
|
||||
|
||||
Midas supports positional-only, keyword-only and mixed arguments (using the `/` and `*` separators). You may omit the name of positional-only arguments. The return type is required.
|
||||
|
||||
Optional parameters can be indicated by adding a question mark (`?`) after their type:
|
||||
#figure(
|
||||
```midas
|
||||
type Repeater = fn(text: str, count: int, *, sep: str?) -> str
|
||||
```,
|
||||
caption: [Function type definition with an optional keyword-only parameter],
|
||||
)
|
||||
|
||||
#gc.warning[
|
||||
Sink arguments (`*args`, `**kwargs`) are not currently supported.
|
||||
]
|
||||
|
||||
=== Constraint types
|
||||
|
||||
A useful feature provided by Midas is the possibility to combine types with custom value constraints. For example, you might want to define a type for positive amounts of money:
|
||||
|
||||
#figure(
|
||||
```midas
|
||||
type Money = float
|
||||
type Income = Money where _ >= 0
|
||||
```,
|
||||
caption: [Simple constraint type definition],
|
||||
)
|
||||
|
||||
Constraints can be combined with any type using the `where` keyword, followed by a constraint expression (see @constraint-expr).
|
||||
|
||||
=== Generic types
|
||||
|
||||
For more complex types, you might want to use type parameters. For example, to define a container, we might write:
|
||||
#figure(
|
||||
```midas
|
||||
type Container[T] = object
|
||||
```,
|
||||
caption: [Simple generic container type definition],
|
||||
)
|
||||
|
||||
To better refine a generic type, you can also bound type parameters using the following syntax:
|
||||
#figure(
|
||||
```midas
|
||||
type Container[T <: float] = object
|
||||
```,
|
||||
caption: [Generic container type definition with a bound],
|
||||
)
|
||||
|
||||
This can be read as "`Container` is a generic type which takes one type parameter `T` that must be a subtype of `float`".\
|
||||
You can use a generic type, i.e. instantiate it, by using a similar syntax with concrete type as arguments:
|
||||
|
||||
#figure(
|
||||
```midas
|
||||
type MyContainer = Container[MyType]
|
||||
```,
|
||||
caption: [Application of a generic type],
|
||||
)
|
||||
|
||||
Generic types can also take multiple parameters, which are then separated by commas:
|
||||
#figure(
|
||||
```midas
|
||||
type ZipCodeRegistry = dict[int, str]
|
||||
```,
|
||||
caption: [Application of a multi-parameter generic type],
|
||||
)
|
||||
|
||||
The _body_ of a generic type, i.e. the right-hand side of the definition, can contain or even be equal to any number of its parameters.#footnote[The latter is not something that is expressible in standard Python, yet it brings a semantic distinction on top of structurally equivalent values.] For example, the following is a valid type statement:
|
||||
#figure(
|
||||
```midas
|
||||
type Price[T <: Currency] = T where _ > 0
|
||||
```,
|
||||
caption: [Type parameters in a generic type's body],
|
||||
)
|
||||
|
||||
=== `Column` / `Frame` types
|
||||
|
||||
To provide useful type-checking for data engineers, Midas offers two special types: `Column` and `Frame`.
|
||||
Their goal is to help type check Pandas' `Series` and `DataFrame` respectively.
|
||||
|
||||
==== `Column`
|
||||
|
||||
The `Column` type is a generic type used to represent a `pandas.Series` object.
|
||||
You can use it like any other generic type and it will provide type checking for some common methods and attributes offered by Pandas.
|
||||
|
||||
#figure(
|
||||
```midas
|
||||
type Temperature = float
|
||||
alias Temperatures = Column[Temperature]
|
||||
```,
|
||||
caption: [Simple column type definition],
|
||||
)
|
||||
|
||||
==== `Frame` <frame-type>
|
||||
|
||||
The `Frame` type is a super-powered generic type used to represent a `pandas.DataFrame` object.
|
||||
In place of type arguments, `Frame` accepts a schema, i.e. a series of column definitions.
|
||||
@simple-frame show how you can define a simple frame type with 3 columns:
|
||||
- `name`: a column of `Name` values
|
||||
- `age`: a column of `int` values
|
||||
- `height`: a column of `float where _ >= 0` values
|
||||
|
||||
Notice that you don't need to specify `Column` types.
|
||||
|
||||
#figure(
|
||||
```midas
|
||||
type Name = str where len(_) != 0
|
||||
alias Data = Frame[
|
||||
name: Name,
|
||||
age: int,
|
||||
height: float where _ >= 0
|
||||
]
|
||||
```,
|
||||
) <simple-frame>
|
||||
|
||||
#pagebreak()
|
||||
|
||||
== Extend Statement <extend-stmt>
|
||||
|
||||
Type statements allow you to define new types, kind of like type aliases. However, a type might have properties or methods of its own. These might override those of the parent type or be brand new members.
|
||||
|
||||
This is where the `extend` statement comes into play. It allows defining members on a given type. Members can either be properties (`prop`) or methods (`def`). The only difference between the two is that methods must be functions and can be overloaded.
|
||||
|
||||
Here is a simple example showing how to define a property and a method on a custom type:
|
||||
#figure(
|
||||
```midas
|
||||
type MyType = float
|
||||
extend MyType {
|
||||
prop norm: float
|
||||
def double: fn() -> MyType
|
||||
}
|
||||
```,
|
||||
caption: [Simple `extend` statement defining a property and a method],
|
||||
)
|
||||
|
||||
An `extend` statement can appear anywhere after the type it extends has been defined.
|
||||
You may want to override Python's dunder methods to implement type checking for some basic operators, like `__add__` for the `+` operator.
|
||||
|
||||
#figure(
|
||||
```midas
|
||||
type Money = float
|
||||
extend Money {
|
||||
def __add__(Money, /) -> Money
|
||||
def __mul__(float, /) -> Money
|
||||
}
|
||||
```,
|
||||
caption: [Simple `extend` statement overriding some dunder methods],
|
||||
)
|
||||
|
||||
When extending generic type, you must specify the whole type, including its parameter(s):
|
||||
#figure(
|
||||
```midas
|
||||
type Container[T <: float] = object
|
||||
extend Container[T <: float] {
|
||||
prop content: T
|
||||
def set_content: fn(content: T) -> None
|
||||
}
|
||||
```,
|
||||
caption: [Generic `extend` statement using type parameters in the declared members],
|
||||
)
|
||||
|
||||
#pagebreak()
|
||||
|
||||
== Predicate Statement <predicate-stmt>
|
||||
|
||||
A *`predicate`* statement lets you define a named constraint expression, like a function, which can then be used in other constraint expressions (either in other predicate statements or in constraint types). See @constraint-expr for more information about the syntax of constraint expressions.
|
||||
|
||||
The left-hand side of a predicate statement is written as a function signature, without a return type. The right-hand side is a constraint expression. For example:
|
||||
#figure(
|
||||
```midas
|
||||
predicate is_positive(v: float) = v >= 0
|
||||
```,
|
||||
caption: [Simple `predicate` statement defining an `is_positive` predicate],
|
||||
)
|
||||
|
||||
The left-hand side can also be curried to allow partial application. For example:
|
||||
#figure(
|
||||
```midas
|
||||
predicate in_range(mn: float, mx: float)(v: float) = mn <= v & v <= mx
|
||||
predicate is_ratio = in_range(0.0, 1.0)
|
||||
```,
|
||||
caption: [Curried `predicate` statement and partial application],
|
||||
) <midas-predicate-partial>
|
||||
|
||||
Notice that the second predicate statement doesn't take any parameters. This is simply a partial application of another predicate, kind of like an alias. You can use it in other expressions to finalize the call:
|
||||
#figure(
|
||||
```midas
|
||||
type Efficiency = float where is_ratio(_)
|
||||
```,
|
||||
caption: [Constraint type definition using the partially applied predicate from @midas-predicate-partial],
|
||||
)
|
||||
|
||||
Of course you can also directly call `in_range`:
|
||||
#figure(
|
||||
```midas
|
||||
type Efficiency = float where in_range(0.0, 1.0)(_)
|
||||
```,
|
||||
caption: [Full call of curried predicate from @midas-predicate-partial],
|
||||
)
|
||||
|
||||
When compiled, named predicates are translated to Python functions which are used in runtime assertions. Only predicates that are referenced are compiled.
|
||||
|
||||
#pagebreak()
|
||||
== Constraint Expressions <constraint-expr>
|
||||
|
||||
*Constraint expressions* are Python-like expressions which can appear in *`predicate`* statements or in constraint types.
|
||||
|
||||
They can contain comparisons, simple computations, logical operations and must evaluate to a boolean value.
|
||||
|
||||
Context is quite restricted inside these expressions. You can only reference some builtin functions, such as type constructors (`float(...)`, `str(...)`, etc.), parameters of predicate statements, and named predicates. In constraint type, the special variable `_` can be used to reference the value targeted by the type. For example:
|
||||
|
||||
#figure(
|
||||
```midas
|
||||
predicate not_nan(v: float) = v != float("nan")
|
||||
type RealFloat = float where not_nan(_)
|
||||
```,
|
||||
caption: [Example constraint expressions],
|
||||
) <ex-constraint-expr>
|
||||
|
||||
In the predicate statement (@ex-constraint-expr:1), we reference the parameter `v` and the builtin `float` function.
|
||||
In the constraint type definition (@ex-constraint-expr:2), we then reference the named predicate `not_nan`, passing the value targeted by the type itself ( `_` )
|
||||
|
||||
= Supported Python Syntax <python-ref>
|
||||
|
||||
Midas integrates naturally in Python via type annotations. Through generated stubs, even other type checker can detect your custom types (see @cmd-stubs).
|
||||
|
||||
It has been designed to leave the user free of typing any amount of their code but be strict about the parts that are annotated. By default, any untyped Python expression is assigned `UnknownType`.
|
||||
|
||||
Any operation is permitted on `UnknownType` and will result in `UnknownType` values.
|
||||
|
||||
The moment an expression can be typed, that be thanks to an annotation or a literal value, the type checker kicks in and will validate your statements.
|
||||
|
||||
Because Python is very flexible language with many features, some expressions and statements might be more complex to properly type check, thus only a subset of the Python language is fully supported. This chapter lists all supported features of Python and how they affect type checking.
|
||||
|
||||
Some examples are presented in the following sections in the form of code blocks. Highlights in the code blocks indicate the type assigned to each expression by the type checker. Some types may be omitted for readability. For example:
|
||||
|
||||
#codly(
|
||||
highlights: (
|
||||
(
|
||||
line: 1,
|
||||
start: 5,
|
||||
fill: green,
|
||||
tag: [_int_],
|
||||
),
|
||||
(
|
||||
line: 2,
|
||||
start: 7,
|
||||
end: 7,
|
||||
fill: green,
|
||||
tag: [_int_],
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
```python
|
||||
v = 3
|
||||
print(v)
|
||||
```
|
||||
|
||||
== Literals
|
||||
|
||||
Literal Python values are type checked using builtin types. Lists and dictionaries of literals are also typed liked literals. This does not include comprehension lists/dicts (```py [. for . in .]```), nor formatted strings (```py f"..."```). @supported-literals shows the list of supported literal values and their type.
|
||||
|
||||
#let supported-literals = table(
|
||||
columns: 2,
|
||||
table.header[*Example value*][*Judged Type*],
|
||||
```py 42```, ```py int```,
|
||||
```py 3.14```, ```py float```,
|
||||
```py True```, ```py bool```,
|
||||
```py "Midas"```, ```py str```,
|
||||
```py None```, ```py None```,
|
||||
```py [1, 2, 3]```, ```py list[int]```,
|
||||
```py {1: "One", 2: "Two"}```, ```py dict[int, str]```,
|
||||
```py ("1", 1, True)```, ```py tuple[str, int, bool]```,
|
||||
)
|
||||
|
||||
#figure(
|
||||
supported-literals,
|
||||
caption: [Supported literal values and their judged types],
|
||||
) <supported-literals>
|
||||
|
||||
== Assignments
|
||||
|
||||
Variable assignments allow assigning a new value to a variable. For the type checker, this implies two things:
|
||||
1. If the variable was not already declared in the current scope, it is declared at that point with the type of the right-hand side expression
|
||||
2. If the variable was already declared, the type of the right-hand side expression is checked against the declared type of the variable. Only a subtype of the variable's type can be assigned to it
|
||||
|
||||
Once a variable has been given a type, it cannot be changed in the same scope.
|
||||
|
||||
The walrus operator (```py :=```) is not currently supported.
|
||||
|
||||
A simple annotation declaration, without assigning a value, is enough to declare a variable. For example:
|
||||
#figure(
|
||||
```python
|
||||
var: float
|
||||
```,
|
||||
caption: [Bare Python variable annotation without assignment],
|
||||
)
|
||||
|
||||
Because unpacking is not supported, assigning to multiple values is also not handled by the type checker.
|
||||
For more information about type annotations, see @type-annotations
|
||||
|
||||
== Arithmetic
|
||||
|
||||
- All basic binary operators are supported, through dunder methods.
|
||||
- All comparison operators except ```py in``` are supported.
|
||||
- All unary operators are supported (`+`, `-`, `~`).
|
||||
- All logical operators are supported (```py and```, ```py or```, ```py not```).
|
||||
|
||||
== Ternary operator
|
||||
|
||||
The ternary operator ```py . if . else .``` is supported. As for `if` statements (see @if-else), the test expression must be a boolean. Additionally, both branches must be of the same type.
|
||||
|
||||
For example:
|
||||
#codly(
|
||||
highlights: (
|
||||
(
|
||||
line: 1,
|
||||
start: 10,
|
||||
end: 44,
|
||||
tag: [_str_],
|
||||
fill: blue,
|
||||
),
|
||||
(
|
||||
line: 1,
|
||||
start: 11,
|
||||
end: 16,
|
||||
tag: [_str_],
|
||||
fill: green,
|
||||
),
|
||||
(
|
||||
line: 1,
|
||||
start: 39,
|
||||
end: 43,
|
||||
tag: [_str_],
|
||||
fill: green,
|
||||
),
|
||||
(
|
||||
line: 1,
|
||||
start: 21,
|
||||
end: 32,
|
||||
tag: [_bool_],
|
||||
fill: green,
|
||||
),
|
||||
),
|
||||
)
|
||||
#figure(
|
||||
```python
|
||||
parity = ("even" if num % 2 == 0 else "odd")
|
||||
```,
|
||||
caption: [Typing of ternary operator],
|
||||
)
|
||||
|
||||
== Control flow
|
||||
|
||||
Some control flow features are supported. For the limited code of this project, not all constructs are supported. The following are those currently handled and typ checked by Midas.
|
||||
|
||||
=== `if` / `elif` / `else` <if-else>
|
||||
|
||||
Conditional statements are checked relatively strictly by Midas. The test expression, i.e. what comes after the ```py if``` keyword, must be a boolean. While Python allows introducing and leaking new variables from inside an ```py if``` statement, Midas will strictly forbid leaks by restraining bindings to the scope they are defined in. For example, the following Python code will not compile with Midas:
|
||||
#figure(
|
||||
```python
|
||||
age = 22
|
||||
if age >= 18:
|
||||
msg = "You're an adult"
|
||||
else:
|
||||
msg = "You're still a child"
|
||||
print(msg) # -> unknown variable 'msg'
|
||||
```,
|
||||
caption: [`if`/`else` statement cannot leak variables],
|
||||
)
|
||||
|
||||
=== `for` loops
|
||||
|
||||
Simple forms of `for` loops can be used, that is using a single variable and iterating over an object implementing the `__getitem__` method. Like above in @if-else, leaking variables from inside the loop is ignored.
|
||||
|
||||
`for`-`else` statements are not supported. `while` loops are also not supported.
|
||||
|
||||
== Functions
|
||||
|
||||
You can define functions as usual and the type checker will do its best to type it. Apart from argument sinks (`*args`, `**kwargs`), all forms of parameter specifications are supported (positional-only, keyword-only, mixed, optional).
|
||||
|
||||
As for the rest of your code, type annotations are optional, but recommended. If you omit the return type hint, the type checker will try to infer it from the function body and its return statements. If you did specify a return type, all return paths must return values that are subtypes of the type hint.
|
||||
|
||||
#codly(
|
||||
highlights: (
|
||||
(
|
||||
line: 2,
|
||||
start: 12,
|
||||
end: 16,
|
||||
tag: [_float_],
|
||||
fill: green,
|
||||
),
|
||||
(
|
||||
line: 2,
|
||||
start: 12,
|
||||
tag: [_float_],
|
||||
fill: blue,
|
||||
),
|
||||
(
|
||||
line: 3,
|
||||
start: 10,
|
||||
end: 15,
|
||||
tag: [_(value: float) -> float_],
|
||||
fill: green,
|
||||
),
|
||||
(
|
||||
line: 3,
|
||||
start: 17,
|
||||
end: 19,
|
||||
tag: [_float_],
|
||||
fill: green,
|
||||
),
|
||||
(
|
||||
line: 3,
|
||||
start: 10,
|
||||
tag: [_float_],
|
||||
fill: blue,
|
||||
),
|
||||
),
|
||||
)
|
||||
#figure(
|
||||
```python
|
||||
def double(value: float) -> float:
|
||||
return value * 2
|
||||
result = double(4.0)
|
||||
```,
|
||||
caption: [Typing of function's body and call],
|
||||
)
|
||||
|
||||
Anonymous functions (```py lambda```) are not yet supported
|
||||
|
||||
== Casts <cast>
|
||||
|
||||
#gc.info[
|
||||
The functions discussed in this section are provided by the `midas.typing` submodule. You can import them in your script like so:
|
||||
#figure(
|
||||
```python
|
||||
from midas.typing import cast, unsafe_cast
|
||||
```,
|
||||
caption: [Importing cast functions],
|
||||
)
|
||||
]
|
||||
|
||||
Sometimes, you may want to use a value whose type is not known to the type checker in a place where it expects a particular type. In that case, if you do know that the runtime type will correspond to what is expected, you can use a `cast` expression.
|
||||
|
||||
Similar to the `cast` function from the `typing` package of Python's Standard Library, it allows telling the type checker that a value has a given type. While `typing`'s function doesn't have any runtime side-effect, Midas' will generate runtime assertions, ensuring that your statement is true when running the code. What cannot be checked statically is checked at runtime.
|
||||
|
||||
In the following example, a runtime check would be generated to ensure that the value is indeed a `float` and that it satisfies the type's constraint (i.e. `>= 0`):
|
||||
|
||||
#codly(
|
||||
highlights: (
|
||||
(
|
||||
line: 1,
|
||||
start: 35,
|
||||
end: 47,
|
||||
tag: [_UnknownType_],
|
||||
fill: red,
|
||||
),
|
||||
(
|
||||
line: 2,
|
||||
start: 7,
|
||||
end: 17,
|
||||
tag: [_PositiveFloat_],
|
||||
fill: green,
|
||||
),
|
||||
),
|
||||
)
|
||||
#figure(
|
||||
```python
|
||||
typed_value = cast(PositiveFloat, unknown_value)
|
||||
print(typed_value)
|
||||
```,
|
||||
caption: [Typing of `cast` expression],
|
||||
)
|
||||
|
||||
#gc.warning[
|
||||
Assertions are statements inserted just before a statement using a `cast` expression. This means that the expression is evaluated _before_ its actual intended usage location, which might cause issues if you rely on logical operator short-circuiting. See @eager-eval for more information.
|
||||
]
|
||||
|
||||
There may be some cases where the cost of checking a value at runtime is simply not worth the safety, for example when dealing with a big dataset. If do wish so, you can use `unsafe_cast` which will only tell the type checker the type of the value, without generating a runtime assertion. This maps to the default behavior of `typing`'s own `cast` function.
|
||||
|
||||
If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a string, a list of literals, etc.), the assertion is evaluated _at compile-time_ and no runtime assertion is generated.
|
||||
|
||||
== Annotations / Type Hints <type-annotations>
|
||||
|
||||
Vanilla Python already lets you use type hints to specify the type of variables and function parameters.
|
||||
|
||||
Midas use them to type check your code. Additionally, it allows you to use a special syntax to define a `Frame` types directly in these annotations.
|
||||
|
||||
Because these annotations are not interpretable by Python, your integrated type checker might complain loudly about them being invalid.
|
||||
A workaround is to silence it by adding a type comment at the end of the line, as shown in @silence-errors.
|
||||
|
||||
#figure(
|
||||
```python
|
||||
var: Frame[name: str, age: float] # type: ignore # noqa: F821
|
||||
```,
|
||||
caption: [MyPy's and Pylance's complaints about custom type annotation can be silenced with type comments],
|
||||
) <silence-errors>
|
||||
|
||||
=== Frame type annotation
|
||||
|
||||
The syntax is similar to how you can define frame types in the Midas language (see @frame-type). The only difference is that types can only be name references; you cannot inline constraint types.
|
||||
|
||||
The example of @python-frame-type shows how you can annotate a dataframe with some columns directly in Python.
|
||||
|
||||
#figure(
|
||||
```python
|
||||
df: Frame[name: Name, age: float, height: Length[Meter]] = ...
|
||||
```,
|
||||
caption: [Frame type annotation in Python],
|
||||
) <python-frame-type>
|
||||
|
||||
= Commands <commands>
|
||||
|
||||
#TODO
|
||||
|
||||
== Type Checking (`check`) <cmd-check>
|
||||
== Compiling (`compile`) <cmd-compile>
|
||||
== Formatting (`format`) <cmd-format>
|
||||
== Highlighting (`highlight`) <cmd-highlight>
|
||||
== Dumping the AST (`parse`) <cmd-parse>
|
||||
== Dumping the Registry (`dump-registry`) <cmd-registry>
|
||||
== Generating Stubs (`stubs`) <cmd-stubs>
|
||||
== Showing Type Judgements (`types`) <cmd-types>
|
||||
== Validating Definitions (`validate`) <cmd-validate>
|
||||
|
||||
= Known limitations <limitations>
|
||||
|
||||
== Eager evaluation in runtime assertions <eager-eval>
|
||||
|
||||
The process of generating assertions to ensure safety at runtime, mainly for `cast` expressions, leads to the creation of aliases for the expressions being casted. These alias definitions eagerly evaluate before the assertion, and most importantly before the real usage location. This means that you should avoid using `cast` expressions inside logical expressions like `and` or `or`, because the normal "short-circuit" behavior will be irrelevant to the evaluations of the operands.
|
||||
|
||||
For example:
|
||||
|
||||
#figure(
|
||||
```py
|
||||
def foo():
|
||||
print("Foo")
|
||||
return True
|
||||
def bar():
|
||||
print("Bar")
|
||||
return True
|
||||
result = foo() or bar()
|
||||
# Foo
|
||||
# Bar
|
||||
```,
|
||||
caption: [Runtime assertions may eagerly evaluate expressions and bypass logical operator's short-circuit],
|
||||
)
|
||||
211
docs/midas.sublime-syntax
Normal file
211
docs/midas.sublime-syntax
Normal file
@@ -0,0 +1,211 @@
|
||||
%YAML 1.2
|
||||
---
|
||||
name: Midas
|
||||
file_extensions:
|
||||
- midas
|
||||
scope: source.midas
|
||||
|
||||
variables:
|
||||
identifier: "[a-zA-Z_][a-zA-Z0-9_]*"
|
||||
|
||||
contexts:
|
||||
prototype:
|
||||
- include: comments
|
||||
|
||||
main:
|
||||
- include: keywords
|
||||
- include: types
|
||||
|
||||
comments:
|
||||
- match: "//"
|
||||
scope: punctuation.definition.comment.midas
|
||||
push:
|
||||
- meta_scope: comment.line.midas
|
||||
- match: $
|
||||
pop: true
|
||||
- match: /\*
|
||||
scope: punctuation.definition.comment.midas
|
||||
push:
|
||||
- meta_scope: comment.block.midas
|
||||
- match: \*/
|
||||
pop: true
|
||||
|
||||
string:
|
||||
- meta_include_prototype: false
|
||||
- meta_scope: string.quoted.double.c
|
||||
- match: '"'
|
||||
pop: true
|
||||
|
||||
keywords:
|
||||
- match: \balias\b
|
||||
scope: keyword.declaration.midas
|
||||
push: alias-stmt
|
||||
- match: \btype\b
|
||||
scope: keyword.declaration.midas
|
||||
push: type-stmt
|
||||
- match: \bextend\b
|
||||
scope: keyword.declaration.midas
|
||||
push: extend-stmt
|
||||
- match: \bpredicate\b
|
||||
scope: keyword.declaration.midas
|
||||
push: predicate-stmt
|
||||
|
||||
alias-stmt:
|
||||
- match: "{{identifier}}"
|
||||
scope: entity.name.type
|
||||
- match: "="
|
||||
scope: keyword.operator.equal.midas
|
||||
push: type-expr
|
||||
- match: $
|
||||
pop: true
|
||||
|
||||
type-stmt:
|
||||
- match: "{{identifier}}"
|
||||
scope: entity.name.type
|
||||
- match: \[
|
||||
push: type-params
|
||||
- match: "="
|
||||
scope: keyword.operator.equal.midas
|
||||
push: type-expr
|
||||
- match: $
|
||||
pop: true
|
||||
|
||||
type-expr:
|
||||
- match: \b(fn)\s*(\()
|
||||
captures:
|
||||
1: keyword.other.midas
|
||||
2: punctuation.section.group.begin
|
||||
push: fn-params
|
||||
- match: \b(where)\b
|
||||
scope: keyword.other.midas
|
||||
set: constraint
|
||||
- match: "Frame"
|
||||
scope: entity.name.type
|
||||
push:
|
||||
- match: \[
|
||||
push: frame-schema
|
||||
- match: $
|
||||
pop: true
|
||||
- match: "{{identifier}}"
|
||||
scope: entity.name.type
|
||||
- match: $
|
||||
pop: 2
|
||||
|
||||
fn-params:
|
||||
- match: "({{identifier}})(:)"
|
||||
captures:
|
||||
1: variable.parameter.midas
|
||||
2: punctuation.separator.annotation.midas
|
||||
push:
|
||||
- include: type-expr
|
||||
- match: \?
|
||||
scope: keyword.operator.qmark.midas
|
||||
- match: "(?=,)"
|
||||
scope: punctuation.separator.midas
|
||||
pop: true
|
||||
- match: '(?=\))'
|
||||
pop: true
|
||||
- include: type-expr
|
||||
- match: '\)'
|
||||
set:
|
||||
- match: "->"
|
||||
scope: keyword.operator.arrow.midas
|
||||
set: type-expr
|
||||
|
||||
constraint:
|
||||
- match: $
|
||||
pop: 2
|
||||
|
||||
- match: \d+(\.\d+)?
|
||||
scope: constant.numeric.midas
|
||||
|
||||
- match: \b(true|false|none)\b
|
||||
scope: constant.language.midas
|
||||
|
||||
- match: '"'
|
||||
push: string
|
||||
|
||||
- match: (<=|>=|<|>|==|!=|&)
|
||||
scope: keyword.operator
|
||||
|
||||
- match: _
|
||||
scope: variable.language.midas
|
||||
|
||||
- match: '{{identifier}}(?=\s*\()'
|
||||
scope: variable.function.midas
|
||||
|
||||
- match: "{{identifier}}"
|
||||
scope: variable.other.readwrite.midas
|
||||
|
||||
type-params:
|
||||
- match: "<:"
|
||||
scope: keyword.operator.subtype.midas
|
||||
- match: "[a-zA-Z][a-zA-Z_0-9]*"
|
||||
scope: entity.name.type
|
||||
- match: "]"
|
||||
pop: true
|
||||
|
||||
extend-stmt:
|
||||
- match: "{{identifier}}"
|
||||
scope: entity.name.type
|
||||
- match: \[
|
||||
push: type-params
|
||||
- match: \{
|
||||
scope: punctuation.section.block.begin
|
||||
set: extend-body
|
||||
|
||||
extend-body:
|
||||
- include: member-stmt
|
||||
- match: \}
|
||||
scope: punctuation.section.block.end
|
||||
pop: true
|
||||
|
||||
member-stmt:
|
||||
- match: \b(prop|def)\b
|
||||
scope: keyword.other.midas
|
||||
push:
|
||||
- match: "{{identifier}}"
|
||||
scope: variable.other.member
|
||||
- match: ":"
|
||||
push: type-expr
|
||||
- match: $
|
||||
pop: true
|
||||
|
||||
predicate-stmt:
|
||||
- match: "{{identifier}}"
|
||||
scope: entity.name.function.midas
|
||||
- match: '\('
|
||||
push: predicate-params
|
||||
- match: "="
|
||||
scope: keyword.operator.equal.midas
|
||||
set: constraint
|
||||
- match: $
|
||||
pop: true
|
||||
|
||||
predicate-params:
|
||||
- match: "({{identifier}})(:)"
|
||||
captures:
|
||||
1: variable.parameter.midas
|
||||
2: punctuation.separator.annotation.midas
|
||||
push:
|
||||
- include: type-expr
|
||||
- match: "(?=,)"
|
||||
scope: punctuation.separator.midas
|
||||
pop: true
|
||||
- match: '(?=\))'
|
||||
pop: true
|
||||
|
||||
- match: '\)'
|
||||
pop: true
|
||||
|
||||
frame-schema:
|
||||
- include: frame-column
|
||||
- match: \]
|
||||
# scope: punctuation.section.block.end
|
||||
pop: true
|
||||
|
||||
frame-column:
|
||||
- match: "{{identifier}}"
|
||||
scope: variable.other.member
|
||||
- match: ":"
|
||||
push: type-expr
|
||||
143
docs/template.typ
Normal file
143
docs/template.typ
Normal file
@@ -0,0 +1,143 @@
|
||||
#import "@preview/modpattern:0.2.0": modpattern
|
||||
|
||||
#let TODO = block(
|
||||
width: 6em,
|
||||
height: 3em,
|
||||
stroke: red,
|
||||
fill: modpattern(
|
||||
size: (10pt, 10pt),
|
||||
line(
|
||||
start: (0%, 0%),
|
||||
end: (100%, 100%),
|
||||
stroke: gray.transparentize(60%) + 2pt,
|
||||
),
|
||||
),
|
||||
align(
|
||||
center + horizon,
|
||||
text(fill: red, size: 1.5em)[*TODO*],
|
||||
),
|
||||
)
|
||||
|
||||
#let _render-header(version, hash) = {
|
||||
let last-heading = query(heading.where(level: 1).before(here())).last(default: none)
|
||||
let next-heading = query(heading.where(level: 1).after(here())).first(default: none)
|
||||
|
||||
let current-heading = if next-heading != none and next-heading.location().page() == here().page() {
|
||||
next-heading
|
||||
} else if last-heading != none {
|
||||
last-heading
|
||||
} else { none }
|
||||
|
||||
let chapter = if current-heading != none {
|
||||
let body = current-heading.body
|
||||
if current-heading.numbering != none {
|
||||
let num = counter(heading).display(current-heading.numbering, at: current-heading.location())
|
||||
body = [#num #body]
|
||||
}
|
||||
body
|
||||
} else []
|
||||
|
||||
grid(
|
||||
columns: (1fr, auto, 1fr),
|
||||
align: (left, center, right),
|
||||
document.title, [v#version - #hash], chapter,
|
||||
)
|
||||
}
|
||||
|
||||
#let _unshift-prefix(prefix, content) = context {
|
||||
pad(left: -measure(prefix).width, prefix + content)
|
||||
}
|
||||
|
||||
#let project(
|
||||
title: none,
|
||||
author: none,
|
||||
version: "0.0.1",
|
||||
hash: "abcdefgh",
|
||||
icon-path: none,
|
||||
doc,
|
||||
) = {
|
||||
assert(title != none, message: "Please provide a title")
|
||||
|
||||
set document(
|
||||
title: title,
|
||||
author: author,
|
||||
)
|
||||
set text(
|
||||
font: "Source Sans 3",
|
||||
)
|
||||
|
||||
set raw(syntaxes: path("midas.sublime-syntax"))
|
||||
|
||||
let front-page() = {
|
||||
align(center)[
|
||||
#{
|
||||
set text(size: 1.5em)
|
||||
std.title()
|
||||
}
|
||||
|
||||
v#version - #hash
|
||||
|
||||
#if icon-path != none {
|
||||
v(1cm)
|
||||
image(icon-path)
|
||||
}
|
||||
]
|
||||
pagebreak()
|
||||
}
|
||||
|
||||
let outlines() = {
|
||||
outline()
|
||||
pagebreak()
|
||||
outline(
|
||||
title: [List of Listings],
|
||||
target: figure.where(kind: raw),
|
||||
)
|
||||
|
||||
outline(
|
||||
title: [List of Tables],
|
||||
target: figure.where(kind: table),
|
||||
)
|
||||
}
|
||||
|
||||
let main() = {
|
||||
// Adapted from https://github.com/hei-templates/hei-synd-thesis/blob/7d2b941197babae0bf3afd4e5914754e09a64001/lib/template-thesis.typ#L242-L261
|
||||
show heading.where(level: 1): it => {
|
||||
pagebreak()
|
||||
|
||||
set text(size: 1.5em)
|
||||
set block(above: 1.2em, below: 1.2em)
|
||||
if it.numbering != none {
|
||||
let num = numbering(it.numbering, ..counter(heading).at(it.location()))
|
||||
let prefix = num + h(1em)
|
||||
_unshift-prefix(prefix, it.body)
|
||||
} else {
|
||||
it
|
||||
}
|
||||
}
|
||||
|
||||
show heading.where(level: 2): it => {
|
||||
if it.numbering != none {
|
||||
let num = numbering(it.numbering, ..counter(heading).at(it.location()))
|
||||
_unshift-prefix(num + h(0.8em), it.body)
|
||||
} else {
|
||||
it
|
||||
}
|
||||
}
|
||||
|
||||
set page(
|
||||
header: context _render-header(version, hash),
|
||||
footer: context if page.numbering != none {
|
||||
align(center, counter(page).display(page.numbering, both: true))
|
||||
},
|
||||
numbering: "1 / 1",
|
||||
)
|
||||
|
||||
show heading: set heading(numbering: "I.1.")
|
||||
counter(page).update(1)
|
||||
doc
|
||||
}
|
||||
|
||||
front-page()
|
||||
outlines()
|
||||
main()
|
||||
}
|
||||
@@ -2,10 +2,6 @@
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
# Prototype of custom type import to use valid Python syntax
|
||||
import midas
|
||||
midas.using("02_custom_types.midas")
|
||||
|
||||
# A data-frame using a custom type
|
||||
df: Frame[
|
||||
location: GeoLocation
|
||||
@@ -21,7 +17,7 @@ lat + lon # Invalid operation
|
||||
# Registered operations are permitted
|
||||
lat1: Latitude = lat[0]
|
||||
lat2: Latitude = lat[1]
|
||||
lat_diff: LatitudeDiff = lat2 - lat1 # Valid operation
|
||||
lat_diff: Difference[Latitude] = lat2 - lat1 # Valid operation
|
||||
|
||||
# In addition to the type, a column can have one or more constraints, either defined inline or in a separate file
|
||||
df2: Frame[
|
||||
|
||||
73
examples/00_syntax_prototype/03_custom_types_v2.midas
Normal file
73
examples/00_syntax_prototype/03_custom_types_v2.midas
Normal file
@@ -0,0 +1,73 @@
|
||||
// Simple custom type derived from float
|
||||
type Custom(float)
|
||||
|
||||
// Simple custom types with constraints
|
||||
type Latitude(float) where (-90 <= _ <= 90)
|
||||
type Longitude(float) where (-180 <= _ <= 180)
|
||||
|
||||
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
|
||||
type Difference[T](T)
|
||||
|
||||
// Complex custom type, containing two values accessible through properties
|
||||
type GeoLocation {
|
||||
lat: Latitude
|
||||
lon: Longitude
|
||||
}
|
||||
|
||||
// Define operations on our custom type
|
||||
extend GeoLocation {
|
||||
// This type is compatible with the `-` operation with another GeoLocation
|
||||
// i.e. you can subtract a GeoLocation from another GeoLocation, resulting
|
||||
// in a Difference of GeoLocations
|
||||
op __sub__(GeoLocation) -> Difference[GeoLocation]
|
||||
}
|
||||
|
||||
// For complex generics, you need to specify how the genericity the properties
|
||||
// are handled
|
||||
type Difference[GeoLocation] {
|
||||
lat: Difference[Latitude]
|
||||
lon: Difference[Longitude]
|
||||
}
|
||||
|
||||
// Simple operation defined on our custom types
|
||||
extend Latitude {
|
||||
op __sub__(Latitude) -> Difference[Latitude]
|
||||
}
|
||||
|
||||
extend Longitude {
|
||||
op __sub__(Longitude) -> Difference[Longitude]
|
||||
}
|
||||
|
||||
// Predefined custom predicates that can be referenced in other definitions
|
||||
predicate Positive(v: float) = v >= 0
|
||||
predicate StrictlyPositive(v: float) = v > 0
|
||||
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
|
||||
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
|
||||
|
||||
type Person {
|
||||
name: str
|
||||
|
||||
// Property with an inline constraint
|
||||
age: int? where (0 <= _ < 150)
|
||||
|
||||
// Property referencing a predicate
|
||||
height: float where StrictlyPositive
|
||||
|
||||
home: GeoLocation
|
||||
}
|
||||
|
||||
// Custom complex type derived from another complex type, with a constraint
|
||||
// on a property
|
||||
// Multiple proposed syntaxes, not yet defined
|
||||
|
||||
// Explicit, but new keyword
|
||||
type EquatorialPerson refines Person where Equatorial(_.home)
|
||||
|
||||
// Explicit with existing keyword, might be confusing if expectations regarding 'is'
|
||||
type EquatorialPerson is Person where Equatorial(_.home)
|
||||
|
||||
// Consistent and Python-friendly but can be confused with structural extension
|
||||
type EquatorialPerson(Person) where Equatorial(_.home)
|
||||
|
||||
// Allow new properties, probably not useful
|
||||
type EquatorialPerson extends Person where Equatorial(_.home)
|
||||
15
examples/00_syntax_prototype/04_functions.py
Normal file
15
examples/00_syntax_prototype/04_functions.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def func(
|
||||
col1: Column[float + (0 <= _ <= 1)],
|
||||
col2: Column[float + (0 <= _ <= 1)],
|
||||
) -> Column[float + (0 <= _ <= 2)]:
|
||||
result: Column[float + (0 <= _ <= 2)] = col1 + col2
|
||||
return result
|
||||
|
||||
|
||||
def func2(a: int, /, b: float, *, c: str):
|
||||
pass
|
||||
33
examples/00_syntax_prototype/05_custom_types_v3.midas
Normal file
33
examples/00_syntax_prototype/05_custom_types_v3.midas
Normal file
@@ -0,0 +1,33 @@
|
||||
type Foo1 = float
|
||||
type Foo2 = float where (_ > 3)
|
||||
type Foo3 = int | float
|
||||
type Foo4 = int where (_ > 3) | float where (_ > 3)
|
||||
type Foo5 = (int | float) where (_ > 3)
|
||||
type Foo6 = {
|
||||
foo: float
|
||||
bar: float where (_ > 3)
|
||||
}
|
||||
|
||||
type Foo7[T] = T where (_ > 3)
|
||||
type Foo8[A, B<:int] = {
|
||||
a: A
|
||||
b: B
|
||||
}
|
||||
|
||||
type Complex = {
|
||||
a: int
|
||||
b: int
|
||||
}
|
||||
type Complex2 = Complex where (_.a > 3 & _.b < 5)
|
||||
|
||||
predicate Positive(n: int) = n >= 0
|
||||
|
||||
extend Foo1 {
|
||||
op __add__(Foo1) -> Foo1
|
||||
}
|
||||
|
||||
extend Foo7[T] {
|
||||
op __add__(Foo7[T]) -> Foo7[T]
|
||||
}
|
||||
|
||||
type Optional[T] = None | T
|
||||
13
examples/01_simple_type_checking/01_simple_operations.py
Normal file
13
examples/01_simple_type_checking/01_simple_operations.py
Normal file
@@ -0,0 +1,13 @@
|
||||
a: int = 3
|
||||
b: int = 4
|
||||
|
||||
c = a + b # -> int
|
||||
|
||||
c = "invalid" # -> can't assign str to int variable
|
||||
|
||||
d = True
|
||||
e = d + d
|
||||
|
||||
f: float = a
|
||||
|
||||
f = -f
|
||||
14
examples/01_simple_type_checking/02_simple_types.midas
Normal file
14
examples/01_simple_type_checking/02_simple_types.midas
Normal file
@@ -0,0 +1,14 @@
|
||||
type Meter = float
|
||||
type Second = float
|
||||
type MeterPerSecond = float
|
||||
|
||||
extend Meter {
|
||||
def __add__: fn(Meter, /) -> Meter
|
||||
def __sub__: fn(Meter, /) -> Meter
|
||||
def __truediv__: fn(Second, /) -> MeterPerSecond
|
||||
}
|
||||
|
||||
extend Second {
|
||||
def __add__: fn(Second, /) -> Second
|
||||
def __sub__: fn(Second, /) -> Second
|
||||
}
|
||||
6
examples/01_simple_type_checking/02_simple_types.py
Normal file
6
examples/01_simple_type_checking/02_simple_types.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
distance: Meter = cast(Meter, 123.45)
|
||||
time: Second = cast(Second, 6.7)
|
||||
speed = distance / time
|
||||
23
examples/01_simple_type_checking/03_control_flow.py
Normal file
23
examples/01_simple_type_checking/03_control_flow.py
Normal file
@@ -0,0 +1,23 @@
|
||||
def minimum(x: int, y: int):
|
||||
if x < y:
|
||||
return x
|
||||
else:
|
||||
return y
|
||||
|
||||
|
||||
a = 15
|
||||
b = 72
|
||||
c = minimum(a, b)
|
||||
|
||||
|
||||
def factorial(n: int) -> int:
|
||||
if n <= 1:
|
||||
return 1
|
||||
return n * factorial(n - 1)
|
||||
|
||||
|
||||
category = "Category 1" if a < 10 else "Category 2"
|
||||
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
21
examples/01_simple_type_checking/04_complex_types.midas
Normal file
21
examples/01_simple_type_checking/04_complex_types.midas
Normal file
@@ -0,0 +1,21 @@
|
||||
type Meter = float
|
||||
|
||||
extend Meter {
|
||||
def __add__: fn(Meter, /) -> Meter
|
||||
def __sub__: fn(Meter, /) -> Meter
|
||||
}
|
||||
|
||||
type Coordinate = object
|
||||
|
||||
extend Coordinate {
|
||||
prop x: Meter
|
||||
prop y: Meter
|
||||
}
|
||||
|
||||
type Difference[T <: float] = T
|
||||
type MeterDifference = Difference[Meter]
|
||||
|
||||
type CompDiff[T <: float] = {
|
||||
prop d1: Difference[T]
|
||||
prop d2: Difference[T]
|
||||
}
|
||||
37
examples/01_simple_type_checking/04_complex_types.py
Normal file
37
examples/01_simple_type_checking/04_complex_types.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
p1: Coordinate
|
||||
p2: Coordinate
|
||||
|
||||
diff_x = p2.x - p1.x
|
||||
diff_y = p2.y - p1.y
|
||||
|
||||
dist = diff_x + diff_y
|
||||
|
||||
p2.x += cast(Meter, 1)
|
||||
p2.y = True # invalid, wrong type
|
||||
p2.z = 3 # invalid, no property 'z' on Coordinate
|
||||
p2.x.a = 3 # invalid, no properties on Meter
|
||||
|
||||
foo: list[float] = []
|
||||
|
||||
append = foo.append
|
||||
|
||||
foo.append("") # invalid, must be float
|
||||
foo.append(2)
|
||||
append(True) # invalid, must be float
|
||||
append(2)
|
||||
|
||||
bar: list[list[Meter]]
|
||||
|
||||
bar.append([p2.x])
|
||||
|
||||
foo2 = foo + foo
|
||||
|
||||
a = foo[0]
|
||||
b = bar[0][1]
|
||||
c = bar[0][1][2] # invalid, not method __getitem__ on Meter
|
||||
c = bar[""] # invalid, wrong index type
|
||||
|
||||
d = foo[1:2]
|
||||
28
examples/01_simple_type_checking/05_functions.py
Normal file
28
examples/01_simple_type_checking/05_functions.py
Normal file
@@ -0,0 +1,28 @@
|
||||
def incr(value: int):
|
||||
return value + 1
|
||||
|
||||
|
||||
def decr(value: int):
|
||||
return value - 1
|
||||
|
||||
|
||||
def foo(a: int, /, b: float, *, c: str):
|
||||
return True
|
||||
|
||||
|
||||
r1 = foo() # foo() missing 2 required positional arguments: 'a' and 'b'
|
||||
r2 = foo(1) # foo() missing 1 required positional argument: 'b'
|
||||
r3 = foo(1, 2.0) # foo() missing 1 required keyword-only argument: 'c'
|
||||
r4 = foo(1, b=2.0) # foo() missing 1 required keyword-only argument: 'c'
|
||||
r5 = foo(1, 2.0, "test") # foo() takes 2 positional arguments but 3 were given
|
||||
r6 = foo(1, 2.0, b=3.0) # foo() got multiple values for argument 'b'
|
||||
r7 = foo(
|
||||
a=1
|
||||
) # foo() got some positional-only arguments passed as keyword arguments: 'a'
|
||||
r8 = foo(g="test") # foo() got an unexpected keyword argument 'g'
|
||||
|
||||
r9a = foo(1, 2.0, c="test")
|
||||
r9b = foo(1, b=2.0, c="test")
|
||||
r9c = foo(1, c="test", b=2.0)
|
||||
|
||||
r10 = foo("a", 3, c=False) # wrong argument types
|
||||
10
examples/01_simple_type_checking/06_overloads.midas
Normal file
10
examples/01_simple_type_checking/06_overloads.midas
Normal file
@@ -0,0 +1,10 @@
|
||||
type T1 = object
|
||||
type T2 = object
|
||||
type Foo = object
|
||||
type T2b = T2
|
||||
|
||||
extend Foo {
|
||||
def bar: fn(T1, /) -> int
|
||||
def bar: fn(T2, /) -> float
|
||||
def bar: fn(T2b, /) -> int
|
||||
}
|
||||
18
examples/01_simple_type_checking/06_overloads.py
Normal file
18
examples/01_simple_type_checking/06_overloads.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
foo: Foo
|
||||
t1: T1
|
||||
t2: T2
|
||||
|
||||
a = foo.bar(t1)
|
||||
b = foo.bar(t2)
|
||||
|
||||
func = foo.bar
|
||||
|
||||
c = func(t1)
|
||||
d = func(t2)
|
||||
|
||||
t2b: T2b
|
||||
|
||||
e = foo.bar(t2b)
|
||||
15
examples/02_demonstration/demo.midas
Normal file
15
examples/02_demonstration/demo.midas
Normal file
@@ -0,0 +1,15 @@
|
||||
predicate in_range(min: float, max: float)(v: float) = min <= v & v <= max
|
||||
predicate is_ratio = in_range(0, 1)
|
||||
|
||||
type Currency = float
|
||||
type Price[T <: Currency] = T where _ >= 0
|
||||
|
||||
extend Price[T <: Currency] {
|
||||
def __add__: fn(Price[T], /) -> Price[T]
|
||||
}
|
||||
|
||||
type EUR = Currency
|
||||
type USD = Currency
|
||||
type CHF = Currency
|
||||
|
||||
type Discount = float where is_ratio(_)
|
||||
35
examples/02_demonstration/demo.py
Normal file
35
examples/02_demonstration/demo.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import TypeVar
|
||||
|
||||
from demo_stubs import CHF, EUR, USD, Currency, Discount, Price
|
||||
|
||||
from midas.typing import cast, unsafe_cast
|
||||
|
||||
T = TypeVar("T", bound=Currency)
|
||||
|
||||
|
||||
def apply_discount(amount: Price[T], discount: Discount) -> Price[T]:
|
||||
return cast(Price[T], (1.0 - discount) * amount)
|
||||
|
||||
|
||||
a1 = cast(Price[EUR], 3.2)
|
||||
a2 = cast(Price[USD], 10.4)
|
||||
r1 = cast(Discount, 0.2)
|
||||
|
||||
print(apply_discount(a1, r1))
|
||||
print(apply_discount(a2, r1))
|
||||
|
||||
a3 = a1 + a1
|
||||
a4 = a1 + a2 # cannot add euros and dollars
|
||||
a3 = a2 # cannot change variable type
|
||||
|
||||
dyn_price = float(input("Price (CHF): "))
|
||||
dyn_discount = float(input("Discount (0.0-1.0): "))
|
||||
discounted = apply_discount(
|
||||
cast(Price[CHF], dyn_price),
|
||||
cast(Discount, dyn_discount),
|
||||
)
|
||||
|
||||
print(f"Discounted: CHF {discounted}")
|
||||
|
||||
large_data = [i * 10 for i in range(100)]
|
||||
prices = unsafe_cast(list[Price[EUR]], large_data)
|
||||
14
examples/02_demonstration/demo_stubs.pyi
Normal file
14
examples/02_demonstration/demo_stubs.pyi
Normal file
@@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
class Currency(float): ...
|
||||
|
||||
_T0 = TypeVar("_T0", bound=Currency, covariant=True)
|
||||
|
||||
class Price(Currency, Generic[_T0]):
|
||||
def __add__(self, _0: Price[_T0], /) -> Price[_T0]: ...
|
||||
|
||||
class EUR(Currency): ...
|
||||
class USD(Currency): ...
|
||||
class CHF(Currency): ...
|
||||
class Discount(float): ...
|
||||
165
gen/gen.py
Normal file
165
gen/gen.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Helper script to generate AST nodes for Midas and Python.
|
||||
|
||||
Takes in simple templates and generates full dataclasses and a visitor interface
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
HEADER = '''"""
|
||||
This file was generated by a script. Any manual changes might be overwritten.
|
||||
Please modify {defs_path} instead and run {gen_path}
|
||||
"""'''
|
||||
|
||||
SECTION_TEMPLATE = """{banner}
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class {base}(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
{visitor_methods}
|
||||
|
||||
|
||||
{classes}"""
|
||||
|
||||
TEMPLATE = """{header}
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
{imports}
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
{preamble}
|
||||
{sections}
|
||||
"""
|
||||
|
||||
VISITOR_METHOD_TEMPLATE = """
|
||||
@abstractmethod
|
||||
def visit_{func_name}(self, {param}: {cls}) -> T: ...
|
||||
"""
|
||||
|
||||
CLASS_TEMPLATE = """
|
||||
@dataclass(frozen=True)
|
||||
class {cls}({base}):
|
||||
{body}
|
||||
|
||||
def accept(self, visitor: {base}.Visitor[T]) -> T:
|
||||
return visitor.visit_{func_name}(self)
|
||||
"""
|
||||
|
||||
SECTION_REGEX = re.compile(
|
||||
r"^###>\s*(?P<base>[^\n]*?)\s*\|\s*(?P<name>[^\n]*?)(\s*\|\s*(?P<param>[^\n]*?))?\s*?\n(?P<body>.*?)\n###<$",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
IMPORTS_REGEX = re.compile(
|
||||
r"^###>\s*Imports\s*?\n(?P<body>.*?)\n###<$",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
PREAMBLE_REGEX = re.compile(
|
||||
r"^###>\s*Preamble\s*?\n(?P<body>.*?)\n###<$",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def snake_case(text: str) -> str:
|
||||
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
|
||||
|
||||
|
||||
def make_visitor_method(cls: str, param: str):
|
||||
method: str = VISITOR_METHOD_TEMPLATE.format(
|
||||
func_name=snake_case(cls), param=param, cls=cls
|
||||
)
|
||||
return method.strip("\n")
|
||||
|
||||
|
||||
def make_class(name: str, cls: str, base: str):
|
||||
body: str = cls.split("\n", 1)[1]
|
||||
func_name: str = snake_case(name)
|
||||
cls_def: str = CLASS_TEMPLATE.format(
|
||||
cls=name,
|
||||
base=base,
|
||||
body=body,
|
||||
func_name=func_name,
|
||||
)
|
||||
return cls_def.strip("\n")
|
||||
|
||||
|
||||
def make_banner(text: str) -> str:
|
||||
middle: str = f"# {text} #"
|
||||
rule: str = "#" * len(middle)
|
||||
return "\n".join((rule, middle, rule))
|
||||
|
||||
|
||||
def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
||||
print(f" Generating {full_name}")
|
||||
visitor_methods: list[str] = []
|
||||
classes: list[str] = []
|
||||
definitions: list[str] = body.strip("\n").split("\n\n\n")
|
||||
for cls in definitions:
|
||||
cls = cls.strip("\n")
|
||||
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
|
||||
print(f" Processing {name}")
|
||||
visitor_methods.append(make_visitor_method(name, param))
|
||||
classes.append(make_class(name, cls, base))
|
||||
|
||||
return SECTION_TEMPLATE.format(
|
||||
banner=make_banner(full_name),
|
||||
base=base,
|
||||
visitor_methods="\n\n".join(visitor_methods),
|
||||
classes="\n\n\n".join(classes),
|
||||
)
|
||||
|
||||
|
||||
def generate(definitions_path: Path, out_path: Path):
|
||||
print(f"Processing generating {out_path} from {definitions_path}")
|
||||
root_dir: Path = Path(__file__).parent.parent
|
||||
rel_path: Path = definitions_path.relative_to(root_dir)
|
||||
src: str = definitions_path.read_text()
|
||||
sections: list[str] = []
|
||||
|
||||
imports: str = ""
|
||||
if m := IMPORTS_REGEX.search(src):
|
||||
imports = m.group("body").strip("\n")
|
||||
|
||||
preamble: str = ""
|
||||
if m := PREAMBLE_REGEX.search(src):
|
||||
preamble = m.group("body")
|
||||
|
||||
for section_m in SECTION_REGEX.finditer(src):
|
||||
full_name: str = section_m.group("name")
|
||||
base: str = section_m.group("base")
|
||||
param: str = section_m.group("param") or base.lower()
|
||||
body: str = section_m.group("body")
|
||||
sections.append(make_section(full_name, base, param, body))
|
||||
|
||||
result: str = TEMPLATE.format(
|
||||
header=HEADER.format(
|
||||
defs_path=rel_path,
|
||||
gen_path=Path(__file__).relative_to(root_dir),
|
||||
),
|
||||
imports=imports,
|
||||
preamble=preamble,
|
||||
sections="\n\n\n".join(sections),
|
||||
)
|
||||
out_path.write_text(result)
|
||||
|
||||
|
||||
def main():
|
||||
root: Path = Path(__file__).parent.parent
|
||||
defs_dir: Path = root / "gen"
|
||||
ast_dir: Path = root / "midas" / "ast"
|
||||
generate(defs_dir / "midas.py", ast_dir / "midas.py")
|
||||
generate(defs_dir / "python.py", ast_dir / "python.py")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
170
gen/midas.py
Normal file
170
gen/midas.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821, F401]
|
||||
|
||||
###> Imports
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.lexer.token import Token
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Preamble
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TypeParam:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
|
||||
class MemberKind(Enum):
|
||||
PROPERTY = auto()
|
||||
METHOD = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
l_paren: Token
|
||||
pos: list[FunctionType.Parameter]
|
||||
mixed: list[FunctionType.Parameter]
|
||||
kw: list[FunctionType.Parameter]
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Stmt | Statements
|
||||
class TypeStmt:
|
||||
name: Token
|
||||
params: list[TypeParam]
|
||||
type: Type
|
||||
|
||||
|
||||
class AliasStmt:
|
||||
name: Token
|
||||
type: Type
|
||||
|
||||
|
||||
class MemberStmt:
|
||||
name: Token
|
||||
type: Type
|
||||
kind: MemberKind
|
||||
|
||||
|
||||
class ExtendStmt:
|
||||
name: Token
|
||||
params: list[TypeParam]
|
||||
members: list[MemberStmt]
|
||||
|
||||
|
||||
class PredicateStmt:
|
||||
name: Token
|
||||
params: list[ParamSpec]
|
||||
body: Expr
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Expr | Expressions
|
||||
|
||||
|
||||
class LogicalExpr:
|
||||
left: Expr
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
|
||||
class BinaryExpr:
|
||||
left: Expr
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
|
||||
class UnaryExpr:
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
|
||||
class CallExpr:
|
||||
callee: Expr
|
||||
arguments: list[Expr]
|
||||
keywords: dict[str, Expr]
|
||||
|
||||
|
||||
class GetExpr:
|
||||
expr: Expr
|
||||
name: Token
|
||||
|
||||
|
||||
class VariableExpr:
|
||||
name: Token
|
||||
|
||||
|
||||
class GroupingExpr:
|
||||
expr: Expr
|
||||
|
||||
|
||||
class LiteralExpr:
|
||||
value: Any
|
||||
|
||||
|
||||
class WildcardExpr:
|
||||
token: Token
|
||||
|
||||
|
||||
###<
|
||||
|
||||
###> Type | Types
|
||||
|
||||
|
||||
class NamedType:
|
||||
name: Token
|
||||
|
||||
|
||||
class GenericType:
|
||||
type: Type
|
||||
args: list[Type]
|
||||
|
||||
|
||||
class ConstraintType:
|
||||
type: Type
|
||||
constraint: Expr
|
||||
|
||||
|
||||
class ComplexType:
|
||||
members: list[MemberStmt]
|
||||
|
||||
|
||||
class ExtensionType:
|
||||
base: Type
|
||||
extension: ComplexType
|
||||
|
||||
|
||||
class FunctionType:
|
||||
params: ParamSpec
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Parameter:
|
||||
location: Optional[Location] = None
|
||||
name: Optional[Token]
|
||||
type: Type
|
||||
required: bool
|
||||
|
||||
|
||||
class FrameType:
|
||||
columns: list[Column]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Column:
|
||||
location: Optional[Location] = None
|
||||
name: Token
|
||||
type: Type
|
||||
|
||||
|
||||
###<
|
||||
192
gen/python.py
Normal file
192
gen/python.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821, F401]
|
||||
|
||||
###> Imports
|
||||
import ast
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from midas.ast.location import Location
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Preamble
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
pos: list[Function.Parameter]
|
||||
mixed: list[Function.Parameter]
|
||||
kw: list[Function.Parameter]
|
||||
|
||||
@property
|
||||
def all(self) -> list[Function.Parameter]:
|
||||
return self.pos + self.mixed + self.kw
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> MidasType | Type annotations | node
|
||||
class BaseType:
|
||||
base: str
|
||||
args: tuple[MidasType, ...]
|
||||
|
||||
|
||||
class ConstraintType:
|
||||
type: MidasType
|
||||
constraint: ast.expr
|
||||
|
||||
|
||||
class FrameColumn:
|
||||
name: Optional[str]
|
||||
type: Optional[MidasType]
|
||||
|
||||
|
||||
class FrameType:
|
||||
columns: list[FrameColumn]
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Stmt | Statements
|
||||
class ExpressionStmt:
|
||||
expr: Expr
|
||||
|
||||
|
||||
class Function:
|
||||
name: str
|
||||
params: ParamSpec
|
||||
returns: Optional[MidasType]
|
||||
body: list[Stmt]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Parameter:
|
||||
location: Optional[Location] = None
|
||||
name: str
|
||||
type: Optional[MidasType]
|
||||
default: Optional[Expr]
|
||||
|
||||
|
||||
class TypeAssign:
|
||||
name: str
|
||||
type: MidasType
|
||||
|
||||
|
||||
class AssignStmt:
|
||||
targets: list[Expr]
|
||||
value: Expr
|
||||
|
||||
|
||||
class ReturnStmt:
|
||||
value: Optional[Expr]
|
||||
|
||||
|
||||
class IfStmt:
|
||||
test: Expr
|
||||
body: list[Stmt]
|
||||
orelse: list[Stmt]
|
||||
|
||||
|
||||
class Pass:
|
||||
pass
|
||||
|
||||
|
||||
class ForStmt:
|
||||
target: Expr
|
||||
iterator: Expr
|
||||
body: list[Stmt]
|
||||
|
||||
|
||||
class RawStmt:
|
||||
stmt: ast.stmt
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Expr | Expressions
|
||||
class BinaryExpr:
|
||||
left: Expr
|
||||
operator: ast.operator
|
||||
right: Expr
|
||||
|
||||
|
||||
class CompareExpr:
|
||||
left: Expr
|
||||
operator: ast.cmpop
|
||||
right: Expr
|
||||
|
||||
|
||||
class UnaryExpr:
|
||||
operator: ast.unaryop
|
||||
right: Expr
|
||||
|
||||
|
||||
class CallExpr:
|
||||
callee: Expr
|
||||
arguments: list[Expr]
|
||||
keywords: dict[str, Expr]
|
||||
|
||||
|
||||
class GetExpr:
|
||||
object: Expr
|
||||
name: str
|
||||
|
||||
|
||||
class LiteralExpr:
|
||||
value: Any
|
||||
|
||||
|
||||
class VariableExpr:
|
||||
name: str
|
||||
|
||||
|
||||
class LogicalExpr:
|
||||
left: Expr
|
||||
operator: ast.boolop
|
||||
right: Expr
|
||||
|
||||
|
||||
class CastExpr:
|
||||
type: MidasType
|
||||
expr: Expr
|
||||
unsafe: bool
|
||||
|
||||
|
||||
class TernaryExpr:
|
||||
test: Expr
|
||||
if_true: Expr
|
||||
if_false: Expr
|
||||
|
||||
|
||||
class ListExpr:
|
||||
items: list[Expr]
|
||||
|
||||
|
||||
class DictExpr:
|
||||
keys: list[Optional[Expr]]
|
||||
values: list[Expr]
|
||||
|
||||
|
||||
class SubscriptExpr:
|
||||
object: Expr
|
||||
index: Expr
|
||||
|
||||
|
||||
class SliceExpr:
|
||||
lower: Optional[Expr]
|
||||
upper: Optional[Expr]
|
||||
step: Optional[Expr]
|
||||
|
||||
|
||||
class TupleExpr:
|
||||
items: tuple[Expr, ...]
|
||||
|
||||
|
||||
class RawExpr:
|
||||
expr: ast.expr
|
||||
|
||||
|
||||
###<
|
||||
@@ -1,102 +0,0 @@
|
||||
from lexer.base import Lexer
|
||||
from lexer.keyword import ANNOTATION_KEYWORDS
|
||||
from lexer.token import TokenType
|
||||
|
||||
|
||||
class AnnotationLexer(Lexer):
|
||||
def scan_token(self) -> None:
|
||||
char: str = self.advance()
|
||||
match char:
|
||||
case "(":
|
||||
self.add_token(TokenType.LEFT_PAREN)
|
||||
case ")":
|
||||
self.add_token(TokenType.RIGHT_PAREN)
|
||||
case "[":
|
||||
self.add_token(TokenType.LEFT_BRACKET)
|
||||
case "]":
|
||||
self.add_token(TokenType.RIGHT_BRACKET)
|
||||
case "<":
|
||||
self.add_token(
|
||||
TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS
|
||||
)
|
||||
case ">":
|
||||
self.add_token(
|
||||
TokenType.GREATER_EQUAL if self.match("=") else TokenType.GREATER
|
||||
)
|
||||
case "=":
|
||||
self.add_token(
|
||||
TokenType.EQUAL_EQUAL if self.match("=") else TokenType.EQUAL
|
||||
)
|
||||
case "!":
|
||||
if self.match("="):
|
||||
self.add_token(TokenType.BANG_EQUAL)
|
||||
else:
|
||||
self.error("Unexpected single bang. Did you mean '!=' ?")
|
||||
case ":":
|
||||
self.add_token(TokenType.COLON)
|
||||
case ",":
|
||||
self.add_token(TokenType.COMMA)
|
||||
case "_":
|
||||
self.add_token(TokenType.UNDERSCORE)
|
||||
case "+":
|
||||
self.add_token(TokenType.PLUS)
|
||||
case "#":
|
||||
self.scan_comment()
|
||||
case "\n":
|
||||
self.add_token(TokenType.NEWLINE)
|
||||
case " " | "\r" | "\t":
|
||||
# Consume all whitespace characters until EOL or EOF
|
||||
while (
|
||||
self.peek().isspace()
|
||||
and self.peek() != "\n"
|
||||
and not self.is_at_end()
|
||||
):
|
||||
self.advance()
|
||||
self.add_token(TokenType.WHITESPACE)
|
||||
case _:
|
||||
if char.isdigit():
|
||||
self.scan_number()
|
||||
elif char.isalpha():
|
||||
self.scan_identifier()
|
||||
else:
|
||||
self.error("Unexpected character")
|
||||
return None
|
||||
|
||||
def scan_number(self):
|
||||
"""Scan the rest of number and add it as a token
|
||||
|
||||
This method handles both simple integers and floats. Scientific notation
|
||||
and base prefixes (0x, 0b, 0o) are not supported
|
||||
"""
|
||||
while self.peek().isdigit():
|
||||
self.advance()
|
||||
|
||||
if self.peek() == "." and self.peek_next().isdigit():
|
||||
self.advance()
|
||||
while self.peek().isdigit():
|
||||
self.advance()
|
||||
|
||||
value: float = float(self.source[self.start : self.idx])
|
||||
self.add_token(TokenType.NUMBER, value)
|
||||
|
||||
def scan_identifier(self):
|
||||
"""Scan the rest of an identifier and add it as a token
|
||||
|
||||
An identifier starts with a letter, followed by any number of
|
||||
alphanumerical characters or underscores
|
||||
"""
|
||||
while self.peek().isalnum() or self.peek() == "_":
|
||||
self.advance()
|
||||
|
||||
lexeme: str = self.source[self.start : self.idx]
|
||||
token_type: TokenType = ANNOTATION_KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
|
||||
self.add_token(token_type)
|
||||
|
||||
def scan_comment(self):
|
||||
"""Scan the rest of a comment and add it as a token
|
||||
|
||||
A comment starts with a `#` character and ends at the EOL/EOF
|
||||
"""
|
||||
while self.peek() != "\n" and not self.is_at_end():
|
||||
self.advance()
|
||||
self.add_token(TokenType.COMMENT)
|
||||
@@ -1,16 +0,0 @@
|
||||
from lexer.token import TokenType
|
||||
|
||||
ANNOTATION_KEYWORDS: dict[str, TokenType] = {
|
||||
"True": TokenType.TRUE,
|
||||
"False": TokenType.FALSE,
|
||||
"None": TokenType.NONE,
|
||||
}
|
||||
|
||||
MIDAS_KEYWORDS: dict[str, TokenType] = {
|
||||
"type": TokenType.TYPE,
|
||||
"op": TokenType.OP,
|
||||
"constraint": TokenType.CONSTRAINT,
|
||||
"true": TokenType.TRUE,
|
||||
"false": TokenType.FALSE,
|
||||
"none": TokenType.NONE,
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Any
|
||||
|
||||
from lexer.position import Position
|
||||
|
||||
|
||||
class TokenType(Enum):
|
||||
# Punctuation
|
||||
LEFT_PAREN = auto()
|
||||
RIGHT_PAREN = auto()
|
||||
LEFT_BRACKET = auto()
|
||||
RIGHT_BRACKET = auto()
|
||||
LEFT_BRACE = auto()
|
||||
RIGHT_BRACE = auto()
|
||||
COLON = auto()
|
||||
COMMA = auto()
|
||||
UNDERSCORE = auto()
|
||||
|
||||
# Operators
|
||||
PLUS = auto()
|
||||
MINUS = auto()
|
||||
STAR = auto()
|
||||
SLASH = auto()
|
||||
GREATER = auto()
|
||||
GREATER_EQUAL = auto()
|
||||
LESS = auto()
|
||||
LESS_EQUAL = auto()
|
||||
EQUAL = auto()
|
||||
EQUAL_EQUAL = auto()
|
||||
BANG_EQUAL = auto()
|
||||
|
||||
# Literals
|
||||
IDENTIFIER = auto()
|
||||
NUMBER = auto()
|
||||
TRUE = auto()
|
||||
FALSE = auto()
|
||||
NONE = auto()
|
||||
|
||||
# Keywords
|
||||
TYPE = auto()
|
||||
OP = auto()
|
||||
CONSTRAINT = auto()
|
||||
|
||||
# Misc
|
||||
COMMENT = auto()
|
||||
WHITESPACE = auto()
|
||||
EOF = auto()
|
||||
NEWLINE = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Token:
|
||||
"""A scanned token"""
|
||||
|
||||
type: TokenType
|
||||
lexeme: str
|
||||
value: Any
|
||||
position: Position
|
||||
49
midas/ast/location.py
Normal file
49
midas/ast/location.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Protocol
|
||||
|
||||
|
||||
class HasLocation(Protocol):
|
||||
lineno: int
|
||||
col_offset: int
|
||||
end_lineno: Optional[int]
|
||||
end_col_offset: Optional[int]
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Location:
|
||||
"""Information about the location of an AST node"""
|
||||
|
||||
lineno: int
|
||||
col_offset: int
|
||||
end_lineno: Optional[int]
|
||||
end_col_offset: Optional[int]
|
||||
|
||||
@staticmethod
|
||||
def from_ast(obj: HasLocation) -> Location:
|
||||
return Location(
|
||||
lineno=obj.lineno,
|
||||
col_offset=obj.col_offset,
|
||||
end_lineno=obj.end_lineno,
|
||||
end_col_offset=obj.end_col_offset,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def span(start: Location, end: Location) -> Location:
|
||||
"""Create a new location spanning from one location to another
|
||||
|
||||
Args:
|
||||
start (Location): the starting location
|
||||
end (Location): the end location
|
||||
|
||||
Returns:
|
||||
Location: a new location spanning from the start of `start`
|
||||
to the end of `end`
|
||||
"""
|
||||
return Location(
|
||||
lineno=start.lineno,
|
||||
col_offset=start.col_offset,
|
||||
end_lineno=end.lineno,
|
||||
end_col_offset=end.end_col_offset,
|
||||
)
|
||||
342
midas/ast/midas.py
Normal file
342
midas/ast/midas.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
This file was generated by a script. Any manual changes might be overwritten.
|
||||
Please modify gen/midas.py instead and run gen/gen.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.lexer.token import Token
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TypeParam:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
|
||||
class MemberKind(Enum):
|
||||
PROPERTY = auto()
|
||||
METHOD = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
l_paren: Token
|
||||
pos: list[FunctionType.Parameter]
|
||||
mixed: list[FunctionType.Parameter]
|
||||
kw: list[FunctionType.Parameter]
|
||||
|
||||
|
||||
##############
|
||||
# Statements #
|
||||
##############
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Stmt(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_alias_stmt(self, stmt: AliasStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_member_stmt(self, stmt: MemberStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_extend_stmt(self, stmt: ExtendStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_predicate_stmt(self, stmt: PredicateStmt) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeStmt(Stmt):
|
||||
name: Token
|
||||
params: list[TypeParam]
|
||||
type: Type
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_type_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AliasStmt(Stmt):
|
||||
name: Token
|
||||
type: Type
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_alias_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemberStmt(Stmt):
|
||||
name: Token
|
||||
type: Type
|
||||
kind: MemberKind
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_member_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExtendStmt(Stmt):
|
||||
name: Token
|
||||
params: list[TypeParam]
|
||||
members: list[MemberStmt]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_extend_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PredicateStmt(Stmt):
|
||||
name: Token
|
||||
params: list[ParamSpec]
|
||||
body: Expr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_predicate_stmt(self)
|
||||
|
||||
|
||||
###############
|
||||
# Expressions #
|
||||
###############
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Expr(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_binary_expr(self, expr: BinaryExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_call_expr(self, expr: CallExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_get_expr(self, expr: GetExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_variable_expr(self, expr: VariableExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_grouping_expr(self, expr: GroupingExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LogicalExpr(Expr):
|
||||
left: Expr
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_logical_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BinaryExpr(Expr):
|
||||
left: Expr
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_binary_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UnaryExpr(Expr):
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_unary_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CallExpr(Expr):
|
||||
callee: Expr
|
||||
arguments: list[Expr]
|
||||
keywords: dict[str, Expr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_call_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GetExpr(Expr):
|
||||
expr: Expr
|
||||
name: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_get_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VariableExpr(Expr):
|
||||
name: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_variable_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GroupingExpr(Expr):
|
||||
expr: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_grouping_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LiteralExpr(Expr):
|
||||
value: Any
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_literal_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WildcardExpr(Expr):
|
||||
token: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_wildcard_expr(self)
|
||||
|
||||
|
||||
#########
|
||||
# Types #
|
||||
#########
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Type(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_named_type(self, type: NamedType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_generic_type(self, type: GenericType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint_type(self, type: ConstraintType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_complex_type(self, type: ComplexType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_extension_type(self, type: ExtensionType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_function_type(self, type: FunctionType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_frame_type(self, type: FrameType) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NamedType(Type):
|
||||
name: Token
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_named_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenericType(Type):
|
||||
type: Type
|
||||
args: list[Type]
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_generic_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstraintType(Type):
|
||||
type: Type
|
||||
constraint: Expr
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ComplexType(Type):
|
||||
members: list[MemberStmt]
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_complex_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExtensionType(Type):
|
||||
base: Type
|
||||
extension: ComplexType
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_extension_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionType(Type):
|
||||
params: ParamSpec
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Parameter:
|
||||
location: Optional[Location] = None
|
||||
name: Optional[Token]
|
||||
type: Type
|
||||
required: bool
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_function_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FrameType(Type):
|
||||
columns: list[Column]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Column:
|
||||
location: Optional[Location] = None
|
||||
name: Token
|
||||
type: Type
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_frame_type(self)
|
||||
3
midas/ast/printer/__init__.py
Normal file
3
midas/ast/printer/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .midas import MidasPrinter as MidasPrinter
|
||||
from .midas_ast import MidasAstPrinter as MidasAstPrinter
|
||||
from .python_ast import PythonAstPrinter as PythonAstPrinter
|
||||
103
midas/ast/printer/base.py
Normal file
103
midas/ast/printer/base.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, Generator, Generic, Optional, Protocol, Sequence, TypeVar
|
||||
|
||||
|
||||
class _Level(Enum):
|
||||
EMPTY = auto()
|
||||
ACTIVE = auto()
|
||||
LAST = auto()
|
||||
|
||||
|
||||
class Expr(Protocol):
|
||||
def accept(self, printer: AstPrinter) -> None: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Expr)
|
||||
|
||||
|
||||
class AstPrinter(Generic[T]):
|
||||
LAST_CHILD = "└── "
|
||||
CHILD = "├── "
|
||||
VERTICAL = "│ "
|
||||
EMPTY = " "
|
||||
|
||||
def __init__(self):
|
||||
self._levels: list[_Level] = []
|
||||
self._idx: Optional[int] = None
|
||||
self._buf: io.StringIO = io.StringIO()
|
||||
|
||||
def print(self, expr: T):
|
||||
self._buf = io.StringIO()
|
||||
expr.accept(self)
|
||||
return self._buf.getvalue()
|
||||
|
||||
@contextmanager
|
||||
def _child_level(self, single: bool = False) -> Generator[None, None, None]:
|
||||
self._levels.append(_Level.LAST if single else _Level.ACTIVE)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._levels.pop()
|
||||
|
||||
def _mark_last(self):
|
||||
if self._levels:
|
||||
self._levels[-1] = _Level.LAST
|
||||
|
||||
def _write_line(self, text: str, *, last: bool = False):
|
||||
if last:
|
||||
self._mark_last()
|
||||
indent: str = self._build_indent()
|
||||
if self._idx is not None:
|
||||
text = f"[{self._idx}] {text}"
|
||||
self._idx = None
|
||||
self._buf.write(indent + text + "\n")
|
||||
|
||||
def _build_indent(self) -> str:
|
||||
parts: list[str] = []
|
||||
for level in self._levels[:-1]:
|
||||
parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL)
|
||||
if self._levels:
|
||||
if self._levels[-1] == _Level.LAST:
|
||||
parts.append(self.LAST_CHILD)
|
||||
self._levels[-1] = _Level.EMPTY
|
||||
else:
|
||||
parts.append(self.CHILD)
|
||||
return "".join(parts)
|
||||
|
||||
def _write_optional_child(
|
||||
self, label: str, child: Optional[T], *, last: bool = False
|
||||
):
|
||||
if last:
|
||||
self._mark_last()
|
||||
if child is None:
|
||||
self._write_line(f"{label}: None")
|
||||
else:
|
||||
self._write_line(label)
|
||||
with self._child_level(single=True):
|
||||
child.accept(self)
|
||||
|
||||
def _write_sequence(
|
||||
self,
|
||||
label: str,
|
||||
list_: Sequence[T],
|
||||
*,
|
||||
last: bool = False,
|
||||
print_func: Optional[Callable[[T], None]] = None,
|
||||
):
|
||||
if last:
|
||||
self._mark_last()
|
||||
|
||||
self._write_line(label)
|
||||
with self._child_level():
|
||||
for i, item in enumerate(list_):
|
||||
self._idx = i
|
||||
if i == len(list_) - 1:
|
||||
self._mark_last()
|
||||
if print_func is not None:
|
||||
print_func(item)
|
||||
else:
|
||||
item.accept(self)
|
||||
183
midas/ast/printer/midas.py
Normal file
183
midas/ast/printer/midas.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import midas.ast.midas as m
|
||||
|
||||
|
||||
class MidasPrinter(
|
||||
m.Expr.Visitor[str],
|
||||
m.Stmt.Visitor[str],
|
||||
m.Type.Visitor[str],
|
||||
):
|
||||
def __init__(self, indent: int = 4):
|
||||
self.indent: int = indent
|
||||
self.level: int = 0
|
||||
|
||||
def indented(self, text: str) -> str:
|
||||
return " " * (self.level * self.indent) + text
|
||||
|
||||
def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
|
||||
self.level = 0
|
||||
return expr.accept(self)
|
||||
|
||||
# Statements
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
|
||||
template: str = ""
|
||||
if len(stmt.params) != 0:
|
||||
params: list[str] = [self._print_type_param(param) for param in stmt.params]
|
||||
template = f"[{', '.join(params)}]"
|
||||
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
||||
return self.indented(res)
|
||||
|
||||
def visit_alias_stmt(self, stmt: m.AliasStmt) -> str:
|
||||
return self.indented(f"alias {stmt.name.lexeme} = {stmt.type.accept(self)}")
|
||||
|
||||
def _print_type_param(self, param: m.TypeParam) -> str:
|
||||
res: str = param.name.lexeme
|
||||
if param.bound is not None:
|
||||
res += "<:" + param.bound.accept(self)
|
||||
return res
|
||||
|
||||
def visit_member_stmt(self, stmt: m.MemberStmt):
|
||||
keyword: str = {
|
||||
m.MemberKind.PROPERTY: "prop",
|
||||
m.MemberKind.METHOD: "def",
|
||||
}.get(stmt.kind, "")
|
||||
res: str = f"{keyword} {stmt.name.lexeme}: {stmt.type.accept(self)}"
|
||||
return self.indented(res)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt):
|
||||
template: str = ""
|
||||
if len(stmt.params) != 0:
|
||||
params: list[str] = [self._print_type_param(param) for param in stmt.params]
|
||||
template = f"[{', '.join(params)}]"
|
||||
res: str = self.indented(f"extend {stmt.name.lexeme}{template}")
|
||||
res += " {\n"
|
||||
self.level += 1
|
||||
for member in stmt.members:
|
||||
res += member.accept(self) + "\n"
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
name: str = stmt.name.lexeme
|
||||
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
|
||||
body: str = stmt.body.accept(self)
|
||||
return self.indented(f"predicate {name}{sig} = {body}")
|
||||
|
||||
# Expressions
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{left} {operator} {right}"
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{left} {operator} {right}"
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr):
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{operator}{right}"
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> str:
|
||||
args: list[str] = [arg.accept(self) for arg in expr.arguments] + [
|
||||
f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items()
|
||||
]
|
||||
return f"{expr.callee.accept(self)}({', '.join(args)})"
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
expr_: str = expr.expr.accept(self)
|
||||
name: str = expr.name.lexeme
|
||||
return f"{expr_}.{name}"
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr):
|
||||
return expr.name.lexeme
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr):
|
||||
expr_: str = expr.expr.accept(self)
|
||||
return f"({expr_})"
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr):
|
||||
return str(expr.value)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr):
|
||||
return "_"
|
||||
|
||||
# Types
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> str:
|
||||
return type.name.lexeme
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
if len(type.args) != 0:
|
||||
args: list[str] = [param.accept(self) for param in type.args]
|
||||
res += f"[{', '.join(args)}]"
|
||||
return res
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
res += " where " + type.constraint.accept(self)
|
||||
return res
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> str:
|
||||
res: str = "{\n"
|
||||
self.level += 1
|
||||
for member in type.members:
|
||||
res += member.accept(self)
|
||||
res += "\n"
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_extension_type(self, type: m.ExtensionType) -> str:
|
||||
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||
spec: str = self._visit_param_spec(type.params)
|
||||
return f"fn {spec} -> {type.returns.accept(self)}"
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
|
||||
pos: list[str] = [self._print_param(param) for param in spec.pos]
|
||||
mixed: list[str] = [self._print_param(param) for param in spec.mixed]
|
||||
kw: list[str] = [self._print_param(param) for param in spec.kw]
|
||||
params: list[str] = pos
|
||||
|
||||
if len(pos) != 0:
|
||||
params.append("/")
|
||||
params += mixed
|
||||
if len(kw) != 0:
|
||||
params.append("*")
|
||||
params += kw
|
||||
return f"({', '.join(params)})"
|
||||
|
||||
def _print_param(self, param: m.FunctionType.Parameter) -> str:
|
||||
res: str = ""
|
||||
if param.name is not None:
|
||||
res += param.name.lexeme
|
||||
res += ": "
|
||||
res += param.type.accept(self)
|
||||
if not param.required:
|
||||
res += "?"
|
||||
return res
|
||||
|
||||
def visit_frame_type(self, type: m.FrameType) -> str:
|
||||
res: str = self.indented("Frame[")
|
||||
if len(type.columns) != 0:
|
||||
res += "\n"
|
||||
self.level += 1
|
||||
columns: list[str] = []
|
||||
for column in type.columns:
|
||||
columns.append(self.indented(self._print_frame_column(column)))
|
||||
res += ",\n".join(columns)
|
||||
self.level -= 1
|
||||
res += "\n"
|
||||
res += "]"
|
||||
return res
|
||||
|
||||
def _print_frame_column(self, column: m.FrameType.Column) -> str:
|
||||
return f"{column.name.lexeme}: {column.type.accept(self)}"
|
||||
253
midas/ast/printer/midas_ast.py
Normal file
253
midas/ast/printer/midas_ast.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import midas.ast.midas as m
|
||||
from midas.ast.printer.base import AstPrinter
|
||||
|
||||
|
||||
class MidasAstPrinter(
|
||||
AstPrinter,
|
||||
m.Expr.Visitor[None],
|
||||
m.Stmt.Visitor[None],
|
||||
m.Type.Visitor[None],
|
||||
):
|
||||
# Statements
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
self._write_line("TypeStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_sequence(
|
||||
"params",
|
||||
stmt.params,
|
||||
print_func=self._print_type_param,
|
||||
)
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
|
||||
self._write_line("AliasStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def _print_type_param(self, param: m.TypeParam) -> None:
|
||||
self._write_line("Param")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{param.name.lexeme}"')
|
||||
self._write_optional_child("bound", param.bound, last=True)
|
||||
|
||||
def visit_member_stmt(self, stmt: m.MemberStmt):
|
||||
self._write_line("MemberStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f"kind: {stmt.kind.name}")
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self._write_line("ExtendStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_sequence(
|
||||
"params",
|
||||
stmt.params,
|
||||
print_func=self._print_type_param,
|
||||
)
|
||||
self._write_sequence("members", stmt.members, last=True)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
self._write_line("PredicateStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_sequence(
|
||||
"params",
|
||||
stmt.params,
|
||||
print_func=self._visit_param_spec,
|
||||
)
|
||||
self._write_line("body", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.body.accept(self)
|
||||
|
||||
# Expressions
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
self._write_line("LogicalExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr):
|
||||
self._write_line("BinaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr):
|
||||
self._write_line("UnaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> None:
|
||||
self._write_line("CallExpr")
|
||||
with self._child_level():
|
||||
self._write_line("callee")
|
||||
with self._child_level(single=True):
|
||||
expr.callee.accept(self)
|
||||
self._write_sequence("arguments", expr.arguments)
|
||||
self._write_line("keywords", last=True)
|
||||
with self._child_level():
|
||||
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||
self._idx = i
|
||||
if i == len(expr.keywords) - 1:
|
||||
self._mark_last()
|
||||
self._write_line(name)
|
||||
with self._child_level(single=True):
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
self._write_line("GetExpr")
|
||||
with self._child_level():
|
||||
self._write_line("expr")
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr):
|
||||
self._write_line("VariableExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr):
|
||||
self._write_line("GroupingExpr")
|
||||
with self._child_level():
|
||||
self._write_line("expr", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"value: {expr.value}", last=True)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||
self._write_line("WildcardExpr")
|
||||
|
||||
# Types
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> None:
|
||||
self._write_line("NamedType")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{type.name.lexeme}"', last=True)
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> None:
|
||||
self._write_line("GenericType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level():
|
||||
type.type.accept(self)
|
||||
self._write_sequence("args", type.args, last=True)
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
||||
self._write_line("ConstraintType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
type.type.accept(self)
|
||||
self._write_line("constraint", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.constraint.accept(self)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> None:
|
||||
self._write_line("ComplexType")
|
||||
with self._child_level():
|
||||
self._write_sequence("members", type.members, last=True)
|
||||
|
||||
def visit_extension_type(self, type: m.ExtensionType) -> None:
|
||||
self._write_line("ExtensionType")
|
||||
with self._child_level():
|
||||
self._write_line("base")
|
||||
with self._child_level(single=True):
|
||||
type.base.accept(self)
|
||||
self._write_line("extension", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.extension.accept(self)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||
self._write_line("FunctionType")
|
||||
with self._child_level():
|
||||
self._write_line("params")
|
||||
with self._child_level(single=True):
|
||||
self._visit_param_spec(type.params)
|
||||
|
||||
self._write_line("returns", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.returns.accept(self)
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
|
||||
self._write_line("ParamSpec")
|
||||
with self._child_level():
|
||||
self._write_sequence(
|
||||
"pos",
|
||||
spec.pos,
|
||||
print_func=self._print_param,
|
||||
)
|
||||
self._write_sequence(
|
||||
"mixed",
|
||||
spec.mixed,
|
||||
print_func=self._print_param,
|
||||
)
|
||||
self._write_sequence(
|
||||
"kw",
|
||||
spec.kw,
|
||||
print_func=self._print_param,
|
||||
last=True,
|
||||
)
|
||||
|
||||
def _print_param(self, param: m.FunctionType.Parameter) -> None:
|
||||
self._write_line("Parameter")
|
||||
with self._child_level():
|
||||
name: str = "None"
|
||||
if param.name is not None:
|
||||
name = f'"{param.name.lexeme}"'
|
||||
self._write_line(f"name: {name}")
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
param.type.accept(self)
|
||||
self._write_line(f"required: {param.required}", last=True)
|
||||
|
||||
def visit_frame_type(self, type: m.FrameType) -> None:
|
||||
self._write_line("FrameType")
|
||||
with self._child_level(single=True):
|
||||
self._write_sequence(
|
||||
"columns",
|
||||
type.columns,
|
||||
print_func=self._print_frame_column,
|
||||
)
|
||||
|
||||
def _print_frame_column(self, column: m.FrameType.Column) -> None:
|
||||
self._write_line("Column")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{column.name.lexeme}"')
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
column.type.accept(self)
|
||||
285
midas/ast/printer/python_ast.py
Normal file
285
midas/ast/printer/python_ast.py
Normal file
@@ -0,0 +1,285 @@
|
||||
import ast
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.printer.base import AstPrinter
|
||||
|
||||
|
||||
class PythonAstPrinter(
|
||||
AstPrinter,
|
||||
p.MidasType.Visitor[None],
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[None],
|
||||
):
|
||||
# Types
|
||||
|
||||
def visit_base_type(self, node: p.BaseType) -> None:
|
||||
self._write_line("BaseType")
|
||||
with self._child_level():
|
||||
self._write_line(f"base: {node.base}")
|
||||
self._write_sequence("args", node.args, last=True)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||
self._write_line("ConstraintType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
node.type.accept(self)
|
||||
self._write_line(f"constraint: {ast.unparse(node.constraint)}", last=True)
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> None:
|
||||
self._write_line("FrameColumn")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {node.name}")
|
||||
self._write_optional_child("type", node.type, last=True)
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> None:
|
||||
self._write_line("FrameType")
|
||||
with self._child_level(single=True):
|
||||
self._write_sequence("columns", node.columns)
|
||||
|
||||
# Statements
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
stmt.expr.accept(self)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
self._write_line("Function")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {stmt.name}")
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
self._print_param_spec(stmt.params)
|
||||
|
||||
self._write_optional_child("returns", stmt.returns)
|
||||
self._write_sequence("body", stmt.body, last=True)
|
||||
|
||||
def _print_param_spec(self, spec: p.ParamSpec) -> None:
|
||||
self._write_line("ParamSpec")
|
||||
with self._child_level():
|
||||
self._write_sequence(
|
||||
"pos",
|
||||
spec.pos,
|
||||
print_func=self._print_param,
|
||||
)
|
||||
self._write_sequence(
|
||||
"mixed",
|
||||
spec.mixed,
|
||||
print_func=self._print_param,
|
||||
)
|
||||
self._write_sequence(
|
||||
"kw",
|
||||
spec.kw,
|
||||
print_func=self._print_param,
|
||||
last=True,
|
||||
)
|
||||
|
||||
def _print_param(self, param: p.Function.Parameter) -> None:
|
||||
self._write_line("Parameter")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {param.name}")
|
||||
self._write_optional_child("type", param.type, last=True)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
self._write_line("TypeAssign")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {stmt.name}")
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
self._write_line("AssignStmt")
|
||||
with self._child_level():
|
||||
self._write_sequence("targets", stmt.targets)
|
||||
self._write_line("value", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.value.accept(self)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
self._write_line("ReturnStmt")
|
||||
with self._child_level():
|
||||
self._write_optional_child("value", stmt.value, last=True)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
self._write_line("IfStmt")
|
||||
with self._child_level():
|
||||
self._write_line("test")
|
||||
with self._child_level(single=True):
|
||||
stmt.test.accept(self)
|
||||
self._write_sequence("body", stmt.body)
|
||||
self._write_sequence("orelse", stmt.orelse, last=True)
|
||||
|
||||
def visit_pass(self, stmt: p.Pass) -> None:
|
||||
self._write_line("Pass")
|
||||
|
||||
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
|
||||
self._write_line("ForStmt")
|
||||
with self._child_level():
|
||||
self._write_line("target")
|
||||
with self._child_level(single=True):
|
||||
stmt.target.accept(self)
|
||||
self._write_line("iterator")
|
||||
with self._child_level(single=True):
|
||||
stmt.iterator.accept(self)
|
||||
self._write_sequence("body", stmt.body, last=True)
|
||||
|
||||
def visit_raw_stmt(self, stmt: p.RawStmt) -> None:
|
||||
self._write_line("RawStmt")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"stmt: {ast.unparse(stmt.stmt)}")
|
||||
|
||||
# Expressions
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||
self._write_line("BinaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
|
||||
self._write_line("CompareExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
|
||||
self._write_line("UnaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||
self._write_line("CallExpr")
|
||||
with self._child_level():
|
||||
self._write_line("callee")
|
||||
with self._child_level(single=True):
|
||||
expr.callee.accept(self)
|
||||
|
||||
self._write_sequence("arguments", expr.arguments)
|
||||
self._write_line("keywords", last=True)
|
||||
with self._child_level():
|
||||
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||
self._idx = i
|
||||
if i == len(expr.keywords) - 1:
|
||||
self._mark_last()
|
||||
self._write_line(name)
|
||||
with self._child_level(single=True):
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> None:
|
||||
self._write_line("GetExpr")
|
||||
with self._child_level():
|
||||
self._write_line("object")
|
||||
with self._child_level(single=True):
|
||||
expr.object.accept(self)
|
||||
self._write_line(f"name: {expr.name}", last=True)
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"value: {expr.value!r}")
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
|
||||
self._write_line("VariableExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"name: {expr.name}")
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
|
||||
self._write_line("LogicalExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||
self._write_line("CastExpr")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
expr.type.accept(self)
|
||||
self._write_line("expr")
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
self._write_line(f"unsafe: {expr.unsafe}", last=True)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||
self._write_line("TernaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("test")
|
||||
with self._child_level(single=True):
|
||||
expr.test.accept(self)
|
||||
|
||||
self._write_line("if_true")
|
||||
with self._child_level(single=True):
|
||||
expr.if_true.accept(self)
|
||||
|
||||
self._write_line("if_false", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.if_false.accept(self)
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||
self._write_line("ListExpr")
|
||||
with self._child_level():
|
||||
self._write_sequence("items", expr.items, last=True)
|
||||
|
||||
def visit_dict_expr(self, expr: p.DictExpr) -> None:
|
||||
self._write_line("DictExpr")
|
||||
with self._child_level():
|
||||
self._write_sequence(
|
||||
"keys",
|
||||
expr.keys,
|
||||
print_func=lambda k: (
|
||||
self._write_line("None") if k is None else k.accept(self)
|
||||
),
|
||||
)
|
||||
self._write_sequence("values", expr.values, last=True)
|
||||
|
||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
||||
self._write_line("SubscriptExpr")
|
||||
with self._child_level():
|
||||
self._write_line("object")
|
||||
with self._child_level(single=True):
|
||||
expr.object.accept(self)
|
||||
self._write_line("index", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.index.accept(self)
|
||||
|
||||
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
|
||||
self._write_line("SliceExpr")
|
||||
with self._child_level():
|
||||
self._write_optional_child("lower", expr.lower)
|
||||
self._write_optional_child("upper", expr.upper)
|
||||
self._write_optional_child("step", expr.step, last=True)
|
||||
|
||||
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
|
||||
self._write_line("TupleExpr")
|
||||
with self._child_level():
|
||||
self._write_sequence("items", expr.items, last=True)
|
||||
|
||||
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
||||
self._write_line("RawExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"expr: {ast.unparse(expr.expr)}")
|
||||
423
midas/ast/python.py
Normal file
423
midas/ast/python.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
This file was generated by a script. Any manual changes might be overwritten.
|
||||
Please modify gen/python.py instead and run gen/gen.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from midas.ast.location import Location
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
pos: list[Function.Parameter]
|
||||
mixed: list[Function.Parameter]
|
||||
kw: list[Function.Parameter]
|
||||
|
||||
@property
|
||||
def all(self) -> list[Function.Parameter]:
|
||||
return self.pos + self.mixed + self.kw
|
||||
|
||||
|
||||
####################
|
||||
# Type annotations #
|
||||
####################
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MidasType(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_base_type(self, node: BaseType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint_type(self, node: ConstraintType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_frame_column(self, node: FrameColumn) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_frame_type(self, node: FrameType) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BaseType(MidasType):
|
||||
base: str
|
||||
args: tuple[MidasType, ...]
|
||||
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_base_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstraintType(MidasType):
|
||||
type: MidasType
|
||||
constraint: ast.expr
|
||||
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FrameColumn(MidasType):
|
||||
name: Optional[str]
|
||||
type: Optional[MidasType]
|
||||
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_frame_column(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FrameType(MidasType):
|
||||
columns: list[FrameColumn]
|
||||
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_frame_type(self)
|
||||
|
||||
|
||||
##############
|
||||
# Statements #
|
||||
##############
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Stmt(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_expression_stmt(self, stmt: ExpressionStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_function(self, stmt: Function) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_type_assign(self, stmt: TypeAssign) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_assign_stmt(self, stmt: AssignStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_return_stmt(self, stmt: ReturnStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_if_stmt(self, stmt: IfStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_pass(self, stmt: Pass) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_for_stmt(self, stmt: ForStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_raw_stmt(self, stmt: RawStmt) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExpressionStmt(Stmt):
|
||||
expr: Expr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_expression_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Function(Stmt):
|
||||
name: str
|
||||
params: ParamSpec
|
||||
returns: Optional[MidasType]
|
||||
body: list[Stmt]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Parameter:
|
||||
location: Optional[Location] = None
|
||||
name: str
|
||||
type: Optional[MidasType]
|
||||
default: Optional[Expr]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_function(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeAssign(Stmt):
|
||||
name: str
|
||||
type: MidasType
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_type_assign(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssignStmt(Stmt):
|
||||
targets: list[Expr]
|
||||
value: Expr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_assign_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReturnStmt(Stmt):
|
||||
value: Optional[Expr]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_return_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IfStmt(Stmt):
|
||||
test: Expr
|
||||
body: list[Stmt]
|
||||
orelse: list[Stmt]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_if_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Pass(Stmt):
|
||||
pass
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_pass(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ForStmt(Stmt):
|
||||
target: Expr
|
||||
iterator: Expr
|
||||
body: list[Stmt]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_for_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RawStmt(Stmt):
|
||||
stmt: ast.stmt
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_raw_stmt(self)
|
||||
|
||||
|
||||
###############
|
||||
# Expressions #
|
||||
###############
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Expr(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_binary_expr(self, expr: BinaryExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_compare_expr(self, expr: CompareExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_call_expr(self, expr: CallExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_get_expr(self, expr: GetExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_variable_expr(self, expr: VariableExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_cast_expr(self, expr: CastExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_list_expr(self, expr: ListExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_dict_expr(self, expr: DictExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_subscript_expr(self, expr: SubscriptExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_slice_expr(self, expr: SliceExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_tuple_expr(self, expr: TupleExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_raw_expr(self, expr: RawExpr) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BinaryExpr(Expr):
|
||||
left: Expr
|
||||
operator: ast.operator
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_binary_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CompareExpr(Expr):
|
||||
left: Expr
|
||||
operator: ast.cmpop
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_compare_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UnaryExpr(Expr):
|
||||
operator: ast.unaryop
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_unary_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CallExpr(Expr):
|
||||
callee: Expr
|
||||
arguments: list[Expr]
|
||||
keywords: dict[str, Expr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_call_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GetExpr(Expr):
|
||||
object: Expr
|
||||
name: str
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_get_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LiteralExpr(Expr):
|
||||
value: Any
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_literal_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VariableExpr(Expr):
|
||||
name: str
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_variable_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LogicalExpr(Expr):
|
||||
left: Expr
|
||||
operator: ast.boolop
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_logical_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CastExpr(Expr):
|
||||
type: MidasType
|
||||
expr: Expr
|
||||
unsafe: bool
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_cast_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TernaryExpr(Expr):
|
||||
test: Expr
|
||||
if_true: Expr
|
||||
if_false: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_ternary_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ListExpr(Expr):
|
||||
items: list[Expr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_list_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DictExpr(Expr):
|
||||
keys: list[Optional[Expr]]
|
||||
values: list[Expr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_dict_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SubscriptExpr(Expr):
|
||||
object: Expr
|
||||
index: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_subscript_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SliceExpr(Expr):
|
||||
lower: Optional[Expr]
|
||||
upper: Optional[Expr]
|
||||
step: Optional[Expr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_slice_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TupleExpr(Expr):
|
||||
items: tuple[Expr, ...]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_tuple_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RawExpr(Expr):
|
||||
expr: ast.expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_raw_expr(self)
|
||||
277
midas/checker/builtins.midas
Normal file
277
midas/checker/builtins.midas
Normal file
@@ -0,0 +1,277 @@
|
||||
extend float {
|
||||
def hex: fn() -> str
|
||||
def is_integer: fn() -> bool
|
||||
prop real: float
|
||||
prop imag: float
|
||||
def conjugate: fn() -> float
|
||||
def __add__: fn(value: float, /) -> float
|
||||
def __sub__: fn(value: float, /) -> float
|
||||
def __mul__: fn(value: float, /) -> float
|
||||
def __floordiv__: fn(value: float, /) -> float
|
||||
def __truediv__: fn(value: float, /) -> float
|
||||
def __mod__: fn(value: float, /) -> float
|
||||
// def __divmod__: fn(value: float, /) -> tuple[float, float]
|
||||
|
||||
def __pow__: fn(value: int, /) -> float
|
||||
// positive __value -> float; negative __value -> complex
|
||||
// return type must be Any as `float | complex` causes too many false-positive errors
|
||||
def __pow__: fn(value: float, /) -> Any
|
||||
def __radd__: fn(value: float, /) -> float
|
||||
def __rsub__: fn(value: float, /) -> float
|
||||
def __rmul__: fn(value: float, /) -> float
|
||||
def __rfloordiv__: fn(value: float, /) -> float
|
||||
def __rtruediv__: fn(value: float, /) -> float
|
||||
def __rmod__: fn(value: float, /) -> float
|
||||
// def __rdivmod__: fn(value: float, /) -> tuple[float, float]
|
||||
// def __rpow__: fn(value: _PositiveInteger, mod: None = None, /) -> float
|
||||
// def __rpow__: fn(value: _NegativeInteger, mod: None = None, /) -> complex
|
||||
// Returning `complex` for the general case gives too many false-positive errors.
|
||||
// def __rpow__: fn(value: float, mod: None = None, /) -> Any
|
||||
// def __getnewargs__: fn() -> tuple[float]
|
||||
def __trunc__: fn() -> int
|
||||
def __ceil__: fn() -> int
|
||||
def __floor__: fn() -> int
|
||||
def __round__: fn(ndigits: None?, /) -> int
|
||||
def __round__: fn(ndigits: int, /) -> float
|
||||
def __eq__: fn(value: object, /) -> bool
|
||||
def __ne__: fn(value: object, /) -> bool
|
||||
def __lt__: fn(value: float, /) -> bool
|
||||
def __le__: fn(value: float, /) -> bool
|
||||
def __gt__: fn(value: float, /) -> bool
|
||||
def __ge__: fn(value: float, /) -> bool
|
||||
def __neg__: fn() -> float
|
||||
def __pos__: fn() -> float
|
||||
def __int__: fn() -> int
|
||||
def __float__: fn() -> float
|
||||
def __abs__: fn() -> float
|
||||
def __hash__: fn() -> int
|
||||
def __bool__: fn() -> bool
|
||||
def __format__: fn(format_spec: str, /) -> str
|
||||
}
|
||||
|
||||
extend int {
|
||||
prop real: int
|
||||
prop imag: int
|
||||
prop numerator: int
|
||||
prop denominator: int
|
||||
def conjugate: fn() -> int
|
||||
def bit_length: fn() -> int
|
||||
def bit_count: fn() -> int
|
||||
// def to_bytes: fn(length: int?, byteorder: str?, *, signed: bool?) -> bytes
|
||||
|
||||
def __add__: fn(value: int, /) -> int
|
||||
def __sub__: fn(value: int, /) -> int
|
||||
def __mul__: fn(value: int, /) -> int
|
||||
def __floordiv__: fn(value: int, /) -> int
|
||||
def __truediv__: fn(value: int, /) -> float
|
||||
def __mod__: fn(value: int, /) -> int
|
||||
// def __divmod__: fn(value: int, /) -> tuple[int, int]
|
||||
def __radd__: fn(value: int, /) -> int
|
||||
def __rsub__: fn(value: int, /) -> int
|
||||
def __rmul__: fn(value: int, /) -> int
|
||||
def __rfloordiv__: fn(value: int, /) -> int
|
||||
def __rtruediv__: fn(value: int, /) -> float
|
||||
def __rmod__: fn(value: int, /) -> int
|
||||
// def __rdivmod__: fn(value: int, /) -> tuple[int, int]
|
||||
def __pow__: fn(value: int, /) -> int
|
||||
// def __pow__: fn(value: _PositiveInteger, mod: None = None, /) -> int
|
||||
// def __pow__: fn(value: _NegativeInteger, mod: None = None, /) -> float
|
||||
// positive __value -> int; negative __value -> float
|
||||
// return type must be Any as `int | float` causes too many false-positive errors
|
||||
// def __pow__: fn(value: int, mod: None = None, /) -> Any
|
||||
// def __pow__: fn(value: int, mod: int, /) -> int
|
||||
def __rpow__: fn(value: int, /) -> Any
|
||||
def __and__: fn(value: int, /) -> int
|
||||
def __or__: fn(value: int, /) -> int
|
||||
def __xor__: fn(value: int, /) -> int
|
||||
def __lshift__: fn(value: int, /) -> int
|
||||
def __rshift__: fn(value: int, /) -> int
|
||||
def __rand__: fn(value: int, /) -> int
|
||||
def __ror__: fn(value: int, /) -> int
|
||||
def __rxor__: fn(value: int, /) -> int
|
||||
def __rlshift__: fn(value: int, /) -> int
|
||||
def __rrshift__: fn(value: int, /) -> int
|
||||
def __neg__: fn() -> int
|
||||
def __pos__: fn() -> int
|
||||
def __invert__: fn() -> int
|
||||
def __trunc__: fn() -> int
|
||||
def __ceil__: fn() -> int
|
||||
def __floor__: fn() -> int
|
||||
def __round__: fn(ndigits: None?, /) -> int
|
||||
def __round__: fn(ndigits: int, /) -> int
|
||||
|
||||
// def __getnewargs__: fn() -> tuple[int]
|
||||
def __eq__: fn(value: object, /) -> bool
|
||||
def __ne__: fn(value: object, /) -> bool
|
||||
def __lt__: fn(value: int, /) -> bool
|
||||
def __le__: fn(value: int, /) -> bool
|
||||
def __gt__: fn(value: int, /) -> bool
|
||||
def __ge__: fn(value: int, /) -> bool
|
||||
def __float__: fn() -> float
|
||||
def __int__: fn() -> int
|
||||
def __abs__: fn() -> int
|
||||
def __hash__: fn() -> int
|
||||
def __bool__: fn() -> bool
|
||||
def __index__: fn() -> int
|
||||
def __format__: fn(format_spec: str, /) -> str
|
||||
}
|
||||
|
||||
extend list[T] {
|
||||
def copy: fn () -> list[T]
|
||||
def append: fn (object: T, /) -> None
|
||||
def extend: fn (iterable: list[T], /) -> None
|
||||
def pop: fn (index: int?, /) -> T
|
||||
def index: fn (value: T, start: int?, stop: int?, /) -> int
|
||||
def count: fn (value: T, /) -> int
|
||||
def insert: fn (index: int, object: T, /) -> None
|
||||
def remove: fn (value: T, /) -> None
|
||||
def sort: fn (*, reverse: bool?) -> None
|
||||
def __len__: fn () -> int
|
||||
// def __iter__: fn () -> Iterator[T]
|
||||
def __getitem__: fn (i: int, /) -> T
|
||||
def __getitem__: fn (s: slice, /) -> list[T]
|
||||
def __setitem__: fn (key: int, value: T, /) -> None
|
||||
def __setitem__: fn (key: slice, value: list[T], /) -> None
|
||||
def __delitem__: fn (key: int, /) -> None
|
||||
def __delitem__: fn (key: slice, /) -> None
|
||||
// def __add__: fn[S <: T] (value: list[S], /) -> list[T]
|
||||
def __add__: fn (value: list[T], /) -> list[T]
|
||||
def __iadd__: fn (value: list[T], /) -> list[T]
|
||||
def __mul__: fn (value: int, /) -> list[T]
|
||||
def __rmul__: fn (value: int, /) -> list[T]
|
||||
def __imul__: fn (value: int, /) -> list[T]
|
||||
def __contains__: fn (key: object, /) -> bool
|
||||
// def __reversed__: fn (self) -> Iterator[_T]
|
||||
def __gt__: fn (value: list[T], /) -> bool
|
||||
def __ge__: fn (value: list[T], /) -> bool
|
||||
def __lt__: fn (value: list[T], /) -> bool
|
||||
def __le__: fn (value: list[T], /) -> bool
|
||||
def __eq__: fn (value: object, /) -> bool
|
||||
|
||||
prop __doc__: str
|
||||
}
|
||||
|
||||
extend dict[K, V] {
|
||||
def copy: fn() -> dict[K, V]
|
||||
def keys: fn() -> list[K] // TODO: use builtin types
|
||||
def values: fn() -> list[V] // TODO: use builtin types
|
||||
// def items: fn() -> list[tuple[K, V]] // TODO: use builtin types
|
||||
|
||||
// def get: fn(key: K, default: None = None, /) -> V | None
|
||||
def get: fn(key: K, default: V, /) -> V
|
||||
// def get: fn[T](key: K, default: T, /) -> V | T
|
||||
def pop: fn(key: K, /) -> V
|
||||
def pop: fn(key: K, default: V, /) -> V
|
||||
// def pop: fn[T](key: K, default: T, /) -> V | T
|
||||
def __len__: fn() -> int
|
||||
def __getitem__: fn(key: K, /) -> V
|
||||
def __setitem__: fn(key: K, value: V, /) -> None
|
||||
def __delitem__: fn(key: K, /) -> None
|
||||
// def __iter__: fn() -> Iterator[K]
|
||||
def __eq__: fn(value: object, /) -> bool
|
||||
// def __reversed__: fn() -> Iterator[K]
|
||||
|
||||
def __or__: fn(value: dict[K, V], /) -> dict[K, V]
|
||||
// def __or__: fn[K2, V2](value: dict[K2, V2], /) -> dict[K | K2, V | V2]
|
||||
def __ror__: fn(value: dict[K, V], /) -> dict[K, V]
|
||||
// def __ror__: fn[K2, V2](value: dict[K2, V2], /) -> dict[K | K2, V | V2]
|
||||
// def __ior__: fn(value: SupportsKeysAndGetItem[K, V], /) -> dict[K, V]
|
||||
// def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V]
|
||||
|
||||
}
|
||||
|
||||
extend str {
|
||||
def capitalize: fn() -> str
|
||||
def casefold: fn() -> str
|
||||
def center: fn(width: int, fillchar: str?, /) -> str
|
||||
def count: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def count: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def count: fn(sub: str, start: None, end: int, /) -> int
|
||||
def count: fn(sub: str, start: int, end: int, /) -> int
|
||||
def encode: fn(encoding: str?, errors: str?) -> bytes
|
||||
def endswith: fn(suffix: str, start: None?, end: None?, /) -> bool
|
||||
def endswith: fn(suffix: str, start: int, end: None?, /) -> bool
|
||||
def endswith: fn(suffix: str, start: None, end: int, /) -> bool
|
||||
def endswith: fn(suffix: str, start: int, end: int, /) -> bool
|
||||
def expandtabs: fn(tabsize: int?) -> str
|
||||
def find: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def find: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def find: fn(sub: str, start: None, end: int, /) -> int
|
||||
def find: fn(sub: str, start: int, end: int, /) -> int
|
||||
// def format: fn(*args: object, **kwargs: object) -> str
|
||||
// def format_map: fn(mapping: _FormatMapMapping, /) -> str
|
||||
def index: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def index: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def index: fn(sub: str, start: None, end: int, /) -> int
|
||||
def index: fn(sub: str, start: int, end: int, /) -> int
|
||||
def isalnum: fn() -> bool
|
||||
def isalpha: fn() -> bool
|
||||
def isascii: fn() -> bool
|
||||
def isdecimal: fn() -> bool
|
||||
def isdigit: fn() -> bool
|
||||
def isidentifier: fn() -> bool
|
||||
def islower: fn() -> bool
|
||||
def isnumeric: fn() -> bool
|
||||
def isprintable: fn() -> bool
|
||||
def isspace: fn() -> bool
|
||||
def istitle: fn() -> bool
|
||||
def isupper: fn() -> bool
|
||||
def join: fn(iterable: list[str], /) -> str // TODO: use Iterable
|
||||
def ljust: fn(width: int, fillchar: str?, /) -> str
|
||||
def lower: fn() -> str
|
||||
def lstrip: fn(chars: None?, /) -> str
|
||||
def lstrip: fn(chars: str, /) -> str
|
||||
def partition: fn(sep: str, /) -> tuple[str, str, str]
|
||||
|
||||
def replace: fn(old: str, new: str, count: int?, /) -> str
|
||||
|
||||
def removeprefix: fn(prefix: str, /) -> str
|
||||
def removesuffix: fn(suffix: str, /) -> str
|
||||
def rfind: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def rfind: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def rfind: fn(sub: str, start: None, end: int, /) -> int
|
||||
def rfind: fn(sub: str, start: int, end: int, /) -> int
|
||||
def rindex: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def rindex: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def rindex: fn(sub: str, start: None, end: int, /) -> int
|
||||
def rindex: fn(sub: str, start: int, end: int, /) -> int
|
||||
def rjust: fn(width: int, fillchar: str?, /) -> str
|
||||
def rpartition: fn(sep: str, /) -> tuple[str, str, str]
|
||||
def rsplit: fn(sep: None?, maxsplit: int?) -> list[str]
|
||||
def rsplit: fn(sep: str, maxsplit: int?) -> list[str]
|
||||
def rstrip: fn(chars: None?, /) -> str
|
||||
def rstrip: fn(chars: str, /) -> str
|
||||
def split: fn(sep: None?, maxsplit: int?) -> list[str]
|
||||
def split: fn(sep: str, maxsplit: int?) -> list[str]
|
||||
def splitlines: fn(keepends: bool?) -> list[str]
|
||||
def startswith: fn(prefix: str, start: None?, end: None?, /) -> bool
|
||||
def startswith: fn(prefix: str, start: int, end: None?, /) -> bool
|
||||
def startswith: fn(prefix: str, start: None, end: int, /) -> bool
|
||||
def startswith: fn(prefix: str, start: int, end: int, /) -> bool
|
||||
def strip: fn(chars: None?, /) -> str
|
||||
def strip: fn(chars: str, /) -> str
|
||||
def swapcase: fn() -> str
|
||||
def title: fn() -> str
|
||||
// def translate: fn(table: _TranslateTable, /) -> str
|
||||
def upper: fn() -> str
|
||||
def zfill: fn(width: int, /) -> str
|
||||
def __add__: fn(value: str, /) -> str
|
||||
// Incompatible with Sequence.__contains__
|
||||
def __contains__: fn(key: str, /) -> bool
|
||||
def __eq__: fn(value: object, /) -> bool
|
||||
def __ge__: fn(value: str, /) -> bool
|
||||
def __getitem__: fn(key: slice, /) -> str
|
||||
def __getitem__: fn(key: int, /) -> str
|
||||
def __gt__: fn(value: str, /) -> bool
|
||||
def __hash__: fn() -> int
|
||||
// def __iter__: fn() -> Iterator[str]
|
||||
def __le__: fn(value: str, /) -> bool
|
||||
def __len__: fn() -> int
|
||||
def __lt__: fn(value: str, /) -> bool
|
||||
def __mod__: fn(value: Any, /) -> str
|
||||
def __mul__: fn(value: int, /) -> str
|
||||
def __ne__: fn(value: object, /) -> bool
|
||||
def __rmul__: fn(value: int, /) -> str
|
||||
def __getnewargs__: fn() -> tuple[str]
|
||||
def __format__: fn(format_spec: str, /) -> str
|
||||
}
|
||||
60
midas/checker/builtins.py
Normal file
60
midas/checker/builtins.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from midas.checker.types import (
|
||||
BaseType,
|
||||
GenericType,
|
||||
TopType,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.registry import TypesRegistry
|
||||
|
||||
|
||||
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||
"object": {"float", "list", "dict", "str", "bytes", "tuple"},
|
||||
"float": {"int"},
|
||||
}
|
||||
"""
|
||||
Hard-coded subtype relationships between builtin types
|
||||
|
||||
Circular dependencies and diamond inheritance MUST be avoided
|
||||
"""
|
||||
|
||||
|
||||
def define_builtins(reg: TypesRegistry):
|
||||
"""Define builtin types and operations"""
|
||||
any = reg.define_type("Any", TopType())
|
||||
unit = reg.define_type("None", UnitType())
|
||||
object = reg.define_type("object", BaseType(name="object"))
|
||||
bytes = reg.define_type("bytes", BaseType(name="bytes"))
|
||||
bool = reg.define_type("bool", BaseType(name="bool"))
|
||||
int = reg.define_type("int", BaseType(name="int"))
|
||||
float = reg.define_type("float", BaseType(name="float"))
|
||||
str = reg.define_type("str", BaseType(name="str"))
|
||||
slice = reg.define_type("slice", BaseType(name="slice"))
|
||||
|
||||
tuple = reg.define_type("tuple", BaseType(name="tuple"))
|
||||
|
||||
list = reg.define_type(
|
||||
"list",
|
||||
GenericType(
|
||||
name="list",
|
||||
params=[TypeVar(name="T", bound=None)],
|
||||
body=BaseType(name="list"),
|
||||
),
|
||||
)
|
||||
dict = reg.define_type(
|
||||
"dict",
|
||||
GenericType(
|
||||
name="dict",
|
||||
params=[
|
||||
TypeVar(name="K", bound=None),
|
||||
TypeVar(name="V", bound=None),
|
||||
],
|
||||
body=BaseType(name="dict"),
|
||||
),
|
||||
)
|
||||
41
midas/checker/checker.py
Normal file
41
midas/checker/checker.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.checker.midas import MidasTyper
|
||||
from midas.checker.python import PythonTyper
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import Reporter
|
||||
from midas.utils import TypedAST
|
||||
|
||||
|
||||
class TypeChecker:
|
||||
"""Type checking dispatcher
|
||||
|
||||
Contains a typer for Midas and one for Python, as well as the types registry
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.types: TypesRegistry = TypesRegistry()
|
||||
self.reporter: Reporter = Reporter()
|
||||
|
||||
self.midas_typer = MidasTyper(self.types, self.reporter)
|
||||
self.python_typer = PythonTyper(self.types, self.reporter)
|
||||
|
||||
def import_midas(self, path: Path):
|
||||
source: str = path.read_text()
|
||||
return self.import_midas_source(source, path=str(path))
|
||||
|
||||
def import_midas_source(self, source: str, path: Optional[str] = None):
|
||||
self.midas_typer.process(source, path)
|
||||
|
||||
def type_check(self, path: Path) -> TypedAST:
|
||||
source: str = path.read_text()
|
||||
return self.type_check_source(source, path=str(path))
|
||||
|
||||
def type_check_source(self, source: str, path: Optional[str] = None) -> TypedAST:
|
||||
return self.python_typer.process(source, path)
|
||||
|
||||
@property
|
||||
def diagnostics(self) -> list[Diagnostic]:
|
||||
return self.reporter.diagnostics
|
||||
64
midas/checker/diagnostic.py
Normal file
64
midas/checker/diagnostic.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
|
||||
|
||||
class DiagnosticType(StrEnum):
|
||||
ERROR = "Error"
|
||||
WARNING = "Warning"
|
||||
INFO = "Info"
|
||||
DEBUG = "Debug"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Diagnostic:
|
||||
"""Information about a diagnostic (warning, errors, etc.)
|
||||
|
||||
Holds a location, a diagnostic type and a message.
|
||||
Optionally bound to a file path
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
|
||||
file_path: Optional[str]
|
||||
location: Location
|
||||
type: DiagnosticType
|
||||
message: str
|
||||
|
||||
@property
|
||||
def location_str(self) -> str:
|
||||
"""The diagnostic type and location as a human readable string
|
||||
|
||||
The location is formatted as "<Type> in <file> from L<start_line>:<start_col> to <end_line>:<end_col>",
|
||||
for example: "Error in /home/user/Desktop/script.py from L12:5 to L12:8"
|
||||
|
||||
If the file is `None`, the "in ..." section is excluded from the result.<br>
|
||||
If the location's end is not specified, the formulation "at L<start_line>:<start_col>" is used.
|
||||
|
||||
Returns:
|
||||
str: _description_
|
||||
"""
|
||||
|
||||
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
|
||||
end_loc: Optional[str] = ""
|
||||
if (
|
||||
self.location.end_lineno is not None
|
||||
and self.location.end_col_offset is not None
|
||||
):
|
||||
end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}"
|
||||
|
||||
loc: str = ""
|
||||
if self.file_path is not None:
|
||||
loc += f" in {self.file_path}"
|
||||
if end_loc is None:
|
||||
loc += f" at {start_loc}"
|
||||
else:
|
||||
loc += f" from {start_loc} to {end_loc}"
|
||||
|
||||
return f"{self.type}{loc}"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.location_str}: {self.message}"
|
||||
486
midas/checker/dispatcher.py
Normal file
486
midas/checker/dispatcher.py
Normal file
@@ -0,0 +1,486 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Generic, Optional, Protocol, TypeVar, Union
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
DerivedType,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.checker.unifier import Unifier
|
||||
|
||||
|
||||
class HasLocation(Protocol):
|
||||
@property
|
||||
def location(self) -> Location: ...
|
||||
|
||||
|
||||
E = TypeVar("E", bound=HasLocation)
|
||||
|
||||
TypedExpr = tuple[E, Type]
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MappedArgument(Generic[E]):
|
||||
arg_expr: E
|
||||
arg_type: Type
|
||||
parameter: Function.Parameter
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class OverloadCandidate:
|
||||
function: Function
|
||||
mapped: list[MappedArgument]
|
||||
|
||||
|
||||
class CallError(StrEnum):
|
||||
INVALID_ARGS = "Invalid arguments"
|
||||
NO_MATCHING_OVERLOAD = "No matching overload"
|
||||
IMPOSSIBLE_UNIFICATION = "Parameters unification failed"
|
||||
NOT_CALLABLE = "Not callable"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class CallResult:
|
||||
error: Optional[CallError] = None
|
||||
result: Type = UnknownType()
|
||||
message: Optional[str] = None
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return self.error is None
|
||||
|
||||
@property
|
||||
def error_message(self) -> str:
|
||||
if self.message is not None:
|
||||
return self.message
|
||||
if self.error is not None:
|
||||
return str(self.error)
|
||||
return ""
|
||||
|
||||
|
||||
class CallDispatcher(Generic[E]):
|
||||
def __init__(self, types: TypesRegistry, reporter: FileReporter) -> None:
|
||||
self.types: TypesRegistry = types
|
||||
self.reporter: FileReporter = reporter
|
||||
self.logger: logging.Logger = logging.getLogger("CallDispatcher")
|
||||
|
||||
def set_reporter(self, reporter: FileReporter):
|
||||
self.reporter = reporter
|
||||
|
||||
def get_result(
|
||||
self,
|
||||
location: Location,
|
||||
callee: Type,
|
||||
positional: list[TypedExpr[E]],
|
||||
keywords: dict[str, TypedExpr[E]],
|
||||
report_errors: bool = True,
|
||||
) -> CallResult:
|
||||
"""Get the result type of a function call
|
||||
|
||||
If the function has overloads, the function will try to resolve the
|
||||
appropriate signature.
|
||||
Argument types are matched to the defined parameters.
|
||||
The function doesn't take the raw expression as a parameter to accommodate
|
||||
for desugared calls such as for operators.
|
||||
|
||||
Args:
|
||||
location (Location): the call location
|
||||
callee (Type): the called function
|
||||
positional (list[TypedExpr]): the list positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Type: the return type of the call, or `None` if either
|
||||
the call is invalid or no overload matched the arguments uniquely
|
||||
"""
|
||||
match callee:
|
||||
case Function() as function:
|
||||
valid: bool
|
||||
mapped: list[MappedArgument[E]]
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function, location, positional, keywords
|
||||
)
|
||||
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
||||
if not valid:
|
||||
return CallResult(error=CallError.INVALID_ARGS)
|
||||
return CallResult(result=function.returns)
|
||||
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
res = self._match_overload(
|
||||
overloads, location, positional, keywords, report_errors
|
||||
)
|
||||
if res[0] is None:
|
||||
return CallResult(
|
||||
error=CallError.NO_MATCHING_OVERLOAD,
|
||||
message=res[1],
|
||||
)
|
||||
return CallResult(result=res[0].returns)
|
||||
|
||||
case AppliedType(body=body):
|
||||
return self.get_result(
|
||||
location, body, positional, keywords, report_errors
|
||||
)
|
||||
|
||||
case UnknownType():
|
||||
return CallResult(result=UnknownType())
|
||||
|
||||
case DerivedType(type=base):
|
||||
return self.get_result(
|
||||
location, base, positional, keywords, report_errors
|
||||
)
|
||||
|
||||
case GenericType():
|
||||
unifier: Unifier = Unifier(self.types)
|
||||
pos: list[Type] = [a[1] for a in positional]
|
||||
kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()}
|
||||
unified: Optional[Type] = unifier.unify_call(callee, pos, kw)
|
||||
if unified is None:
|
||||
pos_str: str = ", ".join(str(t) for t in pos)
|
||||
kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items())
|
||||
message: str = (
|
||||
f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}"
|
||||
)
|
||||
if report_errors:
|
||||
self.reporter.error(location, message)
|
||||
return CallResult(
|
||||
error=CallError.IMPOSSIBLE_UNIFICATION,
|
||||
message=message,
|
||||
)
|
||||
return self.get_result(
|
||||
location,
|
||||
unified,
|
||||
positional,
|
||||
keywords,
|
||||
report_errors,
|
||||
)
|
||||
|
||||
case _:
|
||||
message: str = f"{callee} ({callee.__class__.__name__}) is not callable"
|
||||
if report_errors:
|
||||
self.reporter.error(location, message)
|
||||
return CallResult(
|
||||
error=CallError.NOT_CALLABLE,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def _unwrap_function(
|
||||
self,
|
||||
callee: Type,
|
||||
positional: list[TypedExpr[E]],
|
||||
keywords: dict[str, TypedExpr[E]],
|
||||
) -> Union[tuple[Function, None], tuple[None, CallError]]:
|
||||
match callee:
|
||||
case DerivedType(type=base):
|
||||
return self._unwrap_function(base, positional, keywords)
|
||||
|
||||
case GenericType():
|
||||
unifier: Unifier = Unifier(self.types)
|
||||
unified: Optional[Type] = unifier.unify_call(
|
||||
callee,
|
||||
[a[1] for a in positional],
|
||||
{k: v[1] for k, v in keywords.items()},
|
||||
)
|
||||
if unified is None:
|
||||
return None, CallError.IMPOSSIBLE_UNIFICATION
|
||||
return self._unwrap_function(unified, positional, keywords)
|
||||
|
||||
case Function():
|
||||
return callee, None
|
||||
|
||||
case AppliedType(body=body):
|
||||
return self._unwrap_function(body, positional, keywords)
|
||||
|
||||
case _:
|
||||
return None, CallError.NOT_CALLABLE
|
||||
|
||||
def _are_arguments_valid(
|
||||
self,
|
||||
arguments: list[MappedArgument[E]],
|
||||
report_errors: bool = True,
|
||||
) -> bool:
|
||||
"""Check whether the passed argument types correspond to their matched parameter definitions
|
||||
|
||||
Args:
|
||||
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
||||
"""
|
||||
valid: bool = True
|
||||
for arg in arguments:
|
||||
if not self.types.is_subtype(arg.arg_type, arg.parameter.type):
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg.arg_expr.location,
|
||||
f"Wrong type for argument '{arg.parameter.name}', expected {arg.parameter.type}, got {arg.arg_type}",
|
||||
)
|
||||
valid = False
|
||||
return valid
|
||||
|
||||
def _match_overload(
|
||||
self,
|
||||
overloads: list[Type],
|
||||
location: Location,
|
||||
positional: list[TypedExpr[E]],
|
||||
keywords: dict[str, TypedExpr[E]],
|
||||
report_errors: bool = True,
|
||||
) -> Union[tuple[Function, None], tuple[None, str]]:
|
||||
"""Try and resolve the appropriate overload for the given arguments
|
||||
|
||||
Args:
|
||||
overloads (list[Type]): the list of possible overloads
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional[Function]: the resolved function signature if it can be
|
||||
determined unambiguously, or `None`.
|
||||
"""
|
||||
candidates: list[OverloadCandidate] = []
|
||||
errors: list[CallError] = []
|
||||
for overload in overloads:
|
||||
function, unwrap_error = self._unwrap_function(
|
||||
overload, positional, keywords
|
||||
)
|
||||
if function is None:
|
||||
errors.append(unwrap_error) # type: ignore
|
||||
continue
|
||||
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function=function,
|
||||
location=location,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
report_errors=False,
|
||||
)
|
||||
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||
candidates.append(
|
||||
OverloadCandidate(
|
||||
function=function,
|
||||
mapped=mapped,
|
||||
)
|
||||
)
|
||||
|
||||
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||
kw_types: str = ", ".join(
|
||||
f"{name}: {type}" for name, (_, type) in keywords.items()
|
||||
)
|
||||
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||
|
||||
n_candidates: int = len(candidates)
|
||||
|
||||
# Exactly 1 match -> return it
|
||||
if n_candidates == 1:
|
||||
return candidates[0].function, None
|
||||
|
||||
# No match -> invalid call
|
||||
if n_candidates == 0:
|
||||
overloads_str: str = ", ".join(map(str, overloads))
|
||||
errors_str: str = ", ".join(errors)
|
||||
message: str = (
|
||||
f"No matching overload in [{overloads_str}] {for_args} (errors: {errors_str})"
|
||||
)
|
||||
if report_errors:
|
||||
self.reporter.error(location, message)
|
||||
return None, message
|
||||
|
||||
# Multiple matches -> see if one <: all others (more specific)
|
||||
for i1, c1 in enumerate(candidates):
|
||||
mapped1: list[MappedArgument[E]] = c1.mapped
|
||||
best_match: bool = True
|
||||
for i2, c2 in enumerate(candidates):
|
||||
if i1 == i2:
|
||||
continue
|
||||
mapped2: list[MappedArgument[E]] = c2.mapped
|
||||
if not self._are_mapped_subtypes(mapped1, mapped2):
|
||||
best_match = False
|
||||
break
|
||||
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
||||
if best_match:
|
||||
return c1.function, None
|
||||
|
||||
candidates_str: str = ", ".join(
|
||||
str(candidate.function) for candidate in candidates
|
||||
)
|
||||
message: str = f"Multiple matching overloads {for_args}: {candidates_str}"
|
||||
if report_errors:
|
||||
self.reporter.error(location, message)
|
||||
return None, message
|
||||
|
||||
def map_call_arguments(
|
||||
self,
|
||||
function: Function,
|
||||
location: Location,
|
||||
positional: list[TypedExpr[E]],
|
||||
keywords: dict[str, TypedExpr[E]],
|
||||
report_errors: bool = True,
|
||||
) -> tuple[bool, list[MappedArgument]]:
|
||||
"""Map call arguments to a function's parameters as defined in its signature
|
||||
|
||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||
with the arguments passed at the call site
|
||||
|
||||
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
||||
unless `report_errors` is set to `False`
|
||||
|
||||
Args:
|
||||
function (Function): the function definition
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
||||
the call is valid and the list of mapped arguments
|
||||
"""
|
||||
set_params: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
param.name
|
||||
for param in function.params.pos + function.params.mixed
|
||||
if param.required
|
||||
]
|
||||
required_keyword: list[str] = [
|
||||
param.name for param in function.params.kw if param.required
|
||||
]
|
||||
|
||||
mapped: list[MappedArgument[E]] = []
|
||||
|
||||
pos_params: list[Function.Parameter] = list(function.params.pos)
|
||||
mixed_params: list[Function.Parameter] = list(function.params.mixed)
|
||||
kw_params: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in function.params.kw
|
||||
}
|
||||
|
||||
valid_call: bool = True
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Parameter
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg[0].location, "Too many positional arguments"
|
||||
)
|
||||
valid_call = False
|
||||
break
|
||||
name: str = param.name
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_params.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
arg_expr=arg[0],
|
||||
arg_type=arg[1],
|
||||
parameter=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({param.name: param for param in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Parameter
|
||||
if name not in kw_params:
|
||||
if report_errors:
|
||||
if name in set_params:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Multiple values for parameter '{name}'"
|
||||
)
|
||||
else:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Unknown keyword parameter '{name}'"
|
||||
)
|
||||
valid_call = False
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_params.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
arg_expr=arg[0],
|
||||
arg_type=arg[1],
|
||||
parameter=param,
|
||||
)
|
||||
)
|
||||
|
||||
def join_params(params: list[str]) -> str:
|
||||
params = list(map(lambda p: f"'{p}'", params))
|
||||
if len(params) == 0:
|
||||
return ""
|
||||
if len(params) == 1:
|
||||
return params[0]
|
||||
return ", ".join(params[:-1]) + " and " + params[-1]
|
||||
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
params: str = join_params(required_positional)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required positional argument{plural}: {params}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
params: str = join_params(required_keyword)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required keyword argument{plural}: {params}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
return valid_call, mapped
|
||||
|
||||
def _are_mapped_subtypes(
|
||||
self, mapped1: list[MappedArgument[E]], mapped2: list[MappedArgument[E]]
|
||||
) -> bool:
|
||||
"""Check whether the given argument mappings are subtype/supertype of one another
|
||||
|
||||
This function checks whether the argument mappings `mapped1` are subtypes
|
||||
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
||||
of the corresponding parameter in `mapped2`, `False` is returned.
|
||||
|
||||
This is used to check whether a given overload is
|
||||
a more specific function/ a subtype of another.
|
||||
|
||||
Args:
|
||||
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
||||
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
||||
|
||||
Returns:
|
||||
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
||||
"""
|
||||
by_expr: dict[E, Type] = {}
|
||||
for arg in mapped1:
|
||||
by_expr[arg.arg_expr] = arg.parameter.type
|
||||
|
||||
for arg in mapped2:
|
||||
type2: Type = arg.parameter.type
|
||||
type1: Type = by_expr[arg.arg_expr]
|
||||
if not self.types.is_subtype(type1, type2):
|
||||
return False
|
||||
return True
|
||||
142
midas/checker/environment.py
Normal file
142
midas/checker/environment.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from midas.checker.types import Type
|
||||
|
||||
|
||||
class Environment:
|
||||
"""
|
||||
A scoped environment in which variables are defined
|
||||
|
||||
Each environment can inherit from a parent/enclosing environment.
|
||||
"""
|
||||
|
||||
def __init__(self, enclosing: Optional[Environment] = None) -> None:
|
||||
self.enclosing: Optional[Environment] = enclosing
|
||||
self.values: dict[str, Type] = {}
|
||||
self.return_types: list[Type] = []
|
||||
|
||||
self._children: list[Environment] = []
|
||||
if enclosing is not None:
|
||||
enclosing._children.append(self)
|
||||
|
||||
def define(self, name: str, value: Type) -> None:
|
||||
"""Define a variable in this environment
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
value (Type): the value
|
||||
"""
|
||||
self.values[name] = value
|
||||
|
||||
def get(self, name: str) -> Optional[Type]:
|
||||
"""Get a variable in the closest environment which has a definition for it
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the value of the variable, or None if it was not found
|
||||
"""
|
||||
if name in self.values:
|
||||
return self.values[name]
|
||||
if self.enclosing is not None:
|
||||
return self.enclosing.get(name)
|
||||
# raise NameError(f"Undefined variable '{name}'")
|
||||
return None
|
||||
|
||||
def assign(self, name: str, value: Type) -> bool:
|
||||
"""Assign a new value to a variable in the environment it was defined in
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
value (Type): the new value
|
||||
|
||||
Returns:
|
||||
bool: True if the variable was assigned in this environment or an ancestor, False otherwise
|
||||
"""
|
||||
if name not in self.values:
|
||||
if self.enclosing is None:
|
||||
return False
|
||||
if self.enclosing.assign(name, value):
|
||||
return True
|
||||
self.values[name] = value
|
||||
return True
|
||||
|
||||
def clear(self):
|
||||
"""Clear all definitions in this environment"""
|
||||
self.values = {}
|
||||
|
||||
def get_at(self, distance: int, name: str) -> Optional[Type]:
|
||||
"""Get the value of a variable at a given distance
|
||||
|
||||
A distance of 0 looks up in this environment, 1 in the parent environment, etc.
|
||||
This methods expects `distance` to be valid. An error will be raised if
|
||||
the stack does not extend far enough to reach `distance`
|
||||
|
||||
Args:
|
||||
distance (int): the scope distance
|
||||
name (str): the name of the variable
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the value at the given distance, or None if it is not defined in that environment
|
||||
|
||||
Raises:
|
||||
AssertionError: if the stack does not extend far enough to reach `distance`
|
||||
"""
|
||||
return self.ancestor(distance).values.get(name)
|
||||
|
||||
def assign_at(self, distance: int, name: str, value: Type) -> None:
|
||||
"""Assign a new value to a variable at a given distance
|
||||
|
||||
A distance of 0 assigns in this environment, 1 in the parent environment, etc.
|
||||
|
||||
Args:
|
||||
distance (int): the scope distance
|
||||
name (str): the name of the variable
|
||||
value (Type): the new value
|
||||
|
||||
Raises:
|
||||
AssertionError: if the stack does not extend far enough to reach `distance`
|
||||
"""
|
||||
self.ancestor(distance).values[name] = value
|
||||
|
||||
def ancestor(self, distance: int) -> Environment:
|
||||
"""Get the ancestor at a given distance
|
||||
|
||||
A distance of 0 references this environment, 1 the parent environment, etc.
|
||||
|
||||
Args:
|
||||
distance (int): the scope distance
|
||||
|
||||
Returns:
|
||||
Environment: the environment
|
||||
|
||||
Raises:
|
||||
AssertionError: if the stack does not extend far enough to reach `distance`
|
||||
"""
|
||||
env: Environment = self
|
||||
for _ in range(distance):
|
||||
assert env.enclosing is not None
|
||||
env = env.enclosing
|
||||
return env
|
||||
|
||||
def flat_dict(self) -> dict[str, Type]:
|
||||
"""Get the current environment including definitions in its ancestor as a flat dictionary
|
||||
|
||||
This method recursively combines this environment definitions with its ancestor's
|
||||
|
||||
Returns:
|
||||
dict: the combined environment
|
||||
"""
|
||||
if self.enclosing is None:
|
||||
return self.values
|
||||
return self.enclosing.flat_dict() | self.values
|
||||
|
||||
def dump(self) -> dict:
|
||||
return {
|
||||
"values": self.values,
|
||||
"return_types": self.return_types,
|
||||
"children": [child.dump() for child in self._children],
|
||||
}
|
||||
174
midas/checker/evaluator.py
Normal file
174
midas/checker/evaluator.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.preamble import Preamble
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import Function, Predicate
|
||||
from midas.lexer.token import TokenType
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class PartialPredicate(Predicate):
|
||||
scope: dict[str, Any]
|
||||
|
||||
|
||||
class Evaluator(m.Expr.Visitor[Any]):
|
||||
def __init__(self, types: TypesRegistry, reporter: Optional[FileReporter] = None):
|
||||
self.types: TypesRegistry = types
|
||||
self.reporter: Optional[FileReporter] = reporter
|
||||
self.preamble: Preamble = Preamble(self.types)
|
||||
self.scopes: list[dict[str, Any]] = [{}]
|
||||
|
||||
def evaluate(self, expr: m.Expr) -> Any:
|
||||
value: Any = expr.accept(self)
|
||||
if self.reporter is not None:
|
||||
self.reporter.debug(expr.location, f"Value: {value}")
|
||||
return value
|
||||
|
||||
def get_value(self, name: str) -> Any:
|
||||
scope: dict[str, Any] = self.scopes[-1]
|
||||
return scope[name]
|
||||
|
||||
def set_value(self, name: str, value: Any, force_declare: bool = False):
|
||||
if not force_declare:
|
||||
for scope in reversed(self.scopes):
|
||||
if name in scope:
|
||||
scope[name] = value
|
||||
return
|
||||
self.scopes[-1][name] = value
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> Any:
|
||||
def left():
|
||||
return self.evaluate(expr.left)
|
||||
|
||||
def right():
|
||||
return self.evaluate(expr.right)
|
||||
|
||||
match expr.operator.type:
|
||||
case TokenType.AND:
|
||||
return left() and right()
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> Any:
|
||||
left: Any = self.evaluate(expr.left)
|
||||
right: Any = self.evaluate(expr.right)
|
||||
match expr.operator.type:
|
||||
case TokenType.MINUS:
|
||||
return left - right
|
||||
case TokenType.STAR:
|
||||
return left * right
|
||||
case TokenType.SLASH:
|
||||
return left / right
|
||||
case TokenType.GREATER:
|
||||
return left > right
|
||||
case TokenType.GREATER_EQUAL:
|
||||
return left >= right
|
||||
case TokenType.LESS:
|
||||
return left < right
|
||||
case TokenType.LESS_EQUAL:
|
||||
return left <= right
|
||||
case TokenType.EQUAL_EQUAL:
|
||||
return left == right
|
||||
case TokenType.BANG_EQUAL:
|
||||
return left != right
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> Any:
|
||||
right: Any = self.evaluate(expr.right)
|
||||
match expr.operator.type:
|
||||
case TokenType.MINUS:
|
||||
return -right
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> Any:
|
||||
callee: Any = self.evaluate(expr.callee)
|
||||
args: list[Any] = [self.evaluate(arg) for arg in expr.arguments]
|
||||
kwargs: dict[str, Any] = {
|
||||
name: self.evaluate(arg) for name, arg in expr.keywords.items()
|
||||
}
|
||||
|
||||
match callee:
|
||||
case Predicate():
|
||||
return self._evaluate_predicate(callee, args, kwargs)
|
||||
case _ if callable(callee):
|
||||
return callee(*args, **kwargs)
|
||||
case _:
|
||||
return NotImplementedError
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> Any:
|
||||
obj: Any = self.evaluate(expr.expr)
|
||||
return getattr(obj, expr.name.lexeme)
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> Any:
|
||||
name: str = expr.name.lexeme
|
||||
for scope in reversed(self.scopes):
|
||||
if name in scope:
|
||||
return scope[name]
|
||||
|
||||
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
|
||||
if predicate is not None:
|
||||
if predicate.alias:
|
||||
return self.evaluate(predicate.body)
|
||||
return predicate
|
||||
|
||||
glob: Optional[Callable] = self.preamble.get_py_func(name)
|
||||
if glob is not None:
|
||||
return glob
|
||||
raise NameError(f"Unknown variable '{name}'")
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Any:
|
||||
return self.evaluate(expr.expr)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> Any:
|
||||
return expr.value
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Any:
|
||||
return self.get_value("_")
|
||||
|
||||
def _evaluate_predicate(
|
||||
self, predicate: Predicate, args: list[Any], kwargs: dict[str, Any]
|
||||
) -> Any:
|
||||
res: Any = None
|
||||
if isinstance(predicate, PartialPredicate):
|
||||
self.scopes.append(predicate.scope)
|
||||
else:
|
||||
self.scopes.append({})
|
||||
match predicate.type:
|
||||
case Function(returns=Function() as inner):
|
||||
self._map_args(predicate.type, args, kwargs)
|
||||
res = PartialPredicate(
|
||||
type=inner,
|
||||
body=predicate.body,
|
||||
alias=False,
|
||||
scope=self.scopes[-1],
|
||||
)
|
||||
|
||||
case Function():
|
||||
self._map_args(predicate.type, args, kwargs)
|
||||
res = self.evaluate(predicate.body)
|
||||
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
self.scopes.pop()
|
||||
return res
|
||||
|
||||
def _map_args(self, function: Function, args: list[Any], kwargs: dict[str, Any]):
|
||||
positional: list[Function.Parameter] = (
|
||||
function.params.pos + function.params.mixed
|
||||
)
|
||||
keywords: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in function.params.mixed + function.params.kw
|
||||
}
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
param: Function.Parameter = positional[i]
|
||||
self.set_value(param.name, arg)
|
||||
|
||||
for name, arg in kwargs.items():
|
||||
param: Function.Parameter = keywords[name]
|
||||
self.set_value(param.name, arg)
|
||||
210
midas/checker/frames/column_groupby_methods.py
Normal file
210
midas/checker/frames/column_groupby_methods.py
Normal file
@@ -0,0 +1,210 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
Function,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
Type,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
groupby: ColumnGroupBy
|
||||
groupby_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.groupby_expr, self.groupby)
|
||||
|
||||
|
||||
class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
NAMED_ARGS: dict[str, str] = {
|
||||
"numeric_only": "bool",
|
||||
"skipna": "bool",
|
||||
"engine": "str",
|
||||
"engine_kwargs": "dict",
|
||||
}
|
||||
|
||||
def _aggregate(
|
||||
self,
|
||||
call: Call,
|
||||
params: list[str | tuple[str, str, bool]] = [],
|
||||
*,
|
||||
preserve_inner_type: bool = False,
|
||||
) -> Type:
|
||||
real_params: list[Function.Parameter] = []
|
||||
for i, param in enumerate(params):
|
||||
match param:
|
||||
case str() as name:
|
||||
param = Function.Parameter(
|
||||
pos=i,
|
||||
name=name,
|
||||
type=self.types.get_type(self.NAMED_ARGS[name]),
|
||||
required=False,
|
||||
)
|
||||
case (name, type, required):
|
||||
param = Function.Parameter(
|
||||
pos=i,
|
||||
name=name,
|
||||
type=self.types.get_type(type),
|
||||
required=required,
|
||||
)
|
||||
real_params.append(param)
|
||||
|
||||
signature = Function(
|
||||
params=ParamSpec(mixed=real_params),
|
||||
returns=(
|
||||
call.groupby.column
|
||||
if preserve_inner_type
|
||||
else ColumnType(type=TopType())
|
||||
),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def kurt(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["skipna", "numeric_only"],
|
||||
)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["numeric_only", "skipna", "engine", "engine_kwargs"],
|
||||
)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["numeric_only", "skipna"],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def prod(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
(
|
||||
"ddof",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
"numeric_only",
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
(
|
||||
"var",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
"numeric_only",
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
78
midas/checker/frames/column_manager.py
Normal file
78
midas/checker/frames/column_manager.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frames.column_groupby_methods import Call as GroupByCall
|
||||
from midas.checker.frames.column_groupby_methods import ColumnGroupByMethodRegistry
|
||||
from midas.checker.frames.column_methods import Call, ColumnMethodRegistry
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import ColumnGroupBy, ColumnType, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
|
||||
|
||||
class ColumnManager:
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
self.method_resolver: ColumnMethodRegistry = ColumnMethodRegistry(self.typer)
|
||||
self.groupby_method_resolver: ColumnGroupByMethodRegistry = (
|
||||
ColumnGroupByMethodRegistry(self.typer)
|
||||
)
|
||||
|
||||
def call(
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
column: ColumnType,
|
||||
column_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: Call = Call(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
column=column,
|
||||
column_expr=column_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.method_resolver.call(method, call)
|
||||
|
||||
def groupby_call(
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
groupby: ColumnGroupBy,
|
||||
groupby_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: GroupByCall = GroupByCall(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
groupby=groupby,
|
||||
groupby_expr=groupby_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.groupby_method_resolver.call(method, call)
|
||||
|
||||
def get_attribute(self, column: ColumnType, name: str) -> Optional[Type]:
|
||||
types: TypesRegistry = self.typer.types
|
||||
match name:
|
||||
case "ndim" | "size":
|
||||
return types.get_type("int")
|
||||
|
||||
case "shape":
|
||||
return types.tuple_of("int")
|
||||
|
||||
case "T":
|
||||
return column
|
||||
|
||||
case _:
|
||||
return None
|
||||
400
midas/checker/frames/column_methods.py
Normal file
400
midas/checker/frames/column_methods.py
Normal file
@@ -0,0 +1,400 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
Function,
|
||||
GenericType,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
column: ColumnType
|
||||
column_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.column_expr, self.column)
|
||||
|
||||
|
||||
class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
def _element_binary_op(self, call: Call, method: str) -> ColumnType:
|
||||
"""Compute the result of an element-wise binary operation
|
||||
|
||||
This function delegates to the inner types for computing the resulting
|
||||
type.
|
||||
|
||||
Args:
|
||||
call (Call): the call that triggered this resolution
|
||||
method (str): the method name
|
||||
|
||||
Returns:
|
||||
ColumnType: the resulting column type
|
||||
"""
|
||||
column2: Optional[ColumnType] = None
|
||||
|
||||
col_type1: Type = call.column.type
|
||||
new_column: Type = ColumnType(type=UnknownType())
|
||||
if len(call.positional) != 0:
|
||||
other: Type = call.positional[0][1]
|
||||
unfolded_other: Type = unfold_type(other)
|
||||
if isinstance(unfolded_other, ColumnType):
|
||||
column2 = unfolded_other
|
||||
col_type2: Type = column2.type
|
||||
|
||||
new_inner_type = self.typer.result_of_binary_op(
|
||||
location=call.location,
|
||||
expr=call.call_expr,
|
||||
left=(call.column_expr, col_type1),
|
||||
right=(call.positional[0][0], col_type2),
|
||||
method=method,
|
||||
)
|
||||
new_column = ColumnType(type=new_inner_type)
|
||||
return new_column
|
||||
|
||||
def _element_wise(self, call: Call, method: str) -> Type:
|
||||
# TODO: support add with scalar
|
||||
|
||||
# Build signature with new column type and generic operand
|
||||
param_type: TypeVar = TypeVar(name="T", bound=None)
|
||||
signature = GenericType(
|
||||
name="add",
|
||||
params=[param_type],
|
||||
body=Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="other",
|
||||
type=ColumnType(type=param_type),
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
returns=self._element_binary_op(call, method),
|
||||
),
|
||||
)
|
||||
|
||||
# Map arguments and compute result type
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
if result.is_valid:
|
||||
self._assert_same_length(
|
||||
call.call_expr, call.column_expr, call.positional[0][0]
|
||||
)
|
||||
|
||||
return result.result
|
||||
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__add__")
|
||||
|
||||
@method("sub", "__sub__")
|
||||
def sub(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__sub__")
|
||||
|
||||
@method("mul", "__mul__")
|
||||
def mul(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mul__")
|
||||
|
||||
@method("div", "truediv", "__truediv__")
|
||||
def truediv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__truediv__")
|
||||
|
||||
@method("floordiv", "__floordiv__")
|
||||
def floordiv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__floordiv__")
|
||||
|
||||
@method("mod", "__mod__")
|
||||
def mod(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mod__")
|
||||
|
||||
@method("pow", "__pow__")
|
||||
def pow(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__pow__")
|
||||
|
||||
@method("lt", "__lt__")
|
||||
def lt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__lt__")
|
||||
|
||||
@method("gt", "__gt__")
|
||||
def gt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__gt__")
|
||||
|
||||
@method("le", "__le__")
|
||||
def le(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__le__")
|
||||
|
||||
@method("ge", "__ge__")
|
||||
def ge(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ge__")
|
||||
|
||||
@method("ne", "__ne__")
|
||||
def ne(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ne__")
|
||||
|
||||
@method("eq", "__eq__")
|
||||
def eq(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__eq__")
|
||||
|
||||
def _aggregate(
|
||||
self,
|
||||
call: Call,
|
||||
kwargs: list[Function.Parameter] = [],
|
||||
*,
|
||||
preserve_inner_type: bool = False,
|
||||
) -> Type:
|
||||
signature = Function(
|
||||
params=ParamSpec(
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
),
|
||||
returns=call.column if preserve_inner_type else ColumnType(type=TopType()),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method("kurtosis", "kurt")
|
||||
def kurtosis(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def mode(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method("product", "prod")
|
||||
def product(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="ddof",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="var",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def head(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
returns=call.column,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def tail(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
returns=call.column,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def groupby(self, call: Call) -> Type:
|
||||
bool_: Type = self.types.get_type("bool")
|
||||
function: Function = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="by",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="level",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=i + 2,
|
||||
name=name,
|
||||
type=bool_,
|
||||
required=False,
|
||||
)
|
||||
for i, name in enumerate(
|
||||
["as_index", "sort", "group_keys", "observed", "dropna"]
|
||||
)
|
||||
],
|
||||
),
|
||||
returns=ColumnGroupBy(column=call.column),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=function,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def _assert_same_length(self, call_expr: p.Expr, column1: p.Expr, column2: p.Expr):
|
||||
func_name: str = "__midas_column_same_length__"
|
||||
|
||||
# Efficiently compute length
|
||||
# https://stackoverflow.com/a/15943975/11109181
|
||||
def len_of_col(col: ast.expr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=ast.Name(id="len"),
|
||||
args=[
|
||||
ast.Attribute(
|
||||
value=col,
|
||||
attr="index",
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
|
||||
self.assertions.define(
|
||||
func_name,
|
||||
ast.FunctionDef(
|
||||
name=func_name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
ast.arg(arg="column1"),
|
||||
ast.arg(arg="column2"),
|
||||
],
|
||||
kwonlyargs=[],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
),
|
||||
body=[
|
||||
ast.Return(
|
||||
value=ast.Compare(
|
||||
left=len_of_col(ast.Name(id="column1")),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[
|
||||
len_of_col(ast.Name(id="column2")),
|
||||
],
|
||||
)
|
||||
)
|
||||
],
|
||||
decorator_list=[],
|
||||
),
|
||||
)
|
||||
self.assertions.add(
|
||||
bound_expr=call_expr,
|
||||
inputs=[column1, column2],
|
||||
builder=lambda c1, c2: ast.Call(
|
||||
func=ast.Name(id=func_name),
|
||||
args=[c1, c2],
|
||||
keywords=[],
|
||||
),
|
||||
message="Columns must have the same length",
|
||||
)
|
||||
103
midas/checker/frames/frame_groupby_methods.py
Normal file
103
midas/checker/frames/frame_groupby_methods.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
FrameGroupBy,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
groupby: FrameGroupBy
|
||||
groupby_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.groupby_expr, self.groupby)
|
||||
|
||||
|
||||
class FrameGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
NAMED_ARGS: dict[str, str] = {
|
||||
"numeric_only": "bool",
|
||||
"skipna": "bool",
|
||||
"engine": "str",
|
||||
"engine_kwargs": "dict",
|
||||
}
|
||||
|
||||
def _aggregate(self, call: Call, method: str) -> Type:
|
||||
new_columns: list[DataFrameType.Column] = []
|
||||
|
||||
for column in call.groupby.frame.columns:
|
||||
column_groupby: ColumnGroupBy = ColumnGroupBy(column=column.type)
|
||||
result_type: Type = self.typer.call_method(
|
||||
location=call.location,
|
||||
call_expr=call.call_expr,
|
||||
obj=(call.groupby_expr, column_groupby),
|
||||
method_name=method,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
if not isinstance(result_type, ColumnType):
|
||||
result_type = ColumnType(type=UnknownType())
|
||||
new_columns.append(
|
||||
DataFrameType.Column(
|
||||
index=column.index,
|
||||
name=column.name,
|
||||
type=result_type,
|
||||
)
|
||||
)
|
||||
|
||||
return DataFrameType(columns=new_columns)
|
||||
|
||||
@method()
|
||||
def kurt(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "kurt")
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "max")
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "mean")
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "median")
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "min")
|
||||
|
||||
@method()
|
||||
def prod(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "prod")
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "std")
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "sum")
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "var")
|
||||
255
midas/checker/frames/frame_manager.py
Normal file
255
midas/checker/frames/frame_manager.py
Normal file
@@ -0,0 +1,255 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, TypeGuard, cast
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frames.frame_groupby_methods import Call as GroupByCall
|
||||
from midas.checker.frames.frame_groupby_methods import FrameGroupByMethodRegistry
|
||||
from midas.checker.frames.frame_methods import Call, FrameMethodRegistry
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
FrameGroupBy,
|
||||
TupleType,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
|
||||
|
||||
def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
|
||||
return all(isinstance(expr, p.LiteralExpr) for expr in exprs)
|
||||
|
||||
|
||||
class FrameManager:
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
self.method_resolver: FrameMethodRegistry = FrameMethodRegistry(self.typer)
|
||||
self.groupby_method_resolver: FrameGroupByMethodRegistry = (
|
||||
FrameGroupByMethodRegistry(self.typer)
|
||||
)
|
||||
|
||||
def assign(
|
||||
self,
|
||||
reporter: FileReporter,
|
||||
location: Location,
|
||||
frame: DataFrameType,
|
||||
index: p.Expr,
|
||||
value_type: Type,
|
||||
) -> Type:
|
||||
match index:
|
||||
case p.LiteralExpr(value=str() as name):
|
||||
return self.assign_column(reporter, location, frame, name, value_type)
|
||||
|
||||
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
|
||||
isinstance(index.value, str) for index in indices
|
||||
):
|
||||
names: list[str] = [cast(str, index.value) for index in indices]
|
||||
|
||||
if not isinstance(value_type, TupleType):
|
||||
reporter.error(
|
||||
location,
|
||||
f"Cannot assign {type} to dataframe columns. Must be a tuple of columns",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
if len(names) != len(value_type.items):
|
||||
reporter.error(
|
||||
location,
|
||||
f"Wrong number of columns. Cannot assign {len(value_type.items)} to {len(names)} targets",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
new_frame: Type = frame
|
||||
for name, value in zip(names, value_type.items):
|
||||
new_frame = self.assign_column(
|
||||
reporter,
|
||||
location,
|
||||
new_frame,
|
||||
name,
|
||||
value,
|
||||
)
|
||||
if not isinstance(new_frame, DataFrameType):
|
||||
return new_frame
|
||||
return new_frame
|
||||
|
||||
case _:
|
||||
reporter.error(
|
||||
location, f"Invalid index type {index} on {frame} (assignment)"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def assign_column(
|
||||
self,
|
||||
reporter: FileReporter,
|
||||
location: Location,
|
||||
frame: DataFrameType,
|
||||
name: str,
|
||||
type: Type,
|
||||
) -> Type:
|
||||
if not isinstance(type, ColumnType):
|
||||
reporter.error(
|
||||
location,
|
||||
f"Cannot assign {type} to dataframe column. Must be a ColumnType",
|
||||
)
|
||||
return frame
|
||||
return self._set_column(frame, name, type)
|
||||
|
||||
def get(
|
||||
self,
|
||||
reporter: FileReporter,
|
||||
location: Location,
|
||||
frame: DataFrameType,
|
||||
index: p.Expr,
|
||||
) -> Type:
|
||||
match index:
|
||||
case p.LiteralExpr(value=str() as name):
|
||||
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
|
||||
if column is None:
|
||||
reporter.error(location, f"Unknown column '{name}' on {frame}")
|
||||
return UnknownType()
|
||||
return column
|
||||
|
||||
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
|
||||
isinstance(index.value, str) for index in indices
|
||||
):
|
||||
names: list[str] = [cast(str, index.value) for index in indices]
|
||||
columns: list[ColumnType] = []
|
||||
for name in names:
|
||||
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
|
||||
if column is None:
|
||||
reporter.error(location, f"Unknown column '{name}' on {frame}")
|
||||
return UnknownType()
|
||||
columns.append(column)
|
||||
return TupleType(items=tuple(columns))
|
||||
|
||||
case _:
|
||||
reporter.error(
|
||||
location, f"Invalid index type {index} on {frame} (access)"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def groupby_get(
|
||||
self,
|
||||
reporter: FileReporter,
|
||||
location: Location,
|
||||
groupby: FrameGroupBy,
|
||||
index: p.Expr,
|
||||
) -> Type:
|
||||
result: Type = self.get(reporter, location, groupby.frame, index)
|
||||
match result:
|
||||
case ColumnType():
|
||||
result = ColumnGroupBy(column=result)
|
||||
case TupleType(items=columns):
|
||||
result = TupleType(
|
||||
items=tuple(
|
||||
ColumnGroupBy(column=cast(ColumnType, column))
|
||||
for column in columns
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _set_column(
|
||||
cls, frame: DataFrameType, name: str, column: ColumnType
|
||||
) -> DataFrameType:
|
||||
new_columns: list[DataFrameType.Column] = []
|
||||
index: int = len(frame.columns)
|
||||
replace: bool = False
|
||||
for i, col in enumerate(frame.columns):
|
||||
if col.name == name:
|
||||
index = i
|
||||
replace = True
|
||||
# TODO: check column type here to prevent changing it
|
||||
new_columns.append(col)
|
||||
|
||||
new_col: DataFrameType.Column = DataFrameType.Column(
|
||||
index=index,
|
||||
name=name,
|
||||
type=column,
|
||||
)
|
||||
if replace:
|
||||
new_columns[index] = new_col
|
||||
else:
|
||||
new_columns.append(new_col)
|
||||
|
||||
return DataFrameType(columns=new_columns)
|
||||
|
||||
@classmethod
|
||||
def _set_columns(
|
||||
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
|
||||
) -> DataFrameType:
|
||||
for name, col in zip(names, columns):
|
||||
frame = cls._set_column(frame, name, col)
|
||||
return frame
|
||||
|
||||
@classmethod
|
||||
def _get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
|
||||
for col in frame.columns:
|
||||
if col.name == name:
|
||||
return col.type
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_columns(
|
||||
cls, frame: DataFrameType, names: list[str]
|
||||
) -> list[Optional[ColumnType]]:
|
||||
return [cls._get_column(frame, name) for name in names]
|
||||
|
||||
def call(
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
frame: DataFrameType,
|
||||
frame_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: Call = Call(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
frame=frame,
|
||||
frame_expr=frame_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.method_resolver.call(method, call)
|
||||
|
||||
def groupby_call(
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
groupby: FrameGroupBy,
|
||||
groupby_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: GroupByCall = GroupByCall(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
groupby=groupby,
|
||||
groupby_expr=groupby_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.groupby_method_resolver.call(method, call)
|
||||
|
||||
def get_attribute(self, frame: DataFrameType, name: str) -> Optional[Type]:
|
||||
types: TypesRegistry = self.typer.types
|
||||
match name:
|
||||
case "ndim" | "size":
|
||||
return types.get_type("int")
|
||||
|
||||
case "shape":
|
||||
return types.tuple_of("int", "int")
|
||||
|
||||
case _:
|
||||
return None
|
||||
479
midas/checker/frames/frame_methods.py
Normal file
479
midas/checker/frames/frame_methods.py
Normal file
@@ -0,0 +1,479 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import (
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
OverloadedFunction,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
Type,
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
frame: DataFrameType
|
||||
frame_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.frame_expr, self.frame)
|
||||
|
||||
|
||||
class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
def _get_method_result(
|
||||
self,
|
||||
call: Call,
|
||||
column1: ColumnType,
|
||||
column2: ColumnType,
|
||||
method: str,
|
||||
) -> ColumnType:
|
||||
"""Get the result of calling a method on a column, passing a second
|
||||
|
||||
This function delegates to the main typer the resolution of the method
|
||||
member, as well as computing the result type. Because we don't have any
|
||||
AST expression for the individual columns, the frame expressions are
|
||||
used instead.
|
||||
|
||||
Args:
|
||||
call (Call): the call that triggered this resolution
|
||||
column1 (ColumnType): the first column, i.e. left operand
|
||||
column2 (ColumnType): the second column, i.e. right operand
|
||||
method (str): the method name
|
||||
|
||||
Returns:
|
||||
ColumnType: the resulting column.
|
||||
If the operation is invalid / doesn't exist,
|
||||
`ColumnType(type=UnknownType())` is returned
|
||||
"""
|
||||
|
||||
result: Type = self.typer.result_of_binary_op(
|
||||
location=call.location,
|
||||
expr=call.call_expr,
|
||||
left=(call.frame_expr, column1),
|
||||
right=(call.positional[0][0], column2),
|
||||
method=method,
|
||||
)
|
||||
|
||||
if not isinstance(result, ColumnType):
|
||||
return ColumnType(type=UnknownType())
|
||||
return result
|
||||
|
||||
def _element_binary_op(self, call: Call, method: str) -> DataFrameType:
|
||||
"""Compute the result of an element-wise binary operation
|
||||
|
||||
This function delegates to the matching columns for computing resulting
|
||||
types. Any column only present in one of the frames is forwarded as a
|
||||
generic `ColumnType(type=UnknownType())`. Columns only in the second
|
||||
frame are append at the end of the schema.
|
||||
|
||||
Args:
|
||||
call (Call): the call that triggered this resolution
|
||||
method (str): the method name
|
||||
|
||||
Returns:
|
||||
DataFrameType: the resulting frame type
|
||||
"""
|
||||
new_columns: list[DataFrameType.Column] = []
|
||||
|
||||
by_name: dict[str, DataFrameType.Column] = {}
|
||||
frame2: Optional[DataFrameType] = None
|
||||
# Get map of operand's columns by name, if there is at least 1 operand, which is a dataframe
|
||||
if len(call.positional) != 0:
|
||||
operand: TypedExpr = call.positional[0]
|
||||
unfolded_other: Type = unfold_type(operand[1])
|
||||
if isinstance(unfolded_other, DataFrameType):
|
||||
frame2 = unfolded_other
|
||||
by_name = {
|
||||
col.name: col for col in frame2.columns if col.name is not None
|
||||
}
|
||||
|
||||
# Compute new schema:
|
||||
# Step 1: for all columns in frame1:
|
||||
# - if present in frame2 -> delegate operation to columns
|
||||
# - if not -> add to schema as unknown
|
||||
in_frame1: set[str] = set()
|
||||
for column in call.frame.columns:
|
||||
if column.name is not None:
|
||||
in_frame1.add(column.name)
|
||||
|
||||
col_type1: ColumnType = column.type
|
||||
col_type: ColumnType = ColumnType(type=UnknownType())
|
||||
if column.name in by_name:
|
||||
column2 = by_name[column.name]
|
||||
col_type2: ColumnType = column2.type
|
||||
|
||||
col_type = self._get_method_result(call, col_type1, col_type2, method)
|
||||
|
||||
new_column = DataFrameType.Column(
|
||||
index=column.index,
|
||||
name=column.name,
|
||||
type=col_type,
|
||||
)
|
||||
new_columns.append(new_column)
|
||||
|
||||
# Step 2: for all columns in frame2
|
||||
# - if not in frame1 -> add to schema as unknown
|
||||
if frame2 is not None:
|
||||
for column in frame2.columns:
|
||||
if column.name in in_frame1:
|
||||
continue
|
||||
new_columns.append(
|
||||
DataFrameType.Column(
|
||||
index=len(new_columns),
|
||||
name=column.name,
|
||||
type=ColumnType(type=UnknownType()),
|
||||
)
|
||||
)
|
||||
|
||||
return DataFrameType(columns=new_columns)
|
||||
|
||||
def _element_wise(self, call: Call, method: str) -> Type:
|
||||
# TODO: support scalar, sequence, Series, dict operand
|
||||
# Build signature with new schema and generic operand
|
||||
signature = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="other",
|
||||
type=DataFrameType(columns=[]),
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
returns=self._element_binary_op(call, method),
|
||||
)
|
||||
|
||||
# Map arguments and compute result type
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
if result.is_valid:
|
||||
self._assert_same_length(
|
||||
call.call_expr, call.frame_expr, call.positional[0][0]
|
||||
)
|
||||
|
||||
return result.result
|
||||
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__add__")
|
||||
|
||||
@method("sub", "__sub__")
|
||||
def sub(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__sub__")
|
||||
|
||||
@method("mul", "__mul__")
|
||||
def mul(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mul__")
|
||||
|
||||
@method("div", "truediv", "__truediv__")
|
||||
def truediv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__truediv__")
|
||||
|
||||
@method("floordiv", "__floordiv__")
|
||||
def floordiv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__floordiv__")
|
||||
|
||||
@method("mod", "__mod__")
|
||||
def mod(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mod__")
|
||||
|
||||
@method("pow", "__pow__")
|
||||
def pow(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__pow__")
|
||||
|
||||
@method("lt", "__lt__")
|
||||
def lt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__lt__")
|
||||
|
||||
@method("gt", "__gt__")
|
||||
def gt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__gt__")
|
||||
|
||||
@method("le", "__le__")
|
||||
def le(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__le__")
|
||||
|
||||
@method("ge", "__ge__")
|
||||
def ge(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ge__")
|
||||
|
||||
@method("ne", "__ne__")
|
||||
def ne(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ne__")
|
||||
|
||||
@method("eq", "__eq__")
|
||||
def eq(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__eq__")
|
||||
|
||||
def _aggregate(self, call: Call, kwargs: list[Function.Parameter] = []) -> Type:
|
||||
with_axis = Function(
|
||||
params=ParamSpec(
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
),
|
||||
returns=ColumnType(type=TopType()),
|
||||
)
|
||||
without_axis = Function(
|
||||
params=ParamSpec(
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=self.types.get_type("None"),
|
||||
required=True,
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
),
|
||||
returns=TopType(),
|
||||
)
|
||||
overload = OverloadedFunction(
|
||||
overloads=[
|
||||
with_axis,
|
||||
without_axis,
|
||||
]
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=overload,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method("kurtosis", "kurt")
|
||||
def kurtosis(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def mode(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method("product", "prod")
|
||||
def product(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="ddof",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="var",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def head(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
returns=call.frame,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def tail(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
returns=call.frame,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def groupby(self, call: Call) -> Type:
|
||||
bool_: Type = self.types.get_type("bool")
|
||||
function: Function = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="by",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="level",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=i + 2,
|
||||
name=name,
|
||||
type=bool_,
|
||||
required=False,
|
||||
)
|
||||
for i, name in enumerate(
|
||||
["as_index", "sort", "group_keys", "observed", "dropna"]
|
||||
)
|
||||
],
|
||||
),
|
||||
returns=FrameGroupBy(frame=call.frame),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=function,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr):
|
||||
func_name: str = "__midas_frame_same_length__"
|
||||
|
||||
# Efficiently compute length
|
||||
# https://stackoverflow.com/a/15943975/11109181
|
||||
def len_of_df(df: ast.expr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=ast.Name(id="len"),
|
||||
args=[
|
||||
ast.Attribute(
|
||||
value=df,
|
||||
attr="index",
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
|
||||
self.assertions.define(
|
||||
func_name,
|
||||
ast.FunctionDef(
|
||||
name=func_name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
ast.arg(arg="frame1"),
|
||||
ast.arg(arg="frame2"),
|
||||
],
|
||||
kwonlyargs=[],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
),
|
||||
body=[
|
||||
ast.Return(
|
||||
value=ast.Compare(
|
||||
left=len_of_df(ast.Name(id="frame1")),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[len_of_df(ast.Name(id="frame2"))],
|
||||
)
|
||||
)
|
||||
],
|
||||
decorator_list=[],
|
||||
),
|
||||
)
|
||||
self.assertions.add(
|
||||
bound_expr=call_expr,
|
||||
inputs=[frame1, frame2],
|
||||
builder=lambda f1, f2: ast.Call(
|
||||
func=ast.Name(id=func_name),
|
||||
args=[f1, f2],
|
||||
keywords=[],
|
||||
),
|
||||
message="DataFrames must have the same length",
|
||||
)
|
||||
100
midas/checker/frames/utils.py
Normal file
100
midas/checker/frames/utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Optional,
|
||||
Protocol,
|
||||
Self,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallDispatcher
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import Type, UnknownType
|
||||
from midas.generator.collector import AssertionCollector
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
|
||||
|
||||
class _MethodRegistryMeta(type):
|
||||
_methods: dict[str, Callable[..., Type]] = {}
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
name: str,
|
||||
bases: tuple[type, ...],
|
||||
namespace: dict[str, Any],
|
||||
):
|
||||
new_class = super().__new__(cls, name, bases, namespace)
|
||||
new_class._methods = {}
|
||||
for attr in namespace.values():
|
||||
if callable(attr) and hasattr(attr, "__method_names__"):
|
||||
for name in attr.__method_names__: # type: ignore
|
||||
new_class._methods[name] = attr # type: ignore
|
||||
return new_class
|
||||
|
||||
|
||||
class MethodCall(Protocol):
|
||||
@property
|
||||
def location(self) -> Location: ...
|
||||
|
||||
@property
|
||||
def call_expr(self) -> p.Expr: ...
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=MethodCall)
|
||||
|
||||
|
||||
class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
|
||||
@property
|
||||
def reporter(self) -> FileReporter:
|
||||
return self.typer.reporter
|
||||
|
||||
@property
|
||||
def types(self) -> TypesRegistry:
|
||||
return self.typer.types
|
||||
|
||||
@property
|
||||
def dispatcher(self) -> CallDispatcher[p.Expr]:
|
||||
return self.typer.dispatcher
|
||||
|
||||
@property
|
||||
def assertions(self) -> AssertionCollector:
|
||||
return self.typer.assertions
|
||||
|
||||
def call(self, method: str, call: T) -> Type:
|
||||
func: Optional[Callable[[Self, T], Type]] = self._methods.get(method)
|
||||
if func is None:
|
||||
self.reporter.warning(
|
||||
call.location, f"Unknown method {method} on {call.subject[1]}"
|
||||
)
|
||||
return UnknownType()
|
||||
return func(self, call)
|
||||
|
||||
|
||||
_Self = TypeVar("_Self", bound=MethodRegistry[Any])
|
||||
Method = Callable[[_Self, T], Type]
|
||||
|
||||
|
||||
def method(*names: str) -> Callable[[Method[_Self, T]], Method[_Self, T]]:
|
||||
def wrapper(func: Method[_Self, T]) -> Method[_Self, T]:
|
||||
names_: tuple[str, ...] = names
|
||||
if len(names_) == 0:
|
||||
names_ = (func.__name__,)
|
||||
setattr(func, "__method_names__", names_)
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
431
midas/checker/midas.py
Normal file
431
midas/checker/midas.py
Normal file
@@ -0,0 +1,431 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.builtins import define_builtins
|
||||
from midas.checker.dispatcher import CallDispatcher, CallResult
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
|
||||
from midas.checker.preamble import Preamble
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter, Reporter
|
||||
from midas.checker.types import (
|
||||
ColumnType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
DataFrameType,
|
||||
DerivedType,
|
||||
ExtensionType,
|
||||
Function,
|
||||
GenericType,
|
||||
ParamSpec,
|
||||
Predicate,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.checker.variance import VarianceInferrer
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
|
||||
|
||||
class ReturnException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MappedArgument:
|
||||
expr: m.Expr
|
||||
type: Type
|
||||
argument: Function.Parameter
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class OverloadCandidate:
|
||||
function: Function
|
||||
mapped: list[MappedArgument]
|
||||
|
||||
|
||||
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
|
||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||
|
||||
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||
self.logger: logging.Logger = logging.getLogger("MidasTyper")
|
||||
self.reporter: FileReporter = reporter.for_file(None)
|
||||
self.types: TypesRegistry = types
|
||||
self.dispatcher: CallDispatcher[m.Expr] = CallDispatcher[m.Expr](
|
||||
self.types, self.reporter
|
||||
)
|
||||
|
||||
self._local_variables: dict[str, TypeVar] = {}
|
||||
|
||||
self._predicate_params: dict[str, Type] = {}
|
||||
|
||||
self._current_name: Optional[str] = None
|
||||
|
||||
define_builtins(self.types)
|
||||
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
|
||||
self.process(builtins_path.read_text(), str(builtins_path))
|
||||
|
||||
self._bool: Type = self.get_type("bool")
|
||||
|
||||
self._preamble: Environment = Preamble(self.types)
|
||||
|
||||
def set_reporter(self, reporter: FileReporter):
|
||||
self.reporter = reporter
|
||||
self.dispatcher.set_reporter(reporter)
|
||||
|
||||
def process(self, source: str, path: Optional[str]):
|
||||
reporter: FileReporter = self.reporter.for_file(path)
|
||||
self.set_reporter(reporter)
|
||||
|
||||
lexer: MidasLexer = MidasLexer(source)
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser: MidasParser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
for error in parser.errors:
|
||||
self.reporter.error(error.token.get_location(), error.message)
|
||||
self.resolve(stmts)
|
||||
|
||||
def type_of(self, expr: m.Expr) -> Type:
|
||||
type: Type = expr.accept(self)
|
||||
return type
|
||||
|
||||
def get_type(self, name: str) -> Type:
|
||||
"""Get a type from its name
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
|
||||
Raises:
|
||||
NameError: if the type is not defined
|
||||
|
||||
Returns:
|
||||
Type: the type
|
||||
"""
|
||||
if name in self._local_variables:
|
||||
return self._local_variables[name]
|
||||
return self.types.get_type(name)
|
||||
|
||||
def get_variable(self, name: str) -> Type:
|
||||
if name in self._predicate_params:
|
||||
return self._predicate_params[name]
|
||||
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
|
||||
if predicate is not None:
|
||||
return predicate.type
|
||||
|
||||
global_: Optional[Type] = self._preamble.get(name)
|
||||
if global_ is not None:
|
||||
return global_
|
||||
|
||||
raise NameError(f"Unknown variable '{name}'")
|
||||
|
||||
def resolve(self, stmts: list[m.Stmt]):
|
||||
"""Process a sequence of statements
|
||||
|
||||
Args:
|
||||
stmts (list[m.Stmt]): the statements
|
||||
"""
|
||||
for stmt in stmts:
|
||||
stmt.accept(self)
|
||||
|
||||
for name, type in self.types._types.items():
|
||||
if isinstance(type, GenericType):
|
||||
inferrer = VarianceInferrer(self.types)
|
||||
self.types._types[name] = inferrer.infer(type)
|
||||
|
||||
def assert_bool(self, expr: m.Expr):
|
||||
type: Type = self.type_of(expr)
|
||||
if not self.types.is_subtype(type, self._bool):
|
||||
self.reporter.error(expr.location, f"Must be a boolean but is {type}")
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
name: str = stmt.name.lexeme
|
||||
self._current_name = name
|
||||
params: list[TypeVar] = self._resolve_type_params(stmt.params)
|
||||
|
||||
type: Type = stmt.type.accept(self)
|
||||
if len(params) != 0:
|
||||
type = GenericType(name=name, params=params, body=type)
|
||||
else:
|
||||
type = DerivedType(name=name, type=type)
|
||||
self.types.define_type(name, type)
|
||||
self._local_variables.clear()
|
||||
self._current_name = None
|
||||
|
||||
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
|
||||
name: str = stmt.name.lexeme
|
||||
self._current_name = name
|
||||
type: Type = stmt.type.accept(self)
|
||||
self.types.define_type(name, type)
|
||||
self._current_name = None
|
||||
|
||||
def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ...
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self._resolve_type_params(stmt.params)
|
||||
base_name: str = stmt.name.lexeme
|
||||
try:
|
||||
_ = self.get_type(base_name)
|
||||
except NameError:
|
||||
self.reporter.error(stmt.name.get_location(), f"Unknown type '{base_name}'")
|
||||
|
||||
for member in stmt.members:
|
||||
member_type: Type = member.type.accept(self)
|
||||
self.types.define_member(
|
||||
base_name,
|
||||
member.name.lexeme,
|
||||
member_type,
|
||||
member.kind,
|
||||
)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||
for spec in stmt.params:
|
||||
for param in spec.mixed:
|
||||
assert param.name is not None
|
||||
self._predicate_params[param.name.lexeme] = param.type.accept(self)
|
||||
|
||||
type: Type = self.type_of(stmt.body)
|
||||
params: list[ParamSpec] = [self._visit_param_spec(spec) for spec in stmt.params]
|
||||
|
||||
if not self._is_valid_predicate(type):
|
||||
self.reporter.error(
|
||||
stmt.body.location,
|
||||
f"Predicate function body must evaluate to a boolean, got {type}",
|
||||
)
|
||||
if len(params) != 0:
|
||||
type = self._bool
|
||||
for spec in reversed(params):
|
||||
type = Function(
|
||||
params=spec,
|
||||
returns=type,
|
||||
)
|
||||
self._predicate_params = {}
|
||||
self.types.define_predicate(
|
||||
stmt.name.lexeme,
|
||||
Predicate(
|
||||
type=type,
|
||||
body=stmt.body,
|
||||
alias=len(params) == 0,
|
||||
),
|
||||
)
|
||||
|
||||
def _is_valid_predicate(self, body: Type) -> bool:
|
||||
match body:
|
||||
case Function(returns=returns):
|
||||
return self._is_valid_predicate(returns)
|
||||
case _ if self.types.is_subtype(body, self._bool):
|
||||
return True
|
||||
case _:
|
||||
return False
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type:
|
||||
self.assert_bool(expr.left)
|
||||
self.assert_bool(expr.right)
|
||||
return self._bool
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
|
||||
method: Optional[str] = MIDAS_BINARY_METHODS.get(expr.operator.type)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator.lexeme}")
|
||||
self.reporter.warning(
|
||||
expr.location, f"Unsupported operator {expr.operator.lexeme}"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||
|
||||
def _visit_binary_expr(
|
||||
self, location: Location, left_expr: m.Expr, right_expr: m.Expr, method: str
|
||||
) -> Type:
|
||||
left: Type = self.type_of(left_expr)
|
||||
right: Type = self.type_of(right_expr)
|
||||
|
||||
operation: Optional[Type] = self.types.lookup_member(left, method)
|
||||
if operation is None:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Undefined operation {method} between {left} and {right}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=location,
|
||||
callee=operation,
|
||||
positional=[(right_expr, right)],
|
||||
keywords={},
|
||||
)
|
||||
return result.result
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
|
||||
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator.lexeme}")
|
||||
self.reporter.warning(
|
||||
expr.location, f"Unsupported operator {expr.operator.lexeme}"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
operand: Type = self.type_of(expr.right)
|
||||
operation: Optional[Type] = self.types.lookup_member(operand, method)
|
||||
if operation is None:
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Undefined operation {method} for {operand}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=expr.location,
|
||||
callee=operation,
|
||||
positional=[],
|
||||
keywords={},
|
||||
)
|
||||
return result.result
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> Type:
|
||||
callee: Type = expr.callee.accept(self)
|
||||
positional: list[tuple[m.Expr, Type]] = [
|
||||
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||
]
|
||||
keywords: dict[str, tuple[m.Expr, Type]] = {
|
||||
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
||||
}
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=expr.location,
|
||||
callee=callee,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> Type:
|
||||
object: Type = expr.expr.accept(self)
|
||||
member: Optional[Type] = self.types.lookup_member(object, expr.name.lexeme)
|
||||
if member is None:
|
||||
self.reporter.error(
|
||||
expr.location, f"Unknown member '{expr.name.lexeme}' of {object}"
|
||||
)
|
||||
return UnknownType()
|
||||
return member
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> Type:
|
||||
return self.get_variable(expr.name.lexeme)
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
|
||||
return expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type:
|
||||
match expr.value:
|
||||
case bool(): # Must be before int
|
||||
return self.types.get_type("bool")
|
||||
case int():
|
||||
return self.types.get_type("int")
|
||||
case float():
|
||||
return self.types.get_type("float")
|
||||
case str():
|
||||
return self.types.get_type("str")
|
||||
case _:
|
||||
self.reporter.warning(expr.location, f"Unknown literal {expr}")
|
||||
return UnknownType()
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type:
|
||||
return self.get_variable("_")
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||
name: str = type.name.lexeme
|
||||
try:
|
||||
return self.get_type(name)
|
||||
except NameError:
|
||||
msg: str = f"Undefined type {name}"
|
||||
if self._current_name == name:
|
||||
msg += ". Recursive types are not supported, use an extend block"
|
||||
self.reporter.error(type.name.get_location(), msg)
|
||||
return UnknownType()
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
args: list[Type] = [arg.accept(self) for arg in type.args]
|
||||
try:
|
||||
return self.types.apply_generic(type_, args)
|
||||
except Exception as e:
|
||||
self.reporter.error(type.location, f"Cannot apply generic type: {e}")
|
||||
return UnknownType()
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||
return ConstraintType(
|
||||
type=type.type.accept(self),
|
||||
constraint=type.constraint,
|
||||
)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
|
||||
return ComplexType(
|
||||
members={
|
||||
member.name.lexeme: member.type.accept(self) for member in type.members
|
||||
}
|
||||
)
|
||||
|
||||
def visit_extension_type(self, type: m.ExtensionType) -> Type:
|
||||
return ExtensionType(
|
||||
base=type.base.accept(self),
|
||||
extension=self.visit_complex_type(type.extension),
|
||||
)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> Type:
|
||||
return Function(
|
||||
params=self._visit_param_spec(type.params),
|
||||
returns=type.returns.accept(self),
|
||||
)
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> ParamSpec:
|
||||
n_pos: int = len(spec.pos)
|
||||
n_mixed: int = len(spec.mixed)
|
||||
|
||||
def process_param(
|
||||
param: m.FunctionType.Parameter, i: int
|
||||
) -> Function.Parameter:
|
||||
return Function.Parameter(
|
||||
pos=i,
|
||||
name=param.name.lexeme if param.name is not None else str(i),
|
||||
type=param.type.accept(self),
|
||||
required=param.required,
|
||||
)
|
||||
|
||||
return ParamSpec(
|
||||
pos=[process_param(param, i) for i, param in enumerate(spec.pos)],
|
||||
mixed=[
|
||||
process_param(param, i + n_pos) for i, param in enumerate(spec.mixed)
|
||||
],
|
||||
kw=[
|
||||
process_param(param, i + n_pos + n_mixed)
|
||||
for i, param in enumerate(spec.kw)
|
||||
],
|
||||
)
|
||||
|
||||
def visit_frame_type(self, type: m.FrameType) -> Type:
|
||||
def process_column(i: int, col: m.FrameType.Column) -> DataFrameType.Column:
|
||||
return DataFrameType.Column(
|
||||
index=i,
|
||||
name=col.name.lexeme,
|
||||
type=ColumnType(type=col.type.accept(self)),
|
||||
)
|
||||
|
||||
return DataFrameType(
|
||||
columns=[process_column(i, col) for i, col in enumerate(type.columns)]
|
||||
)
|
||||
|
||||
def _resolve_type_params(self, params: list[m.TypeParam]):
|
||||
vars: list[TypeVar] = []
|
||||
for param in params:
|
||||
name: str = param.name.lexeme
|
||||
bound: Optional[Type] = None
|
||||
if param.bound is not None:
|
||||
bound = param.bound.accept(self)
|
||||
var = TypeVar(name=name, bound=bound)
|
||||
self._local_variables[name] = var
|
||||
vars.append(var)
|
||||
return vars
|
||||
71
midas/checker/operators.py
Normal file
71
midas/checker/operators.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import ast
|
||||
from typing import Type
|
||||
|
||||
from midas.lexer.token import TokenType
|
||||
|
||||
PY_OPERATOR_METHODS: dict[Type[ast.operator], str] = {
|
||||
ast.Add: "__add__",
|
||||
ast.Sub: "__sub__",
|
||||
ast.Mult: "__mul__",
|
||||
ast.MatMult: "__matmul__",
|
||||
ast.Div: "__truediv__",
|
||||
ast.Mod: "__mod__",
|
||||
ast.Pow: "__pow__",
|
||||
ast.LShift: "__lshift__",
|
||||
ast.RShift: "__rshift__",
|
||||
ast.BitOr: "__or__",
|
||||
ast.BitXor: "__xor__",
|
||||
ast.BitAnd: "__and__",
|
||||
ast.FloorDiv: "__floordiv__",
|
||||
}
|
||||
|
||||
PY_COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
||||
ast.Eq: "__eq__",
|
||||
ast.NotEq: "__eq__",
|
||||
ast.Lt: "__lt__",
|
||||
ast.LtE: "__le__",
|
||||
ast.Gt: "__gt__",
|
||||
ast.GtE: "__ge__",
|
||||
# ast.Is: "__is__",
|
||||
# ast.IsNot: "__isnot__",
|
||||
# ast.In: "__in__",
|
||||
# ast.NotIn: "__notin__",
|
||||
}
|
||||
|
||||
PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = {
|
||||
ast.Invert: "__invert__",
|
||||
# ast.Not: "",
|
||||
ast.UAdd: "__pos__",
|
||||
ast.USub: "__neg__",
|
||||
}
|
||||
|
||||
|
||||
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
|
||||
TokenType.PLUS: "__add__",
|
||||
TokenType.MINUS: "__sub__",
|
||||
TokenType.STAR: "__mul__",
|
||||
TokenType.SLASH: "__truediv__",
|
||||
# TokenType.MODULO: "__mod__",
|
||||
# TokenType.POW: "__pow__",
|
||||
# ast.BitOr: "__or__",
|
||||
# ast.BitXor: "__xor__",
|
||||
# ast.BitAnd: "__and__",
|
||||
# ast.FloorDiv: "__floordiv__",
|
||||
TokenType.EQUAL_EQUAL: "__eq__",
|
||||
TokenType.BANG_EQUAL: "__eq__",
|
||||
TokenType.LESS: "__lt__",
|
||||
TokenType.LESS_EQUAL: "__le__",
|
||||
TokenType.GREATER: "__gt__",
|
||||
TokenType.GREATER_EQUAL: "__ge__",
|
||||
# ast.Is: "__is__",
|
||||
# ast.IsNot: "__isnot__",
|
||||
# ast.In: "__in__",
|
||||
# ast.NotIn: "__notin__",
|
||||
}
|
||||
|
||||
MIDAS_UNARY_METHODS: dict[TokenType, str] = {
|
||||
# ast.Invert: "__invert__",
|
||||
# ast.Not: "",
|
||||
# TokenType.PLUS: "__pos__",
|
||||
TokenType.MINUS: "__neg__",
|
||||
}
|
||||
213
midas/checker/preamble.py
Normal file
213
midas/checker/preamble.py
Normal file
@@ -0,0 +1,213 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import (
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Param:
|
||||
name: str
|
||||
type: Type
|
||||
required: bool = True
|
||||
|
||||
|
||||
class Preamble(Environment):
|
||||
def __init__(self, types: TypesRegistry) -> None:
|
||||
super().__init__()
|
||||
self._types: TypesRegistry = types
|
||||
self._python_funcs: dict[str, Callable[..., Any]] = {}
|
||||
|
||||
self._def_type_constructor("object", object)
|
||||
self._def_type_constructor("float", float)
|
||||
self._def_type_constructor("int", int)
|
||||
self._def_type_constructor("bool", bool)
|
||||
self._def_type_constructor("str", str)
|
||||
self._def_function(
|
||||
name="list",
|
||||
pos=[Param("object", TopType())],
|
||||
returns=self._list_of(TopType()),
|
||||
py_function=list,
|
||||
)
|
||||
|
||||
# TODO: use sink
|
||||
self._def_function(
|
||||
name="print",
|
||||
pos=[Param("object", TopType(), required=False)],
|
||||
returns=UnitType(),
|
||||
py_function=print,
|
||||
)
|
||||
|
||||
map_in = TypeVar(name="T", bound=None)
|
||||
map_out = TypeVar(name="U", bound=None)
|
||||
mapper = self._make_function(
|
||||
name="MapTransform",
|
||||
pos=[Param("v", map_in)],
|
||||
returns=map_out,
|
||||
)
|
||||
self._def_function(
|
||||
name="map",
|
||||
pos=[
|
||||
Param("transform", mapper),
|
||||
Param(
|
||||
"iterable",
|
||||
self._list_of(map_in), # TODO: replace with Iterable[T]
|
||||
),
|
||||
],
|
||||
returns=self._list_of(map_out), # TODO: replace with Iterable[U]
|
||||
type_vars=[map_in, map_out],
|
||||
py_function=map,
|
||||
)
|
||||
self._def_function(
|
||||
name="input",
|
||||
pos=[Param("prompt", TopType(), required=False)],
|
||||
returns=self._types.get_type("str"),
|
||||
)
|
||||
self._def_function(
|
||||
name="len",
|
||||
pos=[Param("object", TopType())],
|
||||
returns=self._types.get_type("int"),
|
||||
)
|
||||
|
||||
T = TypeVar(name="T", bound=None)
|
||||
self._def_overloads(
|
||||
name="max",
|
||||
py_function=max,
|
||||
signatures=[
|
||||
(
|
||||
[Param("arg1", T), Param("arg2", T)],
|
||||
[],
|
||||
[],
|
||||
T,
|
||||
[T],
|
||||
),
|
||||
([Param("iterable", self._list_of(T))], [], [], T, [T]),
|
||||
],
|
||||
)
|
||||
self._def_overloads(
|
||||
name="min",
|
||||
py_function=min,
|
||||
signatures=[
|
||||
(
|
||||
[Param("arg1", T), Param("arg2", T)],
|
||||
[],
|
||||
[],
|
||||
T,
|
||||
[T],
|
||||
),
|
||||
([Param("iterable", self._list_of(T))], [], [], T, [T]),
|
||||
],
|
||||
)
|
||||
|
||||
def _list_of(self, item_type: str | Type) -> Type:
|
||||
return self._types.list_of(item_type)
|
||||
|
||||
def _def_type_constructor(
|
||||
self, name: str, py_function: Optional[Callable[..., Any]] = None
|
||||
):
|
||||
# TODO: more specific arg types
|
||||
self._def_function(
|
||||
name=name,
|
||||
pos=[Param("object", TopType(), required=False)],
|
||||
returns=self._types.get_type(name),
|
||||
py_function=py_function,
|
||||
)
|
||||
|
||||
def _make_function(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
pos: list[Param] = [],
|
||||
mixed: list[Param] = [],
|
||||
kw: list[Param] = [],
|
||||
returns: Type = UnitType(),
|
||||
type_vars: list[TypeVar] = [],
|
||||
) -> Type:
|
||||
def map_params(params: list[Param], offset: int) -> list[Function.Parameter]:
|
||||
return [
|
||||
Function.Parameter(
|
||||
pos=i + offset,
|
||||
name=param.name,
|
||||
type=param.type,
|
||||
required=param.required,
|
||||
)
|
||||
for i, param in enumerate(params)
|
||||
]
|
||||
|
||||
function = Function(
|
||||
params=ParamSpec(
|
||||
pos=map_params(pos, 0),
|
||||
mixed=map_params(mixed, len(pos)),
|
||||
kw=map_params(kw, len(pos) + len(mixed)),
|
||||
),
|
||||
returns=returns,
|
||||
)
|
||||
if len(type_vars) != 0:
|
||||
function = GenericType(
|
||||
name=name,
|
||||
params=type_vars,
|
||||
body=function,
|
||||
)
|
||||
return function
|
||||
|
||||
def _def_function(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
pos: list[Param] = [],
|
||||
mixed: list[Param] = [],
|
||||
kw: list[Param] = [],
|
||||
returns: Type = UnitType(),
|
||||
type_vars: list[TypeVar] = [],
|
||||
py_function: Optional[Callable[..., Any]] = None,
|
||||
):
|
||||
function: Type = self._make_function(
|
||||
name=name,
|
||||
pos=pos,
|
||||
mixed=mixed,
|
||||
kw=kw,
|
||||
returns=returns,
|
||||
type_vars=type_vars,
|
||||
)
|
||||
self.define(name, function)
|
||||
if py_function is not None:
|
||||
self._python_funcs[name] = py_function
|
||||
|
||||
def _def_overloads(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
signatures: list[
|
||||
tuple[list[Param], list[Param], list[Param], Type, list[TypeVar]]
|
||||
],
|
||||
py_function: Optional[Callable[..., Any]] = None,
|
||||
):
|
||||
overloads: list[Type] = []
|
||||
for pos, mixed, kw, returns, type_vars in signatures:
|
||||
overloads.append(
|
||||
self._make_function(
|
||||
name=name,
|
||||
pos=pos,
|
||||
mixed=mixed,
|
||||
kw=kw,
|
||||
returns=returns,
|
||||
type_vars=type_vars,
|
||||
)
|
||||
)
|
||||
function: Type = OverloadedFunction(overloads=overloads)
|
||||
self.define(name, function)
|
||||
if py_function is not None:
|
||||
self._python_funcs[name] = py_function
|
||||
|
||||
def get_py_func(self, name: str) -> Optional[Callable[..., Any]]:
|
||||
return self._python_funcs.get(name)
|
||||
1153
midas/checker/python.py
Normal file
1153
midas/checker/python.py
Normal file
File diff suppressed because it is too large
Load Diff
488
midas/checker/registry.py
Normal file
488
midas/checker/registry.py
Normal file
@@ -0,0 +1,488 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.midas import MemberKind
|
||||
from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ColumnType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
DataFrameType,
|
||||
DerivedType,
|
||||
ExtensionType,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
Predicate,
|
||||
TopType,
|
||||
TupleType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
Variance,
|
||||
substitute_typevars,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Member:
|
||||
kind: MemberKind
|
||||
type: Type
|
||||
|
||||
|
||||
class TypesRegistry:
|
||||
def __init__(self) -> None:
|
||||
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
|
||||
self._types: dict[str, Type] = {}
|
||||
self._members: dict[str, dict[str, Member]] = {}
|
||||
self._predicates: dict[str, Predicate] = {}
|
||||
|
||||
def get_type(self, name: str) -> Type:
|
||||
"""Get a type from its name
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
|
||||
Raises:
|
||||
NameError: if the type is not defined
|
||||
|
||||
Returns:
|
||||
Type: the type
|
||||
"""
|
||||
if name in self._types:
|
||||
return self._types[name]
|
||||
raise NameError(f"Undefined type {name}")
|
||||
|
||||
def define_type(self, name: str, type: Type) -> Type:
|
||||
"""Define a type in the registry
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
type (Type): the type to define
|
||||
|
||||
Raises:
|
||||
ValueError: if a type is already defined with that name
|
||||
|
||||
Returns:
|
||||
Type: the defined type
|
||||
"""
|
||||
if name in self._types:
|
||||
raise ValueError(f"Type {name} already defined")
|
||||
self._types[name] = type
|
||||
return type
|
||||
|
||||
def define_member(
|
||||
self,
|
||||
type_name: str,
|
||||
member_name: str,
|
||||
member_type: Type,
|
||||
kind: MemberKind,
|
||||
):
|
||||
members: dict[str, Member] = self._members.setdefault(type_name, {})
|
||||
if member_name in members:
|
||||
current: Member = members[member_name]
|
||||
if current.kind != kind:
|
||||
self.logger.error(
|
||||
f"Member '{member_name}' is already defined as a {current.kind},"
|
||||
+ f" cannot define a {kind} with the same name"
|
||||
)
|
||||
return
|
||||
if kind != MemberKind.METHOD:
|
||||
self.logger.error(
|
||||
f"Member '{member_name}' already defined for type {type_name},"
|
||||
+ " only methods can be overloaded"
|
||||
)
|
||||
return
|
||||
|
||||
combined: Type
|
||||
match current.type:
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
combined = OverloadedFunction(overloads=overloads + [member_type])
|
||||
case _:
|
||||
combined = OverloadedFunction(overloads=[current.type, member_type])
|
||||
members[member_name] = Member(kind=current.kind, type=combined)
|
||||
|
||||
else:
|
||||
members[member_name] = Member(kind=kind, type=member_type)
|
||||
|
||||
def define_predicate(self, name: str, predicate: Predicate):
|
||||
if name in self._predicates:
|
||||
raise ValueError(f"Predicate {name} already defined")
|
||||
self._predicates[name] = predicate
|
||||
|
||||
def is_builtin_subtype(self, name1: str, name2: str) -> bool:
|
||||
subtypes: set[str] = BUILTIN_SUBTYPES.get(name2, set())
|
||||
if name1 in subtypes:
|
||||
return True
|
||||
for subtype in subtypes:
|
||||
if self.is_builtin_subtype(name1, subtype):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||
"""Check whether `type1` is a subtype of `type2`
|
||||
|
||||
For more details on the rules checked here, see TAPL Chap. 15-16-17
|
||||
|
||||
Args:
|
||||
type1 (Type): the potential subtype
|
||||
type2 (Type): the potential supertype
|
||||
|
||||
Returns:
|
||||
bool: whether `type1` is a subtype of `type2`
|
||||
"""
|
||||
|
||||
if type1 == type2:
|
||||
return True
|
||||
|
||||
match (type1, type2):
|
||||
case (_, TopType()):
|
||||
return True
|
||||
|
||||
case (_, UnknownType()):
|
||||
return True
|
||||
|
||||
case (TypeVar(bound=bound), _):
|
||||
if bound is None:
|
||||
return False
|
||||
return self.is_subtype(bound, type2)
|
||||
|
||||
case (_, TypeVar(bound=bound)):
|
||||
if bound is None:
|
||||
return True
|
||||
return self.is_subtype(type1, bound)
|
||||
|
||||
case (DerivedType(type=base1), _):
|
||||
return self.is_subtype(base1, type2)
|
||||
|
||||
case (BaseType(name=name1), BaseType(name=name2)):
|
||||
return self.is_builtin_subtype(name1, name2)
|
||||
|
||||
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
||||
for k, t in props2.items():
|
||||
if k not in props1:
|
||||
return False
|
||||
if not self.is_subtype(props1[k], t):
|
||||
return False
|
||||
return True
|
||||
|
||||
case (DataFrameType(columns=columns1), DataFrameType(columns=columns2)):
|
||||
# TODO: check order?
|
||||
by_name1: dict[str, DataFrameType.Column] = {
|
||||
col.name: col for col in columns1 if col.name is not None
|
||||
}
|
||||
for col2 in columns2:
|
||||
if col2.name not in by_name1:
|
||||
return False
|
||||
if not self.is_subtype(by_name1[col2.name].type, col2.type):
|
||||
return False
|
||||
return True
|
||||
|
||||
case (ColumnType(type=inner1), ColumnType(type=inner2)):
|
||||
# TODO: invariant, replace ColumnType with simple GenericType
|
||||
if not self.are_equivalent(inner1, inner2):
|
||||
return False
|
||||
return True
|
||||
|
||||
case (Function(), Function()):
|
||||
return self.is_func_subtype(type1, type2)
|
||||
|
||||
case (ConstraintType(type=base1), _):
|
||||
return self.is_subtype(base1, type2)
|
||||
|
||||
case (
|
||||
AppliedType(name=name1, args=args1),
|
||||
AppliedType(name=name2, args=args2),
|
||||
) if (
|
||||
name1 == name2
|
||||
):
|
||||
generic: Type = self.get_type(name1)
|
||||
assert isinstance(generic, GenericType)
|
||||
for param, arg1, arg2 in zip(generic.params, args1, args2):
|
||||
variance: Variance = param.variance
|
||||
if variance in {Variance.INVARIANT, Variance.COVARIANT}:
|
||||
if not self.is_subtype(arg1, arg2):
|
||||
return False
|
||||
if variance in {Variance.INVARIANT, Variance.CONTRAVARIANT}:
|
||||
if not self.is_subtype(arg2, arg1):
|
||||
return False
|
||||
return True
|
||||
|
||||
# TODO: verify legitimacy
|
||||
case (AppliedType(body=body), _):
|
||||
return self.is_subtype(body, type2)
|
||||
|
||||
return False
|
||||
|
||||
def are_equivalent(self, type1: Type, type2: Type) -> bool:
|
||||
return self.is_subtype(type1, type2) and self.is_subtype(type2, type1)
|
||||
|
||||
# TODO: verify the logic in here
|
||||
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
|
||||
"""Check whether a function is a subtype of another
|
||||
|
||||
Args:
|
||||
func1 (Function): the potential function subtype
|
||||
func2 (Function): the potential function supertype
|
||||
|
||||
Returns:
|
||||
bool: whether `func1` is a subtype of `func2`
|
||||
"""
|
||||
if not self.is_subtype(func1.returns, func2.returns):
|
||||
return False
|
||||
|
||||
pos1: list[Function.Parameter] = func1.params.pos
|
||||
mixed1: list[Function.Parameter] = func1.params.mixed
|
||||
kw1: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in func1.params.kw
|
||||
}
|
||||
pos2: list[Function.Parameter] = func2.params.pos
|
||||
mixed2: list[Function.Parameter] = func2.params.mixed
|
||||
kw2: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in func2.params.kw
|
||||
}
|
||||
|
||||
mixed_by_pos: dict[int, Function.Parameter] = {
|
||||
param.pos: param for param in mixed2
|
||||
}
|
||||
mixed_by_name: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in mixed2
|
||||
}
|
||||
|
||||
def is_arg_subtype(sub: Function.Parameter, sup: Function.Parameter) -> bool:
|
||||
if not self.is_subtype(sub.type, sup.type):
|
||||
return False
|
||||
if not sup.required and sub.required:
|
||||
return False
|
||||
return True
|
||||
|
||||
for param1 in pos1:
|
||||
param2: Function.Parameter
|
||||
if param1.pos < len(pos2):
|
||||
param2 = pos2[param1.pos]
|
||||
elif param1.pos in mixed_by_pos:
|
||||
param2 = mixed_by_pos[param1.pos]
|
||||
elif not param1.required:
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
if not is_arg_subtype(param2, param1):
|
||||
return False
|
||||
|
||||
for name, param1 in kw1.items():
|
||||
param2: Function.Parameter
|
||||
if name in kw2:
|
||||
param2 = kw2[name]
|
||||
elif name in mixed_by_name:
|
||||
param2 = mixed_by_name[name]
|
||||
elif not param1.required:
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
if not is_arg_subtype(param2, param1):
|
||||
return False
|
||||
|
||||
for param1 in mixed1:
|
||||
pos_param2: Optional[Function.Parameter] = None
|
||||
kw_param2: Optional[Function.Parameter] = None
|
||||
if param1.name in kw2:
|
||||
kw_param2 = kw2[param1.name]
|
||||
elif param1.name in mixed_by_name:
|
||||
kw_param2 = mixed_by_name[param1.name]
|
||||
if param1.pos < len(pos2):
|
||||
pos_param2 = pos2[param1.pos]
|
||||
elif param1.pos in mixed_by_pos:
|
||||
pos_param2 = mixed_by_pos[param1.pos]
|
||||
|
||||
# No match in func2 and arg is required
|
||||
if pos_param2 is None and kw_param2 is None and param1.required:
|
||||
return False
|
||||
|
||||
# Matching keyword argument
|
||||
if kw_param2 is not None and not is_arg_subtype(kw_param2, param1):
|
||||
return False
|
||||
|
||||
# Matching positional argument
|
||||
if pos_param2 is not None and not is_arg_subtype(pos_param2, param1):
|
||||
return False
|
||||
|
||||
mixed_positions: set[int] = {param.pos for param in mixed1}
|
||||
mixed_names: set[str] = {param.name for param in mixed1}
|
||||
for param2 in pos2:
|
||||
if not param2.required:
|
||||
continue
|
||||
if param2.pos >= len(pos1) and param2.pos not in mixed_positions:
|
||||
return False
|
||||
|
||||
for name, param2 in kw2.items():
|
||||
if not param2.required:
|
||||
continue
|
||||
if name not in kw1 and name not in mixed_names:
|
||||
return False
|
||||
|
||||
for param2 in mixed2:
|
||||
if param2.required:
|
||||
continue
|
||||
pos_match: bool = param2.pos < len(pos1) or param2.pos in mixed_positions
|
||||
kw_match: bool = param2.name in kw1 or param2.name in mixed_names
|
||||
if not pos_match or not kw_match:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def apply_generic(self, type: Type, args: list[Type]) -> Type:
|
||||
match type:
|
||||
case DerivedType(name=name, type=base):
|
||||
return DerivedType(name=name, type=self.apply_generic(base, args))
|
||||
|
||||
case GenericType(name=name, params=type_vars, body=body):
|
||||
n_args: int = len(args)
|
||||
n_type_vars: int = len(type_vars)
|
||||
if n_args < n_type_vars:
|
||||
raise ValueError(
|
||||
f"Missing type arguments, expected {n_type_vars} but only {n_args} provided"
|
||||
)
|
||||
if n_args > n_type_vars:
|
||||
raise ValueError(
|
||||
f"Too many type arguments, expected {n_type_vars} but {n_args} provided"
|
||||
)
|
||||
substitutions: dict[str, Type] = {}
|
||||
for arg, type_var in zip(args, type_vars):
|
||||
if type_var.bound is not None and not self.is_subtype(
|
||||
arg, type_var.bound
|
||||
):
|
||||
raise ValueError(
|
||||
f"Type argument {arg} is not a subtype of {type_var.bound}"
|
||||
)
|
||||
substitutions[type_var.name] = arg
|
||||
return AppliedType(
|
||||
name=name,
|
||||
args=args,
|
||||
body=substitute_typevars(body, substitutions),
|
||||
)
|
||||
|
||||
case BaseType(name="tuple"):
|
||||
return TupleType(items=tuple(args))
|
||||
|
||||
case _:
|
||||
raise ValueError(f"{type} is not a generic type")
|
||||
|
||||
def reduce_types(self, types: list[Type]) -> list[Type]:
|
||||
"""Reduce a list of types to remove subtypes and only keep the highest types
|
||||
|
||||
Args:
|
||||
types (list[Type]): the types to reduce
|
||||
|
||||
Returns:
|
||||
list[Type]: the reduced list of types
|
||||
"""
|
||||
|
||||
reduced: bool = True
|
||||
keep: list[int] = list(range(len(types)))
|
||||
while reduced:
|
||||
reduced = False
|
||||
for i, i1 in enumerate(keep):
|
||||
type1: Type = types[i1]
|
||||
for i2 in keep[i + 1 :]:
|
||||
type2 = types[i2]
|
||||
if self.is_subtype(type1, type2):
|
||||
keep.remove(i1)
|
||||
elif self.is_subtype(type2, type1):
|
||||
keep.remove(i2)
|
||||
else:
|
||||
continue
|
||||
reduced = True
|
||||
break
|
||||
return [types[i] for i in keep]
|
||||
|
||||
def lookup_member(self, type: Type, member_name: str) -> Optional[Type]:
|
||||
match type:
|
||||
case BaseType(name=name):
|
||||
if name in self._members:
|
||||
if member_name in self._members[name]:
|
||||
return self._members[name][member_name].type
|
||||
return None
|
||||
|
||||
case DerivedType(name=name, type=base):
|
||||
if name in self._members:
|
||||
if member_name in self._members[name]:
|
||||
return self._members[name][member_name].type
|
||||
return self.lookup_member(base, member_name)
|
||||
|
||||
case AppliedType(name=name, body=body, args=args):
|
||||
generic: Type = self.get_type(name)
|
||||
|
||||
if not isinstance(generic, GenericType):
|
||||
raise ValueError("AppliedType not derived from a GenericType")
|
||||
|
||||
substitutions = {
|
||||
type_var.name: arg for arg, type_var in zip(args, generic.params)
|
||||
}
|
||||
if name in self._members:
|
||||
if member_name in self._members[name]:
|
||||
member_type: Type = self._members[name][member_name].type
|
||||
return substitute_typevars(member_type, substitutions)
|
||||
|
||||
member_type2: Optional[Type] = self.lookup_member(body, member_name)
|
||||
if member_type2 is not None:
|
||||
member_type2 = substitute_typevars(member_type2, substitutions)
|
||||
return member_type2
|
||||
|
||||
case ComplexType(members=members):
|
||||
if member_name in members:
|
||||
return members[member_name]
|
||||
self.logger.debug(f"No member '{member_name}' in {type}")
|
||||
return None
|
||||
|
||||
case ExtensionType(base=base, extension=ComplexType(members=members)):
|
||||
if member_name in members:
|
||||
return members[member_name]
|
||||
self.logger.debug(
|
||||
f"No member '{member_name}' on {type}, looking up in base"
|
||||
)
|
||||
return self.lookup_member(base, member_name)
|
||||
|
||||
case ConstraintType(type=base):
|
||||
return self.lookup_member(base, member_name)
|
||||
|
||||
case TypeVar(bound=bound) if bound is not None:
|
||||
return self.lookup_member(bound, member_name)
|
||||
|
||||
case UnknownType():
|
||||
return UnknownType()
|
||||
|
||||
case _:
|
||||
self.logger.debug(f"Can't get member on {type}")
|
||||
return None
|
||||
|
||||
def lookup_predicate(self, name: str) -> Optional[Predicate]:
|
||||
return self._predicates.get(name)
|
||||
|
||||
def _by_name_or_type(self, name_or_type: str | Type) -> Type:
|
||||
if isinstance(name_or_type, str):
|
||||
return self.get_type(name_or_type)
|
||||
return name_or_type
|
||||
|
||||
def list_of(self, item_type: str | Type) -> Type:
|
||||
list_ = self.get_type("list")
|
||||
return self.apply_generic(list_, [self._by_name_or_type(item_type)])
|
||||
|
||||
def tuple_of(self, *item_types: str | Type) -> Type:
|
||||
tuple_ = self.get_type("tuple")
|
||||
return self.apply_generic(
|
||||
tuple_,
|
||||
[self._by_name_or_type(item_type) for item_type in item_types],
|
||||
)
|
||||
|
||||
def dict_of(self, key_type: str | Type, value_type: str | Type) -> Type:
|
||||
dict_ = self.get_type("dict")
|
||||
return self.apply_generic(
|
||||
dict_,
|
||||
[
|
||||
self._by_name_or_type(key_type),
|
||||
self._by_name_or_type(value_type),
|
||||
],
|
||||
)
|
||||
70
midas/checker/reporter.py
Normal file
70
midas/checker/reporter.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||
|
||||
|
||||
class Reporter:
|
||||
def __init__(self):
|
||||
self.diagnostics: list[Diagnostic] = []
|
||||
|
||||
def report(
|
||||
self,
|
||||
path: Optional[str],
|
||||
type: DiagnosticType,
|
||||
location: Location,
|
||||
message: str,
|
||||
):
|
||||
self.diagnostics.append(
|
||||
Diagnostic(
|
||||
file_path=path,
|
||||
location=location,
|
||||
type=type,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
|
||||
def for_file(self, path: Optional[str]) -> FileReporter:
|
||||
return FileReporter(self, path)
|
||||
|
||||
|
||||
class FileReporter:
|
||||
def __init__(self, base_reporter: Reporter, path: Optional[str]) -> None:
|
||||
self.base_reporter: Reporter = base_reporter
|
||||
self.path: Optional[str] = path
|
||||
|
||||
def for_file(self, path: Optional[str]) -> FileReporter:
|
||||
return FileReporter(self.base_reporter, path)
|
||||
|
||||
def report(self, type: DiagnosticType, location: Location, message: str):
|
||||
self.base_reporter.report(self.path, type, location, message)
|
||||
|
||||
def error(self, location: Location, message: str):
|
||||
self.report(
|
||||
type=DiagnosticType.ERROR,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def warning(self, location: Location, message: str):
|
||||
self.report(
|
||||
type=DiagnosticType.WARNING,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def info(self, location: Location, message: str):
|
||||
self.report(
|
||||
type=DiagnosticType.INFO,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def debug(self, location: Location, message: str):
|
||||
self.report(
|
||||
type=DiagnosticType.DEBUG,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
244
midas/checker/resolver.py
Normal file
244
midas/checker/resolver.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import midas.ast.python as p
|
||||
|
||||
|
||||
class ResolverError(Exception): ...
|
||||
|
||||
|
||||
class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
"""A variable assignment and reference resolver
|
||||
|
||||
This class keeps track of which scope a variable is defined in and which
|
||||
scope is referred to when a variable is referenced
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.locals: dict[p.Expr, int] = {}
|
||||
self.scopes: list[dict[str, bool]] = [{}]
|
||||
|
||||
def resolve(self, *objects: p.Stmt | p.Expr) -> None:
|
||||
"""Resolve the given statements or expressions"""
|
||||
|
||||
for obj in objects:
|
||||
obj.accept(self)
|
||||
|
||||
def begin_scope(self):
|
||||
"""Begin a new scope inside the current one"""
|
||||
self.scopes.append({})
|
||||
|
||||
def end_scope(self):
|
||||
"""Close the current scope"""
|
||||
self.scopes.pop()
|
||||
|
||||
def declare(self, name: str) -> None:
|
||||
"""Declare a variable in the current scope
|
||||
|
||||
This method must be called *before* evaluating the variable initializer
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
|
||||
Raises:
|
||||
ResolverError: if the variable has already been declared in the current scope
|
||||
"""
|
||||
if len(self.scopes) == 0:
|
||||
return
|
||||
scope: dict[str, bool] = self.scopes[-1]
|
||||
if name in scope:
|
||||
raise ResolverError(
|
||||
f"A variable with the name {name} is already declared in this scope"
|
||||
)
|
||||
scope[name] = False
|
||||
|
||||
def define(self, name: str) -> None:
|
||||
"""Define a variable in the current scope
|
||||
|
||||
This method must be called *after* evaluating the variable initializer
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
"""
|
||||
if len(self.scopes) == 0:
|
||||
return
|
||||
self.scopes[-1][name] = True
|
||||
|
||||
def resolve_local(self, expr: p.Expr, name: str) -> None:
|
||||
"""Resolve a variable reference and store the scope distance
|
||||
|
||||
This method associates to the variable expression a number representing
|
||||
the "distance" of the variable declaration, i.e. the number of scope
|
||||
levels to go "up" to find the closest declaration for that variable.
|
||||
|
||||
Args:
|
||||
expr (p.Expr): the variable expression
|
||||
name (str): the name of the variable
|
||||
"""
|
||||
for i, scope in enumerate(reversed(self.scopes)):
|
||||
if name in scope:
|
||||
self.locals[expr] = i
|
||||
return
|
||||
|
||||
def is_defined(self, name: str) -> bool:
|
||||
for scope in self.scopes:
|
||||
if name in scope:
|
||||
return True
|
||||
return False
|
||||
|
||||
def resolve_function(self, function: p.Function) -> None:
|
||||
"""Resolve a function definition
|
||||
|
||||
This method creates a new scope for the function, resolves all the
|
||||
parameter declarations and then the body.
|
||||
|
||||
Args:
|
||||
function (p.Function): the function to resolve
|
||||
"""
|
||||
self.begin_scope()
|
||||
for param in function.params.all:
|
||||
self.declare(param.name)
|
||||
self.define(param.name)
|
||||
self.resolve(*function.body)
|
||||
self.end_scope()
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
stmt.expr.accept(self)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
# Declare before resolving body to allow recursion
|
||||
self.declare(stmt.name)
|
||||
self.define(stmt.name)
|
||||
self.resolve_function(stmt)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
self.declare(stmt.name)
|
||||
# NOTE: resolve type here?
|
||||
self.define(stmt.name)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
self.resolve(stmt.value)
|
||||
for target in stmt.targets:
|
||||
self._visit_assign(target)
|
||||
|
||||
def _visit_assign(self, target: p.Expr):
|
||||
match target:
|
||||
case p.VariableExpr(name=name):
|
||||
if not self.is_defined(name):
|
||||
self.declare(name)
|
||||
self.define(name)
|
||||
target.accept(self)
|
||||
|
||||
case p.GetExpr():
|
||||
target.accept(self)
|
||||
|
||||
case p.SubscriptExpr():
|
||||
target.accept(self)
|
||||
|
||||
case _:
|
||||
raise Exception(f"Unsupported assignment to {target}")
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
if stmt.value is not None:
|
||||
self.resolve(stmt.value)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
# Not resolved in sub-environment because assignments in the test leak out of the if
|
||||
# For example:
|
||||
# if (m := 1 + 1) < 2:
|
||||
# ...
|
||||
# print(m) # <- m is still defined
|
||||
self.resolve(stmt.test)
|
||||
|
||||
# Body
|
||||
self.begin_scope()
|
||||
self.resolve(*stmt.body)
|
||||
self.end_scope()
|
||||
|
||||
# Else
|
||||
self.begin_scope()
|
||||
self.resolve(*stmt.orelse)
|
||||
self.end_scope()
|
||||
|
||||
def visit_pass(self, stmt: p.Pass) -> None:
|
||||
pass
|
||||
|
||||
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
|
||||
self.resolve(stmt.iterator)
|
||||
self._visit_assign(stmt.target)
|
||||
self.begin_scope()
|
||||
self.resolve(*stmt.body)
|
||||
self.end_scope()
|
||||
|
||||
def visit_raw_stmt(self, stmt: p.RawStmt) -> None:
|
||||
pass
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||
self.resolve(expr.left)
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
|
||||
self.resolve(expr.left)
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||
self.resolve(expr.callee)
|
||||
for arg in expr.arguments:
|
||||
self.resolve(arg)
|
||||
for arg in expr.keywords.values():
|
||||
self.resolve(arg)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> None:
|
||||
self.resolve(expr.object)
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
|
||||
pass
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
|
||||
if len(self.scopes) != 0 and self.scopes[-1].get(expr.name) is False:
|
||||
raise ResolverError(
|
||||
f"Cannot use local variable '{expr.name}' in its own initializer"
|
||||
) # aka. UnboundLocalError
|
||||
self.resolve_local(expr, expr.name)
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
|
||||
self.resolve(expr.left)
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||
self.resolve(expr.expr)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||
self.resolve(expr.test)
|
||||
self.resolve(expr.if_true)
|
||||
self.resolve(expr.if_false)
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||
for item in expr.items:
|
||||
self.resolve(item)
|
||||
|
||||
def visit_dict_expr(self, expr: p.DictExpr) -> None:
|
||||
for key in expr.keys:
|
||||
if key is not None:
|
||||
self.resolve(key)
|
||||
for value in expr.values:
|
||||
self.resolve(value)
|
||||
|
||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
||||
self.resolve(expr.object)
|
||||
self.resolve(expr.index)
|
||||
|
||||
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
|
||||
if expr.lower is not None:
|
||||
self.resolve(expr.lower)
|
||||
if expr.upper is not None:
|
||||
self.resolve(expr.upper)
|
||||
if expr.step is not None:
|
||||
self.resolve(expr.step)
|
||||
|
||||
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
|
||||
for item in expr.items:
|
||||
self.resolve(item)
|
||||
|
||||
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
||||
pass
|
||||
457
midas/checker/types.py
Normal file
457
midas/checker/types.py
Normal file
@@ -0,0 +1,457 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from typing import Optional, assert_never, cast
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.ast.printer import MidasPrinter
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TopType:
|
||||
def __str__(self) -> str:
|
||||
return "Any"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class BaseType:
|
||||
name: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class DerivedType:
|
||||
name: str
|
||||
type: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class UnknownType:
|
||||
def __str__(self) -> str:
|
||||
return "<Unknown>"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class UnitType:
|
||||
def __str__(self) -> str:
|
||||
return "None"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Function:
|
||||
params: ParamSpec
|
||||
returns: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.params} -> {self.returns}"
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Parameter:
|
||||
pos: int
|
||||
name: str
|
||||
type: Type
|
||||
required: bool
|
||||
|
||||
def __str__(self) -> str:
|
||||
opt: str = "" if self.required else "?"
|
||||
return f"{self.name}: {self.type}{opt}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
pos: list[Function.Parameter] = field(default_factory=list)
|
||||
mixed: list[Function.Parameter] = field(default_factory=list)
|
||||
kw: list[Function.Parameter] = field(default_factory=list)
|
||||
|
||||
def __str__(self) -> str:
|
||||
params: list[str] = []
|
||||
if len(self.pos) != 0:
|
||||
params += list(map(str, self.pos))
|
||||
params.append("/")
|
||||
|
||||
if len(self.mixed) != 0:
|
||||
params += list(map(str, self.mixed))
|
||||
|
||||
if len(self.kw) != 0:
|
||||
params.append("*")
|
||||
params += list(map(str, self.kw))
|
||||
|
||||
return f"({', '.join(params)})"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class OverloadedFunction:
|
||||
overloads: list[Type]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "<overloaded function>"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ComplexType:
|
||||
members: dict[str, Type]
|
||||
|
||||
def __str__(self) -> str:
|
||||
props: list[str] = [f"{name}: {type}" for name, type in self.members.items()]
|
||||
return f"{{{', '.join(props)}}}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ExtensionType:
|
||||
base: Type
|
||||
extension: ComplexType
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.base} & {self.extension}"
|
||||
|
||||
|
||||
class Variance(StrEnum):
|
||||
INVARIANT = "INVARIANT"
|
||||
COVARIANT = "COVARIANT"
|
||||
CONTRAVARIANT = "CONTRAVARIANT"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TypeVar:
|
||||
name: str
|
||||
bound: Optional[Type]
|
||||
variance: Variance = Variance.INVARIANT
|
||||
|
||||
def __str__(self) -> str:
|
||||
variance: str = {
|
||||
Variance.COVARIANT: "+",
|
||||
Variance.CONTRAVARIANT: "-",
|
||||
}.get(self.variance, "")
|
||||
res: str = f"{variance}{self.name}"
|
||||
if self.bound is not None:
|
||||
res = f"{res} <: {self.bound}"
|
||||
return res
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class GenericType:
|
||||
name: str
|
||||
params: list[TypeVar]
|
||||
body: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}[{', '.join(map(str, self.params))}]"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class AppliedType:
|
||||
name: str
|
||||
args: list[Type]
|
||||
body: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}[{', '.join(map(str, self.args))}]"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ConstraintType:
|
||||
type: Type
|
||||
constraint: m.Expr
|
||||
|
||||
def __str__(self) -> str:
|
||||
printer = MidasPrinter()
|
||||
return f"{self.type} where {printer.print(self.constraint)}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TupleType:
|
||||
items: tuple[Type, ...]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"({', '.join(map(str, self.items))})"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ColumnType:
|
||||
type: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Column[{self.type}]"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class DataFrameType:
|
||||
columns: list[Column]
|
||||
|
||||
def __str__(self) -> str:
|
||||
schema: list[str] = [f"{col.name}: {col.type}" for col in self.columns]
|
||||
return f"Frame[{', '.join(schema)}]"
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Column:
|
||||
index: int
|
||||
name: Optional[str]
|
||||
type: ColumnType
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class FrameGroupBy:
|
||||
frame: DataFrameType
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"FrameGroupBy[{self.frame}]"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ColumnGroupBy:
|
||||
column: ColumnType
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ColumnGroupBy[{self.column}]"
|
||||
|
||||
|
||||
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
def sub_parameter(param: Function.Parameter):
|
||||
return Function.Parameter(
|
||||
pos=param.pos,
|
||||
name=param.name,
|
||||
type=substitute_typevars(param.type, substitutions),
|
||||
required=param.required,
|
||||
)
|
||||
|
||||
def sub_param_spec(spec: ParamSpec):
|
||||
return ParamSpec(
|
||||
pos=list(map(sub_parameter, spec.pos)),
|
||||
mixed=list(map(sub_parameter, spec.mixed)),
|
||||
kw=list(map(sub_parameter, spec.kw)),
|
||||
)
|
||||
|
||||
def sub_column(col: DataFrameType.Column):
|
||||
return DataFrameType.Column(
|
||||
index=col.index,
|
||||
name=col.name,
|
||||
type=cast(ColumnType, substitute_typevars(col.type, substitutions)),
|
||||
)
|
||||
|
||||
match type:
|
||||
case TopType():
|
||||
return type
|
||||
|
||||
case BaseType(name=name) if name in substitutions:
|
||||
return substitutions[name]
|
||||
|
||||
case BaseType():
|
||||
return type
|
||||
|
||||
case DerivedType(name=name, type=type2):
|
||||
return DerivedType(
|
||||
name=name, type=substitute_typevars(type2, substitutions)
|
||||
)
|
||||
|
||||
case Function(
|
||||
params=params,
|
||||
returns=returns,
|
||||
):
|
||||
return Function(
|
||||
params=sub_param_spec(params),
|
||||
returns=substitute_typevars(returns, substitutions),
|
||||
)
|
||||
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
return OverloadedFunction(
|
||||
overloads=[
|
||||
substitute_typevars(overload, substitutions)
|
||||
for overload in overloads
|
||||
]
|
||||
)
|
||||
|
||||
case ComplexType(members=members):
|
||||
members2: dict[str, Type] = {
|
||||
name: substitute_typevars(prop, substitutions)
|
||||
for name, prop in members.items()
|
||||
}
|
||||
return ComplexType(members=members2)
|
||||
|
||||
case ExtensionType(base=base, extension=ComplexType(members=members)):
|
||||
return ExtensionType(
|
||||
base=substitute_typevars(base, substitutions),
|
||||
extension=ComplexType(
|
||||
members={
|
||||
name: substitute_typevars(prop, substitutions)
|
||||
for name, prop in members.items()
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
case AppliedType(name=name, args=args, body=body):
|
||||
return AppliedType(
|
||||
name=name,
|
||||
args=[substitute_typevars(arg, substitutions) for arg in args],
|
||||
body=substitute_typevars(body, substitutions),
|
||||
)
|
||||
|
||||
case ConstraintType():
|
||||
return ConstraintType(
|
||||
type=substitute_typevars(type.type, substitutions),
|
||||
constraint=type.constraint,
|
||||
)
|
||||
|
||||
case TypeVar(name=name):
|
||||
if name in substitutions:
|
||||
return substitutions[name]
|
||||
raise ValueError(f"Missing TypeVar substitution for {name}")
|
||||
|
||||
case GenericType(name=name, params=params, body=body):
|
||||
params2: list[TypeVar] = []
|
||||
for param in params:
|
||||
param2: Type = substitute_typevars(param, substitutions)
|
||||
if not isinstance(param2, TypeVar):
|
||||
raise ValueError(
|
||||
f"Invalid type parameter substitution, expected TypeVar, got {param2}"
|
||||
)
|
||||
params2.append(param2)
|
||||
return GenericType(
|
||||
name=name,
|
||||
params=params2,
|
||||
body=substitute_typevars(body, substitutions),
|
||||
)
|
||||
|
||||
case TupleType(items=items):
|
||||
return TupleType(
|
||||
items=tuple(substitute_typevars(item, substitutions) for item in items),
|
||||
)
|
||||
|
||||
case ColumnType(type=items_type):
|
||||
return ColumnType(
|
||||
type=substitute_typevars(items_type, substitutions),
|
||||
)
|
||||
|
||||
case DataFrameType(columns=columns):
|
||||
return DataFrameType(
|
||||
columns=list(map(sub_column, columns)),
|
||||
)
|
||||
|
||||
case FrameGroupBy(frame=frame):
|
||||
return FrameGroupBy(
|
||||
frame=cast(DataFrameType, substitute_typevars(frame, substitutions))
|
||||
)
|
||||
|
||||
case ColumnGroupBy(column=column):
|
||||
return ColumnGroupBy(
|
||||
column=cast(ColumnType, substitute_typevars(column, substitutions))
|
||||
)
|
||||
|
||||
case UnknownType() | UnitType():
|
||||
return type
|
||||
|
||||
case TopType() | GenericType():
|
||||
raise NotImplementedError(f"Unsupported type {type}")
|
||||
|
||||
# Ensure exhaustiveness
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
|
||||
def unfold_type(type: Type) -> Type:
|
||||
match type:
|
||||
case DerivedType(type=ref_type):
|
||||
return unfold_type(ref_type)
|
||||
case _:
|
||||
return type
|
||||
|
||||
|
||||
def to_annotation(type: Type) -> str:
|
||||
def _params_annotation(spec: ParamSpec) -> str:
|
||||
if len(spec.kw) != 0:
|
||||
return "..."
|
||||
|
||||
params: str = ", ".join(
|
||||
to_annotation(param.type) for param in spec.pos + spec.mixed
|
||||
)
|
||||
return f"[{params}]"
|
||||
|
||||
match type:
|
||||
case TopType():
|
||||
return "Any"
|
||||
|
||||
case BaseType(name=name):
|
||||
return name
|
||||
|
||||
case DerivedType(name=name):
|
||||
return name
|
||||
|
||||
case UnknownType():
|
||||
return "Any"
|
||||
|
||||
case UnitType():
|
||||
return "None"
|
||||
|
||||
case Function(params=params, returns=returns):
|
||||
params_annot: str = _params_annotation(params)
|
||||
return f"Callable[{params_annot}, {to_annotation(returns)}]"
|
||||
|
||||
case OverloadedFunction():
|
||||
return "Callable"
|
||||
|
||||
case ComplexType() | ExtensionType():
|
||||
raise NotImplementedError
|
||||
|
||||
case TypeVar(name=name):
|
||||
return name
|
||||
|
||||
case GenericType(name=name, params=params):
|
||||
return f"{name}[{', '.join(map(to_annotation, params))}]"
|
||||
|
||||
case AppliedType(name=name, args=args):
|
||||
return f"{name}[{', '.join(map(to_annotation, args))}]"
|
||||
|
||||
case ConstraintType():
|
||||
return str(type)
|
||||
|
||||
case TupleType(items=items):
|
||||
return f"Tuple[{', '.join(map(to_annotation, items))}]"
|
||||
|
||||
case ColumnType():
|
||||
return "pd.Series"
|
||||
|
||||
case DataFrameType():
|
||||
return "pd.DataFrame"
|
||||
|
||||
case FrameGroupBy():
|
||||
return "pd.api.typing.DataFrameGroupBy"
|
||||
|
||||
case ColumnGroupBy():
|
||||
return "pd.api.typing.SeriesGroupBy"
|
||||
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Predicate:
|
||||
type: Type
|
||||
body: m.Expr
|
||||
alias: bool
|
||||
|
||||
|
||||
Type = (
|
||||
TopType
|
||||
| BaseType
|
||||
| DerivedType
|
||||
| UnknownType
|
||||
| UnitType
|
||||
| Function
|
||||
| OverloadedFunction
|
||||
| ComplexType
|
||||
| ExtensionType
|
||||
| TypeVar
|
||||
| GenericType
|
||||
| AppliedType
|
||||
| ConstraintType
|
||||
| TupleType
|
||||
| ColumnType
|
||||
| DataFrameType
|
||||
| FrameGroupBy
|
||||
| ColumnGroupBy
|
||||
)
|
||||
201
midas/checker/unifier.py
Normal file
201
midas/checker/unifier.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
Function,
|
||||
GenericType,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
|
||||
class UnificationError(Exception): ...
|
||||
|
||||
|
||||
class Unifier:
|
||||
def __init__(self, types: TypesRegistry) -> None:
|
||||
self.types: TypesRegistry = types
|
||||
self.logger: logging.Logger = logging.getLogger("Unifier")
|
||||
|
||||
def unify_call(
|
||||
self,
|
||||
type: GenericType,
|
||||
positional: list[Type],
|
||||
keywords: dict[str, Type],
|
||||
) -> Optional[Type]:
|
||||
concrete_func: Function = Function(
|
||||
params=ParamSpec(
|
||||
pos=[
|
||||
Function.Parameter(
|
||||
pos=i,
|
||||
name=str(i),
|
||||
type=arg,
|
||||
required=True,
|
||||
)
|
||||
for i, arg in enumerate(positional)
|
||||
],
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=len(positional) + i,
|
||||
name=name,
|
||||
type=arg,
|
||||
required=True,
|
||||
)
|
||||
for i, (name, arg) in enumerate(keywords.items())
|
||||
],
|
||||
),
|
||||
returns=TopType(), # TODO: use expected type
|
||||
)
|
||||
return self.unify_generic(type, concrete_func, match_return=False)
|
||||
|
||||
def unify_generic(
|
||||
self,
|
||||
template: GenericType,
|
||||
concrete: Type,
|
||||
match_return: bool = True,
|
||||
) -> Optional[Type]:
|
||||
substitutions: dict[str, Type]
|
||||
try:
|
||||
substitutions = self.match(template.body, concrete, match_return)
|
||||
except UnificationError:
|
||||
return None
|
||||
|
||||
args: list[Type] = []
|
||||
for param in template.params:
|
||||
if param.name not in substitutions:
|
||||
return None
|
||||
args.append(substitutions[param.name])
|
||||
|
||||
applied: Type = self.types.apply_generic(template, args)
|
||||
return applied
|
||||
|
||||
def match(
|
||||
self,
|
||||
template: Type,
|
||||
concrete: Type,
|
||||
match_return: bool = True,
|
||||
) -> dict[str, Type]:
|
||||
# TODO: if concrete is Generic, record bound TypeVar. Then when merging
|
||||
# substitutions, check that the constraint is respected
|
||||
match (template, concrete):
|
||||
case (TypeVar(name=name), _):
|
||||
return {name: concrete}
|
||||
|
||||
case (
|
||||
AppliedType(name=template_name, args=template_args),
|
||||
AppliedType(name=concrete_name, args=concrete_args),
|
||||
) if template_name == concrete_name and len(template_args) == len(
|
||||
concrete_args
|
||||
):
|
||||
substitutions: dict[str, Type] = {}
|
||||
for template_arg, concrete_arg in zip(template_args, concrete_args):
|
||||
new_substistutions: dict[str, Type] = self.match(
|
||||
template_arg, concrete_arg
|
||||
)
|
||||
substitutions = self.merge(substitutions, new_substistutions)
|
||||
|
||||
return substitutions
|
||||
|
||||
case (
|
||||
DataFrameType(columns=template_columns),
|
||||
DataFrameType(columns=concrete_columns),
|
||||
) if len(template_columns) == len(concrete_columns):
|
||||
substitutions: dict[str, Type] = {}
|
||||
for template_column, concrete_column in zip(
|
||||
template_columns, concrete_columns
|
||||
):
|
||||
if template_column.index != concrete_column or (
|
||||
template_column.name != concrete_column.name
|
||||
):
|
||||
self.logger.debug(
|
||||
f"Column mismatch: template={template_column}, concrete={concrete_column}"
|
||||
)
|
||||
raise UnificationError
|
||||
new_substistutions: dict[str, Type] = self.match(
|
||||
template_column.type, concrete_column.type
|
||||
)
|
||||
substitutions = self.merge(substitutions, new_substistutions)
|
||||
return substitutions
|
||||
|
||||
case (ColumnType(type=template_column), ColumnType(type=concrete_column)):
|
||||
return self.match(template_column, concrete_column)
|
||||
|
||||
case (Function(), Function()):
|
||||
mapped: list[tuple[Function.Parameter, Function.Parameter]] = (
|
||||
self.map_params(template, concrete)
|
||||
)
|
||||
substitutions: dict[str, Type] = {}
|
||||
for template_arg, concrete_arg in mapped:
|
||||
arg_subs: dict[str, Type] = self.match(
|
||||
template_arg.type, concrete_arg.type
|
||||
)
|
||||
substitutions = self.merge(substitutions, arg_subs)
|
||||
|
||||
if match_return:
|
||||
return_subs: dict[str, Type] = self.match(
|
||||
template.returns, concrete.returns
|
||||
)
|
||||
substitutions = self.merge(substitutions, return_subs)
|
||||
|
||||
return substitutions
|
||||
|
||||
case _:
|
||||
self.logger.debug(f"Can't match {concrete!r} with {template!r}")
|
||||
return {}
|
||||
|
||||
def merge(self, subs1: dict[str, Type], subs2: dict[str, Type]) -> dict[str, Type]:
|
||||
merged: dict[str, Type] = subs1.copy()
|
||||
|
||||
for k, v in subs2.items():
|
||||
if k in merged and merged[k] != v:
|
||||
self.logger.debug(
|
||||
f"Substitution already defined for {k} with type {merged[k]}, got {v}"
|
||||
)
|
||||
raise UnificationError
|
||||
merged[k] = v
|
||||
return merged
|
||||
|
||||
def map_params(
|
||||
self, func1: Function, func2: Function
|
||||
) -> list[tuple[Function.Parameter, Function.Parameter]]:
|
||||
pos1: list[Function.Parameter] = func1.params.pos
|
||||
mixed1: list[Function.Parameter] = func1.params.mixed
|
||||
kw1: list[Function.Parameter] = func1.params.kw
|
||||
|
||||
pos2: list[Function.Parameter] = func2.params.pos
|
||||
mixed2: list[Function.Parameter] = func2.params.mixed
|
||||
kw2: list[Function.Parameter] = func2.params.kw
|
||||
|
||||
mapped: list[tuple[Function.Parameter, Function.Parameter]] = []
|
||||
|
||||
by_pos2: dict[int, Function.Parameter] = {
|
||||
param.pos: param for param in pos2 + mixed2
|
||||
}
|
||||
by_name2: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in mixed2 + kw2
|
||||
}
|
||||
|
||||
for arg1 in pos1:
|
||||
if (arg2 := by_pos2.get(arg1.pos)) is not None:
|
||||
mapped.append((arg1, arg2))
|
||||
|
||||
for arg1 in mixed1:
|
||||
# Match both positionally and by name, conflicts are caught
|
||||
# when merging substitutions
|
||||
if (arg2 := by_pos2.get(arg1.pos)) is not None:
|
||||
mapped.append((arg1, arg2))
|
||||
|
||||
if (arg2 := by_name2.get(arg1.name)) is not None:
|
||||
mapped.append((arg1, arg2))
|
||||
|
||||
for arg1 in kw1:
|
||||
if (arg2 := by_name2.get(arg1.name)) is not None:
|
||||
mapped.append((arg1, arg2))
|
||||
|
||||
return mapped
|
||||
129
midas/checker/variance.py
Normal file
129
midas/checker/variance.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from typing import Literal, Optional, cast
|
||||
|
||||
from midas.checker.registry import Member, TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
ConstraintType,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
Type,
|
||||
TypeVar,
|
||||
Variance,
|
||||
)
|
||||
|
||||
Polarity = Literal[-1, 0, 1]
|
||||
|
||||
|
||||
class Tracker:
|
||||
def __init__(self, vars: list[TypeVar]) -> None:
|
||||
self.vars: list[TypeVar] = vars
|
||||
self.refs: dict[str, set[Polarity]] = {var.name: set() for var in self.vars}
|
||||
|
||||
def record(self, var: TypeVar, polarity: Polarity):
|
||||
self.refs[var.name].add(polarity)
|
||||
|
||||
def get_updated_vars(self) -> list[TypeVar]:
|
||||
return [
|
||||
TypeVar(
|
||||
name=var.name, bound=var.bound, variance=self.get_variance(var.name)
|
||||
)
|
||||
for var in self.vars
|
||||
]
|
||||
|
||||
def get_variance(self, name: str) -> Variance:
|
||||
refs: set[Polarity] = self.refs[name]
|
||||
if refs == {-1}:
|
||||
return Variance.CONTRAVARIANT
|
||||
if refs == {1}:
|
||||
return Variance.COVARIANT
|
||||
return Variance.INVARIANT
|
||||
|
||||
def __contains__(self, item: TypeVar | str):
|
||||
if isinstance(item, TypeVar):
|
||||
return item.name in self
|
||||
return item in self.refs
|
||||
|
||||
|
||||
class VarianceInferrer:
|
||||
def __init__(self, types: TypesRegistry) -> None:
|
||||
self.types: TypesRegistry = types
|
||||
self.tracker: Tracker = Tracker([])
|
||||
|
||||
def infer(self, type: GenericType) -> GenericType:
|
||||
self.tracker = Tracker(type.params)
|
||||
|
||||
self.walk(type.body, 1, type.name)
|
||||
members: dict[str, Member] = self.types._members.get(type.name, {})
|
||||
for name, member in members.items():
|
||||
self.walk(member.type, 1, type.name, [f"member:'{name}'"])
|
||||
|
||||
return GenericType(
|
||||
name=type.name,
|
||||
params=self.tracker.get_updated_vars(),
|
||||
body=type.body,
|
||||
)
|
||||
|
||||
def walk(
|
||||
self,
|
||||
type: Type,
|
||||
polarity: Polarity,
|
||||
base_name: str,
|
||||
path: Optional[list[str]] = None,
|
||||
):
|
||||
if path is None:
|
||||
path = []
|
||||
|
||||
match type:
|
||||
# Arguments are negative positions -> flip polarity
|
||||
# Return is positive position -> keep polarity
|
||||
case Function(params=spec):
|
||||
all_params: list[Function.Parameter] = spec.pos + spec.mixed + spec.kw
|
||||
for param in all_params:
|
||||
self.walk(
|
||||
param.type,
|
||||
-polarity,
|
||||
base_name,
|
||||
path + [f"param:'{param.name}'"],
|
||||
)
|
||||
|
||||
self.walk(type.returns, polarity, base_name, path + ["return"])
|
||||
|
||||
# Walk all overloads
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
for overload in overloads:
|
||||
self.walk(overload, polarity, base_name, path)
|
||||
|
||||
# If same name as root generic -> skip
|
||||
# Get inferred variance of parameters and multiply with current
|
||||
# polarity to recurse through arguments
|
||||
case AppliedType(name=name, args=args):
|
||||
# TODO: handle mutually recursive types
|
||||
if name == base_name:
|
||||
return
|
||||
generic: Type = self.types.get_type(name)
|
||||
assert isinstance(generic, GenericType)
|
||||
params: list[TypeVar] = generic.params
|
||||
polarities: dict[Variance, Polarity] = {
|
||||
Variance.INVARIANT: 0,
|
||||
Variance.COVARIANT: 1,
|
||||
Variance.CONTRAVARIANT: -1,
|
||||
}
|
||||
for param, param in zip(args, params):
|
||||
param_polarity: Polarity = polarities[param.variance]
|
||||
self.walk(
|
||||
param,
|
||||
cast(Polarity, polarity * param_polarity),
|
||||
base_name,
|
||||
path + [f"applied:'{name}'"],
|
||||
)
|
||||
|
||||
# Walk base type
|
||||
case ConstraintType(type=base):
|
||||
self.walk(base, polarity, base_name, path + ["constraint"])
|
||||
|
||||
# Reached end
|
||||
# If tracked, record polarity
|
||||
case TypeVar():
|
||||
if type in self.tracker:
|
||||
self.tracker.record(type, polarity)
|
||||
0
midas/cli/__init__.py
Normal file
0
midas/cli/__init__.py
Normal file
41
midas/cli/ansi.py
Normal file
41
midas/cli/ansi.py
Normal file
@@ -0,0 +1,41 @@
|
||||
class Ansi:
|
||||
CTRL = "\x1b["
|
||||
RESET = CTRL + "0m"
|
||||
BOLD = CTRL + "1m"
|
||||
DIM = CTRL + "2m"
|
||||
ITALIC = CTRL + "3m"
|
||||
UNDERLINE = CTRL + "4m"
|
||||
|
||||
BLACK = 0
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
YELLOW = 3
|
||||
BLUE = 4
|
||||
MAGENTA = 5
|
||||
CYAN = 6
|
||||
WHITE = 7
|
||||
|
||||
BRIGHT_BLACK = 60
|
||||
BRIGHT_RED = 61
|
||||
BRIGHT_GREEN = 62
|
||||
BRIGHT_YELLOW = 63
|
||||
BRIGHT_BLUE = 64
|
||||
BRIGHT_MAGENTA = 65
|
||||
BRIGHT_CYAN = 66
|
||||
BRIGHT_WHITE = 67
|
||||
|
||||
@classmethod
|
||||
def FG(cls, col: int) -> str:
|
||||
return f"{cls.CTRL}{30 + col}m"
|
||||
|
||||
@classmethod
|
||||
def BG(cls, col: int) -> str:
|
||||
return f"{cls.CTRL}{40 + col}m"
|
||||
|
||||
@classmethod
|
||||
def FG_RGB(cls, r: int, g: int, b: int) -> str:
|
||||
return f"{cls.CTRL}38;2;{r};{g};{b}m"
|
||||
|
||||
@classmethod
|
||||
def BG_RGB(cls, r: int, g: int, b: int) -> str:
|
||||
return f"{cls.CTRL}48;2;{r};{g};{b}m"
|
||||
9
midas/cli/commands/__init__.py
Normal file
9
midas/cli/commands/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .check import check as check
|
||||
from .compile import compile as compile
|
||||
from .format import format as format
|
||||
from .highlight import highlight as highlight
|
||||
from .parse import parse as parse
|
||||
from .registry import dump_registry as dump_registry
|
||||
from .stubs import stubs as stubs
|
||||
from .types import types as types
|
||||
from .validate import validate as validate
|
||||
41
midas/cli/commands/check.py
Normal file
41
midas/cli/commands/check.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# **Run type checker and report diagnostics**
|
||||
# ```shell
|
||||
# midas check <file.py> [--types <file.midas>]
|
||||
# ```
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, TextIO
|
||||
|
||||
import click
|
||||
|
||||
from midas.checker.checker import TypeChecker
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.cli.highlighter import DiagnosticsHighlighter
|
||||
from midas.cli.utils import DiagnosticPrinter
|
||||
|
||||
|
||||
@click.command(help="Run type checker and report diagnostics")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||
@click.option("-l", "--highlight", type=click.File("w"))
|
||||
def check(
|
||||
file: TextIO,
|
||||
types: tuple[TextIO],
|
||||
highlight: Optional[TextIO],
|
||||
):
|
||||
source_path: Path = Path(file.name).resolve()
|
||||
|
||||
checker = TypeChecker()
|
||||
for types_file in types:
|
||||
checker.import_midas(Path(types_file.name).resolve())
|
||||
|
||||
checker.type_check(source_path)
|
||||
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
|
||||
printer = DiagnosticPrinter()
|
||||
printer.print_all(diagnostics)
|
||||
|
||||
if highlight is not None:
|
||||
source: str = file.read()
|
||||
highlighter = DiagnosticsHighlighter(source)
|
||||
highlighter.highlight(diagnostics)
|
||||
highlighter.dump(highlight)
|
||||
51
midas/cli/commands/compile.py
Normal file
51
midas/cli/commands/compile.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# **Compile source**
|
||||
# ```shell
|
||||
# midas compile <file.py> [--types <file.midas>] [-o <output>] [--assertions|--strict|--no-checks]
|
||||
# ```
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, TextIO
|
||||
|
||||
import click
|
||||
|
||||
from midas.checker.checker import TypeChecker
|
||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||
from midas.cli.utils import DiagnosticPrinter
|
||||
from midas.generator.generator import Generator
|
||||
from midas.utils import TypedAST
|
||||
|
||||
|
||||
@click.command(help="Compile source")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||
@click.option("-s", "--stubs", type=str, multiple=True)
|
||||
@click.option("--ignore-errors", is_flag=True)
|
||||
def compile(
|
||||
file: TextIO,
|
||||
types: tuple[TextIO],
|
||||
stubs: tuple[str],
|
||||
ignore_errors: bool,
|
||||
):
|
||||
source: str = file.read()
|
||||
source_path: Path = Path(file.name).resolve()
|
||||
|
||||
checker = TypeChecker()
|
||||
type_files: list[tuple[Path, Optional[str]]] = []
|
||||
for i, types_file in enumerate(types):
|
||||
in_path: Path = Path(types_file.name).resolve()
|
||||
checker.import_midas(in_path)
|
||||
type_files.append((in_path, stubs[i] if i < len(stubs) else None))
|
||||
|
||||
typed_ast: TypedAST = checker.type_check_source(source, str(source_path))
|
||||
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
|
||||
printer = DiagnosticPrinter()
|
||||
printer.print_all(diagnostics)
|
||||
|
||||
if not ignore_errors and any(
|
||||
map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)
|
||||
):
|
||||
sys.exit(1)
|
||||
|
||||
generator = Generator(workdir=source_path.parent, types=checker.types)
|
||||
generator.generate(typed_ast, source_path, type_files=type_files)
|
||||
25
midas/cli/commands/format.py
Normal file
25
midas/cli/commands/format.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import TextIO
|
||||
|
||||
import click
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.ast.printer import MidasPrinter
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
|
||||
|
||||
@click.command(help="Parse and pretty print a Midas file")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||
def format(file: TextIO, output: TextIO):
|
||||
source: str = file.read()
|
||||
printer = MidasPrinter()
|
||||
lexer = MidasLexer(source, file=file.name)
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
for err in parser.errors:
|
||||
print(err.get_report())
|
||||
for stmt in stmts:
|
||||
output.write(printer.print(stmt) + "\n")
|
||||
66
midas/cli/commands/highlight.py
Normal file
66
midas/cli/commands/highlight.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import ast
|
||||
from typing import TextIO
|
||||
|
||||
import click
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.cli.highlighter import (
|
||||
Highlighter,
|
||||
LocatableToken,
|
||||
MidasHighlighter,
|
||||
PythonHighlighter,
|
||||
)
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token, TokenType
|
||||
from midas.parser.midas import MidasParser
|
||||
from midas.parser.python import PythonParser
|
||||
|
||||
|
||||
def highlight_python(source: str, path: str) -> Highlighter:
|
||||
tree: ast.Module = ast.parse(source, filename=path)
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
highlighter = PythonHighlighter(source)
|
||||
for stmt in stmts:
|
||||
highlighter.highlight(stmt)
|
||||
return highlighter
|
||||
|
||||
|
||||
def highlight_midas(source: str, path: str) -> Highlighter:
|
||||
lexer = MidasLexer(source, file=path)
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
highlighter = MidasHighlighter(source)
|
||||
for err in parser.errors:
|
||||
print(err.get_report())
|
||||
|
||||
for stmt in stmts:
|
||||
highlighter.highlight(stmt)
|
||||
for token in tokens:
|
||||
if token.type == TokenType.COMMENT:
|
||||
highlighter.wrap(LocatableToken(token), "comment")
|
||||
elif token.is_keyword:
|
||||
highlighter.wrap(LocatableToken(token), "keyword")
|
||||
return highlighter
|
||||
|
||||
|
||||
@click.command(
|
||||
help="Parse a Python or Midas file and produce a highlighted version showing AST node types inline",
|
||||
short_help="Parse and highlight a Python or Midas file",
|
||||
)
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||
def highlight(output: TextIO, file: TextIO):
|
||||
source: str = file.read()
|
||||
highlighter: Highlighter
|
||||
|
||||
if file.name.endswith(".py"):
|
||||
highlighter = highlight_python(source, file.name)
|
||||
elif file.name.endswith(".midas"):
|
||||
highlighter = highlight_midas(source, file.name)
|
||||
else:
|
||||
raise ValueError("Unsupported file type")
|
||||
|
||||
highlighter.dump(output)
|
||||
66
midas/cli/commands/parse.py
Normal file
66
midas/cli/commands/parse.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# **Parse and pretty-print AST**
|
||||
# ```shell
|
||||
# midas parse <file.midas / file.py>
|
||||
# ```
|
||||
|
||||
import ast
|
||||
from typing import TextIO
|
||||
|
||||
import click
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.printer import MidasAstPrinter, PythonAstPrinter
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
from midas.parser.python import PythonParser
|
||||
|
||||
|
||||
def dump_python_ast(tree: ast.Module) -> str:
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
printer = PythonAstPrinter()
|
||||
dump: str = ""
|
||||
for stmt in stmts:
|
||||
dump += printer.print(stmt)
|
||||
dump += "\n"
|
||||
return dump
|
||||
|
||||
|
||||
def dump_midas_ast(source: str, filename: str) -> str:
|
||||
lexer = MidasLexer(source, file=filename)
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
if len(parser.errors) != 0:
|
||||
for err in parser.errors:
|
||||
print(err.get_report())
|
||||
raise RuntimeError("A parsing error occurred")
|
||||
printer = MidasAstPrinter()
|
||||
dump: str = ""
|
||||
for stmt in stmts:
|
||||
dump += printer.print(stmt)
|
||||
dump += "\n"
|
||||
return dump
|
||||
|
||||
|
||||
@click.command(help="Parse a Python or Midas file and pretty-print its AST")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@click.option("--raw", is_flag=True)
|
||||
def parse(file: TextIO, raw: bool):
|
||||
source: str = file.read()
|
||||
|
||||
dump: str
|
||||
if file.name.endswith(".py"):
|
||||
tree: ast.Module = ast.parse(source, filename=file.name)
|
||||
if raw:
|
||||
dump = ast.dump(tree, indent=4)
|
||||
else:
|
||||
dump = dump_python_ast(tree)
|
||||
elif file.name.endswith(".midas"):
|
||||
dump = dump_midas_ast(source, file.name)
|
||||
else:
|
||||
raise ValueError("Unsupported file type")
|
||||
|
||||
click.echo(dump)
|
||||
66
midas/cli/commands/registry.py
Normal file
66
midas/cli/commands/registry.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# **Dump types registry**
|
||||
# ```shell
|
||||
# midas dump-registry [--types <file.midas>]
|
||||
# ```
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
|
||||
import click
|
||||
|
||||
from midas.ast.printer import MidasPrinter
|
||||
from midas.checker.checker import TypeChecker
|
||||
from midas.checker.registry import Member
|
||||
from midas.checker.types import AppliedType, BaseType, DerivedType, GenericType, Type
|
||||
|
||||
|
||||
def base_type(type: Type) -> Type:
|
||||
match type:
|
||||
case BaseType():
|
||||
return type
|
||||
case DerivedType(type=base):
|
||||
return base
|
||||
case AppliedType(body=body):
|
||||
return body
|
||||
case GenericType(body=body):
|
||||
return body
|
||||
case _:
|
||||
return type
|
||||
|
||||
|
||||
@click.command(help="Dump types registry")
|
||||
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||
def dump_registry(
|
||||
types: tuple[TextIO],
|
||||
):
|
||||
checker = TypeChecker()
|
||||
for types_file in types:
|
||||
checker.import_midas(Path(types_file.name).resolve())
|
||||
|
||||
print("##### Types #####")
|
||||
for name, type in checker.types._types.items():
|
||||
members: dict[str, Member] = checker.types._members.get(name, {})
|
||||
params: str = ""
|
||||
if isinstance(type, GenericType):
|
||||
params = ", ".join(map(str, type.params))
|
||||
params = f"[{params}]"
|
||||
print(f"{name}{params} = {base_type(type)}")
|
||||
if len(members) != 0:
|
||||
print(" " * 4 + "Members:")
|
||||
for member_name, member in members.items():
|
||||
kind: str = member.kind.name
|
||||
print(" " * 8 + f"({kind:8}) {member_name}: {member.type}")
|
||||
|
||||
print("##### Predicates #####")
|
||||
printer = MidasPrinter()
|
||||
for name, predicate in checker.types._predicates.items():
|
||||
body: str = printer.print(predicate.body)
|
||||
if predicate.alias:
|
||||
print(f"{name}: {predicate.type} = {body}")
|
||||
else:
|
||||
print(f"{name}{predicate.type}:")
|
||||
body = "\n".join(
|
||||
" " + ("return " if i == 0 else "") + line
|
||||
for i, line in enumerate(body.split("\n"))
|
||||
)
|
||||
print(body)
|
||||
66
midas/cli/commands/stubs.py
Normal file
66
midas/cli/commands/stubs.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import ast
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, TextIO
|
||||
|
||||
import black
|
||||
import click
|
||||
from watchdog.events import DirModifiedEvent, FileModifiedEvent, FileSystemEventHandler
|
||||
from watchdog.observers import Observer
|
||||
|
||||
from midas.checker.checker import TypeChecker
|
||||
from midas.generator.stubs import StubsGenerator
|
||||
|
||||
|
||||
def generate_stubs(in_path: Path, out_path: Path):
|
||||
checker = TypeChecker()
|
||||
checker.import_midas(in_path)
|
||||
|
||||
generator = StubsGenerator(checker.types)
|
||||
module: ast.Module = generator.generate_stubs()
|
||||
module = ast.fix_missing_locations(module)
|
||||
|
||||
output: str = ast.unparse(module)
|
||||
output = black.format_str(output, mode=black.Mode(is_pyi=True))
|
||||
|
||||
out_path.write_text(output)
|
||||
|
||||
|
||||
class Handler(FileSystemEventHandler):
|
||||
def __init__(self, in_path: Path, out_path: Path) -> None:
|
||||
super().__init__()
|
||||
self.in_path: Path = in_path
|
||||
self.out_path: Path = out_path
|
||||
|
||||
def on_modified(self, event: DirModifiedEvent | FileModifiedEvent) -> None:
|
||||
generate_stubs(self.in_path, self.out_path)
|
||||
|
||||
|
||||
@click.command(help="Generate stubs from Midas definitions")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@click.option("-o", "--output", type=click.File("w"))
|
||||
@click.option("-w", "--watch", is_flag=True)
|
||||
def stubs(
|
||||
file: TextIO,
|
||||
output: Optional[TextIO],
|
||||
watch: bool,
|
||||
):
|
||||
source_path: Path = Path(file.name).resolve()
|
||||
out_path: Path = source_path.with_suffix(".pyi")
|
||||
if output is not None:
|
||||
out_path = Path(output.name).resolve()
|
||||
generate_stubs(source_path, out_path)
|
||||
|
||||
if watch:
|
||||
print(f"Watching {source_path}...")
|
||||
print("Press CTRL+C to stop")
|
||||
handler = Handler(source_path, out_path)
|
||||
observer = Observer()
|
||||
observer.schedule(handler, str(source_path))
|
||||
observer.start()
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
observer.stop()
|
||||
observer.join()
|
||||
52
midas/cli/commands/types.py
Normal file
52
midas/cli/commands/types.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# **Print judgements**
|
||||
# ```shell
|
||||
# midas types <file.py> [--types <file.midas>]
|
||||
# ```
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, TextIO
|
||||
|
||||
import click
|
||||
|
||||
from midas.checker.checker import TypeChecker
|
||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||
from midas.cli.highlighter import DiagnosticsHighlighter
|
||||
from midas.cli.utils import DiagnosticPrinter
|
||||
|
||||
|
||||
@click.command(help="Print typing judgements")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||
@click.option("-l", "--highlight", type=click.File("w"))
|
||||
def types(
|
||||
file: TextIO,
|
||||
types: tuple[TextIO],
|
||||
highlight: Optional[TextIO],
|
||||
):
|
||||
source_path: Path = Path(file.name).resolve()
|
||||
|
||||
checker = TypeChecker()
|
||||
for types_file in types:
|
||||
checker.import_midas(Path(types_file.name).resolve())
|
||||
|
||||
checker.type_check(source_path)
|
||||
|
||||
diagnostics: list[Diagnostic] = []
|
||||
for expr, type in checker.python_typer.judgements:
|
||||
diagnostics.append(
|
||||
Diagnostic(
|
||||
file_path=str(source_path),
|
||||
location=expr.location,
|
||||
type=DiagnosticType.INFO,
|
||||
message=f"Type: {type}",
|
||||
)
|
||||
)
|
||||
diagnostics.extend(checker.diagnostics)
|
||||
printer = DiagnosticPrinter()
|
||||
printer.print_all(diagnostics)
|
||||
|
||||
if highlight is not None:
|
||||
source: str = file.read()
|
||||
highlighter = DiagnosticsHighlighter(source)
|
||||
highlighter.highlight(diagnostics)
|
||||
highlighter.dump(highlight)
|
||||
37
midas/cli/commands/validate.py
Normal file
37
midas/cli/commands/validate.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# **Validate midas definitions**
|
||||
# ```shell
|
||||
# midas validate <file.midas>
|
||||
# ```
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, TextIO
|
||||
|
||||
import click
|
||||
|
||||
from midas.checker.checker import TypeChecker
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.cli.highlighter import DiagnosticsHighlighter
|
||||
from midas.cli.utils import DiagnosticPrinter
|
||||
|
||||
|
||||
@click.command(help="Validate Midas definitions")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@click.option("-l", "--highlight", type=click.File("w"))
|
||||
def validate(
|
||||
file: TextIO,
|
||||
highlight: Optional[TextIO],
|
||||
):
|
||||
source_path: Path = Path(file.name).resolve()
|
||||
|
||||
checker = TypeChecker()
|
||||
checker.import_midas(source_path)
|
||||
|
||||
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
|
||||
printer = DiagnosticPrinter()
|
||||
printer.print_all(diagnostics)
|
||||
|
||||
if highlight is not None:
|
||||
source: str = file.read()
|
||||
highlighter = DiagnosticsHighlighter(source)
|
||||
highlighter.highlight(diagnostics)
|
||||
highlighter.dump(highlight)
|
||||
58
midas/cli/highlight.css
Normal file
58
midas/cli/highlight.css
Normal file
@@ -0,0 +1,58 @@
|
||||
html,
|
||||
body {
|
||||
margin: 0;
|
||||
font-size: 14pt;
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
#code {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
font-family: monospace;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.line {
|
||||
display: flex;
|
||||
|
||||
&:nth-child(odd) {
|
||||
background-color: rgb(247, 247, 247);
|
||||
}
|
||||
|
||||
.no {
|
||||
width: 4em;
|
||||
text-align: right;
|
||||
padding: 0.2em 0.4em;
|
||||
border-right: solid black 1px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.txt {
|
||||
flex-grow: 1;
|
||||
padding: 0.2em 0.8em;
|
||||
}
|
||||
}
|
||||
|
||||
span {
|
||||
--col: transparent;
|
||||
--opacity: 0.1;
|
||||
--border: 0px;
|
||||
background-color: rgba(var(--col), var(--opacity));
|
||||
outline: solid rgb(var(--col)) var(--border);
|
||||
outline-offset: 2px;
|
||||
border-radius: 2px;
|
||||
|
||||
&:hover:not(:has(*:hover)) {
|
||||
--opacity: 0.8;
|
||||
--border: 2px;
|
||||
z-index: 10;
|
||||
}
|
||||
|
||||
&.keyword {
|
||||
color: rgb(211, 72, 9);
|
||||
pointer-events: none;
|
||||
}
|
||||
}
|
||||
374
midas/cli/highlighter.py
Normal file
374
midas/cli/highlighter.py
Normal file
@@ -0,0 +1,374 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generic, Optional, Protocol, TextIO, TypeVar
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.lexer.token import Token
|
||||
|
||||
H = TypeVar("H", bound="Highlighter", contravariant=True)
|
||||
|
||||
|
||||
class Highlightable(Protocol, Generic[H]):
|
||||
def accept(self, visitor: H): ...
|
||||
|
||||
|
||||
class Locatable(Protocol):
|
||||
@property
|
||||
@abstractmethod
|
||||
def location(self) -> Optional[Location]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LocatableToken:
|
||||
token: Token
|
||||
|
||||
@property
|
||||
def location(self) -> Location:
|
||||
return self.token.get_location()
|
||||
|
||||
|
||||
class Highlighter(ABC):
|
||||
BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css"
|
||||
EXTRA_CSS_PATH: Optional[Path] = None
|
||||
|
||||
def __init__(self, source: str) -> None:
|
||||
self.source: str = source
|
||||
self.lines: list[str] = self.source.splitlines()
|
||||
self.openings: dict[tuple[int, int], list[str]] = {}
|
||||
self.closings: dict[tuple[int, int], list[str]] = {}
|
||||
|
||||
def format_css(self, path: Path) -> list[str]:
|
||||
css: str = path.read_text()
|
||||
css = "\n".join((" " + line).rstrip() for line in css.splitlines())
|
||||
return [
|
||||
" <style>",
|
||||
css,
|
||||
" </style>",
|
||||
]
|
||||
|
||||
def dump(self, buf: TextIO):
|
||||
base_css: list[str] = self.format_css(self.BASE_CSS_PATH)
|
||||
extra_css: list[str] = (
|
||||
self.format_css(self.EXTRA_CSS_PATH)
|
||||
if self.EXTRA_CSS_PATH is not None
|
||||
else []
|
||||
)
|
||||
lines: list[str] = [
|
||||
"<!DOCTYPE html>",
|
||||
'<html lang="en">',
|
||||
"<head>",
|
||||
' <meta charset="UTF-8">',
|
||||
' <meta name="viewport" content="width=device-width, initial-scale=1.0">',
|
||||
" <title>Highlighted file</title>",
|
||||
*base_css,
|
||||
*extra_css,
|
||||
"</head>",
|
||||
"<body>",
|
||||
' <div id="code">',
|
||||
]
|
||||
for l, line in enumerate(self.lines):
|
||||
lineno: int = l + 1
|
||||
line_buf: str = (
|
||||
f'<div class="line" id="l{lineno}"><div class="no">{lineno}</div><div class="txt">'
|
||||
)
|
||||
for c, char in enumerate(line):
|
||||
pos: tuple[int, int] = (lineno, c)
|
||||
closings: list[str] = self.closings.get(pos, [])
|
||||
openings: list[str] = self.openings.get(pos, [])
|
||||
line_buf += "".join(closings + openings)
|
||||
line_buf += char
|
||||
line_buf += "".join(self.closings.get((lineno, len(line)), []))
|
||||
line_buf += "</div></div>"
|
||||
lines.append(" " + line_buf)
|
||||
lines.extend(
|
||||
[
|
||||
" </div>",
|
||||
"</body>",
|
||||
"</html>",
|
||||
]
|
||||
)
|
||||
|
||||
buf.write("\n".join(lines))
|
||||
|
||||
def wrap(self, node: Locatable, cls: str, message: Optional[str] = None):
|
||||
if node.location is None:
|
||||
return
|
||||
if node.location.end_lineno is None or node.location.end_col_offset is None:
|
||||
return
|
||||
start_pos: tuple[int, int] = (node.location.lineno, node.location.col_offset)
|
||||
end_pos: tuple[int, int] = (
|
||||
node.location.end_lineno,
|
||||
node.location.end_col_offset,
|
||||
)
|
||||
opening: str = f'<span class="{cls}" title="{cls}">'
|
||||
closing: str = "</span>"
|
||||
if message is not None:
|
||||
opening = f'<span class="with-msg">{opening}'
|
||||
closing = f'{closing}<span class="message">{message}</span></span>'
|
||||
|
||||
self.openings.setdefault(start_pos, []).append(opening)
|
||||
self.closings.setdefault(end_pos, []).insert(0, closing)
|
||||
if start_pos[0] != end_pos[0]:
|
||||
for l in range(start_pos[0], end_pos[0]):
|
||||
c: int = len(self.lines[l - 1])
|
||||
self.closings.setdefault((l, c), []).insert(0, closing)
|
||||
self.openings.setdefault((l + 1, 0), []).append(opening)
|
||||
|
||||
|
||||
class PythonHighlighter(
|
||||
Highlighter,
|
||||
p.MidasType.Visitor[None],
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[None],
|
||||
):
|
||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_python.css"
|
||||
|
||||
def highlight(self, node: Highlightable[PythonHighlighter]):
|
||||
node.accept(self)
|
||||
|
||||
def visit_base_type(self, node: p.BaseType) -> None:
|
||||
self.wrap(node, "base-type")
|
||||
for arg in node.args:
|
||||
self.wrap(arg, "arg")
|
||||
arg.accept(self)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||
self.wrap(node, "constraint-type")
|
||||
node.type.accept(self)
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> None:
|
||||
self.wrap(node, "frame-column")
|
||||
if node.type is not None:
|
||||
node.type.accept(self)
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> None:
|
||||
self.wrap(node, "frame-type")
|
||||
for column in node.columns:
|
||||
column.accept(self)
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
stmt.expr.accept(self)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
self.wrap(stmt, "function")
|
||||
self._highlight_param_spec(stmt.params)
|
||||
for body_stmt in stmt.body:
|
||||
body_stmt.accept(self)
|
||||
|
||||
def _highlight_param_spec(self, spec: p.ParamSpec) -> None:
|
||||
for param in spec.all:
|
||||
self._highlight_function_param(param)
|
||||
|
||||
def _highlight_function_param(self, param: p.Function.Parameter) -> None:
|
||||
self.wrap(param, "parameter")
|
||||
if param.type is not None:
|
||||
param.type.accept(self)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
for target in stmt.targets:
|
||||
target.accept(self)
|
||||
stmt.value.accept(self)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
self.wrap(stmt, "return")
|
||||
if stmt.value is not None:
|
||||
stmt.value.accept(self)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
self.wrap(stmt, "if")
|
||||
stmt.test.accept(self)
|
||||
for body_stmt in stmt.body:
|
||||
body_stmt.accept(self)
|
||||
for else_stmt in stmt.orelse:
|
||||
else_stmt.accept(self)
|
||||
|
||||
def visit_pass(self, stmt: p.Pass) -> None:
|
||||
pass
|
||||
|
||||
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
|
||||
self.wrap(stmt, "for")
|
||||
stmt.iterator.accept(self)
|
||||
stmt.target.accept(self)
|
||||
for body_stmt in stmt.body:
|
||||
body_stmt.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> None: ...
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ...
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||
self.wrap(expr, "call")
|
||||
expr.callee.accept(self)
|
||||
for arg in expr.arguments:
|
||||
arg.accept(self)
|
||||
for arg in expr.keywords.values():
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> None: ...
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> None: ...
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> None: ...
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None: ...
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||
for item in expr.items:
|
||||
item.accept(self)
|
||||
|
||||
def visit_dict_expr(self, expr: p.DictExpr) -> None:
|
||||
for key in expr.keys:
|
||||
if key is not None:
|
||||
key.accept(self)
|
||||
for value in expr.values:
|
||||
value.accept(self)
|
||||
|
||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
||||
expr.object.accept(self)
|
||||
expr.index.accept(self)
|
||||
|
||||
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
|
||||
if expr.lower is not None:
|
||||
expr.lower.accept(self)
|
||||
if expr.upper is not None:
|
||||
expr.upper.accept(self)
|
||||
if expr.step is not None:
|
||||
expr.step.accept(self)
|
||||
|
||||
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
|
||||
for item in expr.items:
|
||||
item.accept(self)
|
||||
|
||||
def visit_raw_expr(self, expr: p.RawExpr) -> None: ...
|
||||
|
||||
def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ...
|
||||
|
||||
|
||||
class MidasHighlighter(
|
||||
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
|
||||
):
|
||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"
|
||||
|
||||
def highlight(self, node: Highlightable[MidasHighlighter]):
|
||||
node.accept(self)
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
self.wrap(stmt, "type-stmt")
|
||||
self.wrap(LocatableToken(stmt.name), "type-name")
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_member_stmt(self, stmt: m.MemberStmt) -> None:
|
||||
self.wrap(stmt, "member")
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self.wrap(stmt, "extend")
|
||||
for member in stmt.members:
|
||||
member.accept(self)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||
self.wrap(stmt, "predicate")
|
||||
self.wrap(LocatableToken(stmt.name), "predicate-name")
|
||||
for spec in stmt.params:
|
||||
self._visit_param_spec(spec)
|
||||
stmt.body.accept(self)
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
||||
self.wrap(expr, "logical-expr")
|
||||
expr.left.accept(self)
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
|
||||
self.wrap(expr, "binary-expr")
|
||||
expr.left.accept(self)
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
|
||||
self.wrap(expr, "unary-expr")
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> None:
|
||||
self.wrap(expr, "call-expr")
|
||||
expr.callee.accept(self)
|
||||
for arg in expr.arguments:
|
||||
arg.accept(self)
|
||||
for arg in expr.keywords.values():
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> None:
|
||||
self.wrap(expr, "get-expr")
|
||||
expr.expr.accept(self)
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
|
||||
self.wrap(expr, "variable")
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
||||
expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> None:
|
||||
self.wrap(type, "named-type")
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> None:
|
||||
self.wrap(type, "generic-type")
|
||||
type.type.accept(self)
|
||||
for arg in type.args:
|
||||
arg.accept(self)
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
||||
self.wrap(type, "constraint-type")
|
||||
type.type.accept(self)
|
||||
type.constraint.accept(self)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> None:
|
||||
self.wrap(type, "complex-type")
|
||||
for member in type.members:
|
||||
member.accept(self)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||
self.wrap(type, "function")
|
||||
self._visit_param_spec(type.params)
|
||||
type.returns.accept(self)
|
||||
|
||||
def visit_extension_type(self, type: m.ExtensionType) -> None:
|
||||
self.wrap(type, "extension")
|
||||
type.base.accept(self)
|
||||
type.extension.accept(self)
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
|
||||
for param in spec.pos + spec.mixed + spec.kw:
|
||||
param.type.accept(self)
|
||||
|
||||
def visit_frame_type(self, type: m.FrameType) -> None:
|
||||
self.wrap(type, "frame")
|
||||
for column in type.columns:
|
||||
self._visit_frame_column(column)
|
||||
|
||||
def _visit_frame_column(self, column: m.FrameType.Column) -> None:
|
||||
self.wrap(column, "column")
|
||||
|
||||
|
||||
class DiagnosticsHighlighter(Highlighter):
|
||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
|
||||
|
||||
def highlight(self, diagnostics: list[Diagnostic]):
|
||||
for diagnostic in diagnostics:
|
||||
self.wrap(diagnostic, str(diagnostic.type).lower(), diagnostic.message)
|
||||
39
midas/cli/hl_diagnostic.css
Normal file
39
midas/cli/hl_diagnostic.css
Normal file
@@ -0,0 +1,39 @@
|
||||
span {
|
||||
--opacity: 0.4;
|
||||
|
||||
&.error {
|
||||
--col: 255, 0, 0;
|
||||
}
|
||||
&.warning {
|
||||
--col: 250, 160, 0;
|
||||
}
|
||||
&.info {
|
||||
--col: 150, 190, 250;
|
||||
}
|
||||
|
||||
&.with-msg {
|
||||
position: relative;
|
||||
|
||||
.message {
|
||||
display: none;
|
||||
}
|
||||
|
||||
&:hover:not(:has(.with-msg:hover)) {
|
||||
.message {
|
||||
display: inline-block;
|
||||
}
|
||||
}
|
||||
|
||||
.message {
|
||||
position: absolute;
|
||||
top: calc(100% + 0.2em);
|
||||
left: -.2em;
|
||||
background-color: black;
|
||||
color: white;
|
||||
padding: 0.2em 0.4em;
|
||||
border-radius: .2em;
|
||||
z-index: 10;
|
||||
width: 300%;
|
||||
}
|
||||
}
|
||||
}
|
||||
52
midas/cli/hl_midas.css
Normal file
52
midas/cli/hl_midas.css
Normal file
@@ -0,0 +1,52 @@
|
||||
span {
|
||||
&.comment {
|
||||
--col: 200, 200, 200;
|
||||
color: rgb(110, 110, 110);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
&.named-type,
|
||||
&.generic-type,
|
||||
&.constraint-type,
|
||||
&.complex-type {
|
||||
--col: 150, 150, 150;
|
||||
}
|
||||
|
||||
&.constraint {
|
||||
--col: 233, 108, 108;
|
||||
}
|
||||
|
||||
&.property {
|
||||
--col: 233, 108, 176;
|
||||
}
|
||||
|
||||
&.extend {
|
||||
--col: 108, 197, 233;
|
||||
}
|
||||
|
||||
&.op {
|
||||
--col: 108, 148, 233;
|
||||
}
|
||||
|
||||
&.predicate {
|
||||
--col: 193, 108, 233;
|
||||
}
|
||||
|
||||
&.logical-expr,
|
||||
&.binary-expr,
|
||||
&.unary-expr,
|
||||
&.get-expr {
|
||||
--col: 123, 215, 193;
|
||||
}
|
||||
|
||||
&.template {
|
||||
--col: 163, 117, 71;
|
||||
}
|
||||
|
||||
&.type-name,
|
||||
&.op-name,
|
||||
&.predicate-name {
|
||||
--col: 200, 200, 200;
|
||||
font-weight: bold;
|
||||
}
|
||||
}
|
||||
29
midas/cli/hl_python.css
Normal file
29
midas/cli/hl_python.css
Normal file
@@ -0,0 +1,29 @@
|
||||
span {
|
||||
&.base-type {
|
||||
--col: 108, 233, 108;
|
||||
}
|
||||
|
||||
&.arg {
|
||||
--col: 103, 192, 224;
|
||||
}
|
||||
|
||||
&.constraint-type {
|
||||
--col: 174, 200, 195;
|
||||
}
|
||||
|
||||
&.frame-column {
|
||||
--col: 216, 231, 81;
|
||||
}
|
||||
|
||||
&.frame-type {
|
||||
--col: 231, 46, 40;
|
||||
}
|
||||
|
||||
&.function {
|
||||
--col: 215, 103, 224;
|
||||
}
|
||||
|
||||
&.parameter {
|
||||
--col: 103, 192, 224;
|
||||
}
|
||||
}
|
||||
26
midas/cli/main.py
Normal file
26
midas/cli/main.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import logging
|
||||
|
||||
import click
|
||||
|
||||
from midas.cli import commands
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.option("-v", "--verbose", is_flag=True)
|
||||
def midas(verbose: bool):
|
||||
logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN)
|
||||
|
||||
|
||||
midas.add_command(commands.check)
|
||||
midas.add_command(commands.compile)
|
||||
midas.add_command(commands.format)
|
||||
midas.add_command(commands.highlight)
|
||||
midas.add_command(commands.parse)
|
||||
midas.add_command(commands.dump_registry)
|
||||
midas.add_command(commands.types)
|
||||
midas.add_command(commands.stubs)
|
||||
midas.add_command(commands.validate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
midas()
|
||||
121
midas/cli/utils.py
Normal file
121
midas/cli/utils.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||
from midas.cli.ansi import Ansi
|
||||
|
||||
|
||||
class DiagnosticPrinter:
|
||||
COLORS: dict[DiagnosticType, int] = {
|
||||
DiagnosticType.ERROR: Ansi.RED,
|
||||
DiagnosticType.WARNING: Ansi.YELLOW,
|
||||
DiagnosticType.INFO: Ansi.CYAN,
|
||||
DiagnosticType.DEBUG: Ansi.MAGENTA,
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.files: dict[Optional[str], list[str]] = {}
|
||||
|
||||
def get_lines(self, filename: Optional[str]) -> list[str]:
|
||||
if filename is None:
|
||||
return []
|
||||
if filename not in self.files:
|
||||
path: Path = Path(filename)
|
||||
if path.exists() and path.is_file():
|
||||
self.files[filename] = path.read_text().split("\n")
|
||||
else:
|
||||
self.files[filename] = []
|
||||
return self.files[filename]
|
||||
|
||||
def print_all(self, diagnostics: list[Diagnostic], indent: int = 4):
|
||||
by_type: dict[DiagnosticType, int] = defaultdict(int)
|
||||
for diagnostic in diagnostics:
|
||||
filename: Optional[str] = diagnostic.file_path
|
||||
lines = self.get_lines(filename)
|
||||
self.print(lines, diagnostic, indent=indent)
|
||||
by_type[diagnostic.type] += 1
|
||||
|
||||
if len(diagnostics) == 0:
|
||||
return
|
||||
|
||||
counts: list[str] = []
|
||||
for type in DiagnosticType:
|
||||
if type not in by_type:
|
||||
continue
|
||||
count: int = by_type[type]
|
||||
color: int = self.COLORS.get(type, Ansi.WHITE)
|
||||
counts.append(f"{Ansi.FG(color)}{type.value}s{Ansi.RESET}: {count}")
|
||||
|
||||
print(" ".join(counts))
|
||||
|
||||
def print(self, lines: list[str], diagnostic: Diagnostic, indent: int = 4):
|
||||
"""Pretty-print a diagnostic, showing some context if possible
|
||||
|
||||
If the diagnostic concerns a specific part of one line, the line is shown
|
||||
with the affected part highlighted. The message is clearly printed under the
|
||||
line with an underline further indicating the target expression.
|
||||
|
||||
If multiple lines are concerned, no context is shown, only the
|
||||
diagnostic type, location and message
|
||||
|
||||
Args:
|
||||
lines (list[str]): source code lines
|
||||
diagnostic (Diagnostic): the diagnostic to print
|
||||
indent (int, optional): the number of spaces added before the target line to indent if from the location header. Defaults to 4.
|
||||
"""
|
||||
|
||||
loc: Location = diagnostic.location
|
||||
if loc.lineno != loc.end_lineno:
|
||||
self.print_multiline(lines, diagnostic, indent)
|
||||
return
|
||||
|
||||
start_offset: int = loc.col_offset
|
||||
end_offset: int = loc.end_col_offset or (start_offset + 1)
|
||||
|
||||
line: str = lines[loc.lineno - 1]
|
||||
before: str = line[:start_offset]
|
||||
after: str = line[end_offset:]
|
||||
|
||||
color: int = self.COLORS.get(diagnostic.type, Ansi.WHITE)
|
||||
|
||||
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
|
||||
cursor: str = (
|
||||
" " * start_offset
|
||||
+ Ansi.FG(color)
|
||||
+ "~" * (end_offset - start_offset)
|
||||
+ "> "
|
||||
+ diagnostic.message
|
||||
+ Ansi.RESET
|
||||
)
|
||||
|
||||
indent_str: str = " " * indent
|
||||
print(diagnostic.location_str + ":")
|
||||
print(indent_str + before + subject + after)
|
||||
print(indent_str + cursor)
|
||||
print()
|
||||
|
||||
def print_multiline(
|
||||
self, all_lines: list[str], diagnostic: Diagnostic, indent: int = 4
|
||||
):
|
||||
loc: Location = diagnostic.location
|
||||
lines: list[str] = all_lines[loc.lineno - 1 : loc.end_lineno]
|
||||
|
||||
start_offset: int = loc.col_offset
|
||||
end_offset: int = loc.end_col_offset or (start_offset + 1)
|
||||
|
||||
indent_str: str = " " * indent
|
||||
color: int = self.COLORS.get(diagnostic.type, Ansi.WHITE)
|
||||
res: str = indent_str + lines[0][:start_offset]
|
||||
res += Ansi.FG(color) + lines[0][start_offset:]
|
||||
for line in lines[1:-1]:
|
||||
res += "\n" + indent_str + line
|
||||
res += "\n" + indent_str + lines[-1][:end_offset]
|
||||
res += Ansi.RESET + lines[-1][end_offset:]
|
||||
|
||||
print(diagnostic.location_str + ":")
|
||||
print(res)
|
||||
print()
|
||||
print(Ansi.FG(color) + diagnostic.message + Ansi.RESET)
|
||||
print()
|
||||
59
midas/generator/collector.py
Normal file
59
midas/generator/collector.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import midas.ast.python as p
|
||||
|
||||
AssertionBuilder = Callable[..., ast.expr]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Assertion:
|
||||
bound_expr: p.Expr
|
||||
inputs: list[p.Expr]
|
||||
builder: AssertionBuilder
|
||||
message: str
|
||||
|
||||
def is_bound_to(self, expr: p.Expr) -> bool:
|
||||
return expr == self.bound_expr
|
||||
|
||||
|
||||
class AssertionCollector:
|
||||
def __init__(self):
|
||||
self.assertions: list[Assertion] = []
|
||||
self.definitions: dict[str, ast.stmt] = {}
|
||||
|
||||
def add(
|
||||
self,
|
||||
bound_expr: p.Expr,
|
||||
inputs: list[p.Expr],
|
||||
builder: AssertionBuilder,
|
||||
message: str,
|
||||
):
|
||||
self.assertions.append(
|
||||
Assertion(
|
||||
bound_expr=bound_expr,
|
||||
inputs=inputs,
|
||||
builder=builder,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
|
||||
def remove(self, assertion: Assertion):
|
||||
try:
|
||||
self.assertions.remove(assertion)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def define(self, name: str, stmt: ast.stmt):
|
||||
if name not in self.definitions:
|
||||
self.definitions[name] = stmt
|
||||
|
||||
def get_definitions(self) -> list[ast.stmt]:
|
||||
return list(self.definitions.values())
|
||||
|
||||
def get_assertions(self) -> list[Assertion]:
|
||||
return self.assertions
|
||||
|
||||
def get_assertions_for(self, expr: p.Expr) -> list[Assertion]:
|
||||
return list(filter(lambda a: a.is_bound_to(expr), self.assertions))
|
||||
225
midas/generator/constraints.py
Normal file
225
midas/generator/constraints.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import ast
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import (
|
||||
Function,
|
||||
ParamSpec,
|
||||
Predicate,
|
||||
Type,
|
||||
to_annotation,
|
||||
)
|
||||
from midas.lexer.token import TokenType
|
||||
|
||||
LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = {
|
||||
TokenType.AND: ast.And,
|
||||
# TokenType.OR: ast.Or,
|
||||
}
|
||||
|
||||
BINARY_OPERATORS: dict[TokenType, type[ast.operator]] = {
|
||||
# TokenType.PLUS: ast.Add,
|
||||
TokenType.MINUS: ast.Sub,
|
||||
TokenType.STAR: ast.Mult,
|
||||
TokenType.SLASH: ast.Div,
|
||||
}
|
||||
|
||||
UNARY_OPERATORS: dict[TokenType, type[ast.unaryop]] = {
|
||||
# TokenType.PLUS: ast.UAdd,
|
||||
TokenType.MINUS: ast.USub,
|
||||
}
|
||||
|
||||
COMPARISON_OPERATORS: dict[TokenType, type[ast.cmpop]] = {
|
||||
TokenType.GREATER: ast.Gt,
|
||||
TokenType.GREATER_EQUAL: ast.GtE,
|
||||
TokenType.LESS: ast.Lt,
|
||||
TokenType.LESS_EQUAL: ast.LtE,
|
||||
TokenType.EQUAL_EQUAL: ast.Eq,
|
||||
TokenType.BANG_EQUAL: ast.NotEq,
|
||||
}
|
||||
|
||||
|
||||
class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||
def __init__(self, types: TypesRegistry):
|
||||
self.types: TypesRegistry = types
|
||||
self._id: int = 0
|
||||
self._definitions: list[ast.stmt] = []
|
||||
self._aliases: dict[str, str] = {}
|
||||
|
||||
def get_definitions(self) -> list[ast.stmt]:
|
||||
return self._definitions
|
||||
|
||||
def generate(self, expr: m.Expr) -> ast.expr:
|
||||
match expr:
|
||||
case m.VariableExpr():
|
||||
return expr.accept(self)
|
||||
case _:
|
||||
func = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="_",
|
||||
type=self.types.get_type("Any"),
|
||||
required=True,
|
||||
)
|
||||
],
|
||||
),
|
||||
returns=self.types.get_type("bool"),
|
||||
)
|
||||
alias: str = self.make_alias(None)
|
||||
definition: ast.stmt = self.make_definition(
|
||||
alias, Predicate(type=func, body=expr, alias=False)
|
||||
)
|
||||
self._definitions.append(definition)
|
||||
return ast.Name(id=alias)
|
||||
|
||||
def make_alias(self, name: Optional[str]) -> str:
|
||||
suffix: str
|
||||
if name is None:
|
||||
suffix = f"p{self._id}"
|
||||
self._id += 1
|
||||
else:
|
||||
suffix = name
|
||||
alias: str = f"__midas_{suffix}__"
|
||||
return alias
|
||||
|
||||
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
|
||||
body: ast.expr = predicate.body.accept(self)
|
||||
if predicate.alias:
|
||||
return ast.Assign(
|
||||
targets=[
|
||||
ast.Name(id=name),
|
||||
],
|
||||
value=body,
|
||||
)
|
||||
return self.make_func(name, [ast.Return(value=body)], predicate.type)
|
||||
|
||||
def make_args(self, params: ParamSpec) -> ast.arguments:
|
||||
return ast.arguments(
|
||||
posonlyargs=[
|
||||
ast.arg(
|
||||
arg=param.name,
|
||||
annotation=ast.Constant(value=to_annotation(param.type)),
|
||||
)
|
||||
for param in params.pos
|
||||
],
|
||||
args=[
|
||||
ast.arg(
|
||||
arg=param.name,
|
||||
annotation=ast.Constant(value=to_annotation(param.type)),
|
||||
)
|
||||
for param in params.mixed
|
||||
],
|
||||
kwonlyargs=[
|
||||
ast.arg(
|
||||
arg=param.name,
|
||||
annotation=ast.Constant(value=to_annotation(param.type)),
|
||||
)
|
||||
for param in params.kw
|
||||
],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
)
|
||||
|
||||
def make_func(
|
||||
self, name: str, inner_body: list[ast.stmt], type: Type, level: int = 0
|
||||
) -> ast.stmt:
|
||||
match type:
|
||||
case Function(params=params, returns=Function()):
|
||||
inner_name: str = f"inner{level}"
|
||||
return ast.FunctionDef(
|
||||
name=name,
|
||||
args=self.make_args(params),
|
||||
body=[
|
||||
self.make_func(inner_name, inner_body, type.returns, level + 1),
|
||||
ast.Return(value=ast.Name(id=inner_name)),
|
||||
],
|
||||
returns=ast.Constant(value=to_annotation(type.returns)),
|
||||
decorator_list=[],
|
||||
)
|
||||
|
||||
case Function(params=params):
|
||||
return ast.FunctionDef(
|
||||
name=name,
|
||||
args=self.make_args(params),
|
||||
body=inner_body,
|
||||
returns=ast.Constant(value=to_annotation(type.returns)),
|
||||
decorator_list=[],
|
||||
)
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Expected function, got {type!r}")
|
||||
|
||||
def get_predicate(self, name: str) -> Optional[ast.expr]:
|
||||
if name not in self._aliases:
|
||||
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
|
||||
if predicate is None:
|
||||
return None
|
||||
alias: str = self.make_alias(name)
|
||||
self._aliases[name] = alias
|
||||
self._definitions.append(self.make_definition(alias, predicate))
|
||||
|
||||
return ast.Name(id=self._aliases[name])
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> ast.expr:
|
||||
return ast.BoolOp(
|
||||
op=LOGICAL_OPERATORS[expr.operator.type](),
|
||||
values=[
|
||||
expr.left.accept(self),
|
||||
expr.right.accept(self),
|
||||
],
|
||||
)
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> ast.expr:
|
||||
op: TokenType = expr.operator.type
|
||||
if op in BINARY_OPERATORS:
|
||||
return ast.BinOp(
|
||||
left=expr.left.accept(self),
|
||||
op=BINARY_OPERATORS[op](),
|
||||
right=expr.right.accept(self),
|
||||
)
|
||||
if op in COMPARISON_OPERATORS:
|
||||
return ast.Compare(
|
||||
left=expr.left.accept(self),
|
||||
ops=[COMPARISON_OPERATORS[op]()],
|
||||
comparators=[expr.right.accept(self)],
|
||||
)
|
||||
raise ValueError(f"Unexpected binary operator {op}")
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> ast.expr:
|
||||
return ast.UnaryOp(
|
||||
op=UNARY_OPERATORS[expr.operator.type](),
|
||||
operand=expr.right.accept(self),
|
||||
)
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=expr.callee.accept(self),
|
||||
args=[arg.accept(self) for arg in expr.arguments],
|
||||
keywords=[
|
||||
ast.keyword(arg=name, value=arg.accept(self))
|
||||
for name, arg in expr.keywords.items()
|
||||
],
|
||||
)
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> ast.expr:
|
||||
return ast.Attribute(
|
||||
value=expr.expr.accept(self),
|
||||
attr=expr.name.lexeme,
|
||||
)
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> ast.expr:
|
||||
name: str = expr.name.lexeme
|
||||
if (p := self.get_predicate(name)) is not None:
|
||||
return p
|
||||
return ast.Name(id=name)
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> ast.expr:
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> ast.expr:
|
||||
return ast.Constant(value=expr.value)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> ast.expr:
|
||||
return ast.Name(id="_")
|
||||
690
midas/generator/generator.py
Normal file
690
midas/generator/generator.py
Normal file
@@ -0,0 +1,690 @@
|
||||
import ast
|
||||
import logging
|
||||
import shutil
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional, assert_never
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.printer import MidasPrinter
|
||||
from midas.checker.checker import TypeChecker
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
DataFrameType,
|
||||
DerivedType,
|
||||
ExtensionType,
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
TopType,
|
||||
TupleType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.generator.collector import Assertion, AssertionCollector
|
||||
from midas.generator.constraints import ConstraintGenerator
|
||||
from midas.generator.stubs import StubsGenerator
|
||||
from midas.utils import TypedAST
|
||||
|
||||
|
||||
@dataclass
|
||||
class Scope:
|
||||
pre_assertions: list[ast.stmt] = field(default_factory=list[ast.stmt])
|
||||
aliases: list[str] = field(default_factory=list[str])
|
||||
|
||||
|
||||
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
IS_DATAFRAME_FUNC = "__midas_is_dataframe__"
|
||||
IS_COLUMN_FUNC = "__midas_is_column__"
|
||||
|
||||
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
|
||||
self.workdir: Path = workdir.resolve()
|
||||
self.build_dir: Path = self.workdir / "build" / "midas"
|
||||
self.rel_src_path: Path = Path()
|
||||
self.logger: logging.Logger = logging.getLogger("Generator")
|
||||
|
||||
self._typed_ast: TypedAST = TypedAST(
|
||||
stmts=[],
|
||||
judgements=[],
|
||||
evaluated_casts=[],
|
||||
assertions=AssertionCollector(),
|
||||
)
|
||||
self._alias_count: int = 0
|
||||
self._predicate_count: int = 0
|
||||
self._scopes: list[Scope] = []
|
||||
self._aliases: list[tuple[p.Expr, ast.expr]] = []
|
||||
|
||||
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
|
||||
self._constraints: list[tuple[m.Expr, ast.expr]] = []
|
||||
|
||||
self.define_is_dataframe: bool = False
|
||||
self.define_is_column: bool = False
|
||||
|
||||
def set_src_path(self, path: Path):
|
||||
self.rel_src_path = path.resolve().relative_to(self.workdir)
|
||||
|
||||
def generate_ast(self, typed_ast: TypedAST) -> ast.AST:
|
||||
self._typed_ast = typed_ast
|
||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts, can_be_empty=True)
|
||||
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
|
||||
|
||||
body = predicates + body
|
||||
|
||||
if self.define_is_dataframe:
|
||||
body = [self._is_dataframe_definition()] + body
|
||||
|
||||
if self.define_is_column:
|
||||
body = [self._is_column_definition()] + body
|
||||
|
||||
module = ast.Module(body=body, type_ignores=[])
|
||||
module = ast.fix_missing_locations(module)
|
||||
return module
|
||||
|
||||
def generate(
|
||||
self,
|
||||
typed_ast: TypedAST,
|
||||
src_path: Path,
|
||||
out_path: Optional[Path] = None,
|
||||
type_files: Optional[list[tuple[Path, Optional[str]]]] = None,
|
||||
) -> Path:
|
||||
self.set_src_path(src_path)
|
||||
if out_path is None:
|
||||
if self.build_dir.exists():
|
||||
shutil.rmtree(self.build_dir)
|
||||
self.build_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = (self.build_dir / self.rel_src_path).resolve()
|
||||
try:
|
||||
_ = out_path.relative_to(self.build_dir)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Directory traversal, {self.rel_src_path} points outside of parent directory"
|
||||
)
|
||||
out_dir: Path = out_path.parent
|
||||
out_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if type_files is not None:
|
||||
for in_path, out_name in type_files:
|
||||
if out_name is None:
|
||||
out_name = in_path.stem
|
||||
self.generate_stubs(in_path, out_dir / f"{out_name}.py")
|
||||
|
||||
module: ast.AST = self.generate_ast(typed_ast)
|
||||
compiled: str = ast.unparse(module)
|
||||
|
||||
out_path.write_text(compiled)
|
||||
return out_path
|
||||
|
||||
def generate_stubs(self, in_path: Path, out_path: Path):
|
||||
checker = TypeChecker()
|
||||
checker.import_midas(in_path)
|
||||
generator = StubsGenerator(checker.types)
|
||||
module: ast.Module = generator.generate_stubs()
|
||||
module = ast.fix_missing_locations(module)
|
||||
output: str = ast.unparse(module)
|
||||
out_path.write_text(output)
|
||||
|
||||
def convert(self, expr: p.Expr) -> ast.expr:
|
||||
for expr2, alias in self._aliases:
|
||||
if expr2 == expr:
|
||||
return alias
|
||||
assertions = self._typed_ast.assertions.get_assertions_for(expr)
|
||||
if len(assertions) != 0:
|
||||
return self._apply_assertions(expr, assertions)
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> ast.expr:
|
||||
return ast.BinOp(
|
||||
left=self.convert(expr.left),
|
||||
op=expr.operator,
|
||||
right=self.convert(expr.right),
|
||||
)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> ast.expr:
|
||||
return ast.Compare(
|
||||
left=self.convert(expr.left),
|
||||
ops=[expr.operator],
|
||||
comparators=[self.convert(expr.right)],
|
||||
)
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> ast.expr:
|
||||
return ast.UnaryOp(
|
||||
op=expr.operator,
|
||||
operand=self.convert(expr.right),
|
||||
)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=self.convert(expr.callee),
|
||||
args=[self.convert(arg) for arg in expr.arguments],
|
||||
keywords=[
|
||||
ast.keyword(arg=name, value=self.convert(arg))
|
||||
for name, arg in expr.keywords.items()
|
||||
],
|
||||
)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> ast.expr:
|
||||
return ast.Attribute(
|
||||
value=self.convert(expr.object),
|
||||
attr=expr.name,
|
||||
)
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> ast.expr:
|
||||
return ast.Constant(value=expr.value)
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> ast.expr:
|
||||
return ast.Name(id=expr.name)
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> ast.expr:
|
||||
return ast.BoolOp(
|
||||
op=expr.operator,
|
||||
values=[self.convert(expr.left), self.convert(expr.right)],
|
||||
)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
||||
expr2: ast.expr = self.convert(expr.expr)
|
||||
|
||||
if expr in self._typed_ast.evaluated_casts or expr.unsafe:
|
||||
return expr2
|
||||
|
||||
alias: ast.expr = self._make_alias(expr.expr, expr2)
|
||||
|
||||
type: Type = self._get_expr_type(expr)
|
||||
asserts: list[ast.stmt] = self._make_cast_asserts(expr.location, alias, type)
|
||||
for assert_ in asserts:
|
||||
self._add_assert(assert_)
|
||||
|
||||
return alias
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
|
||||
return ast.IfExp(
|
||||
test=self.convert(expr.test),
|
||||
body=self.convert(expr.if_true),
|
||||
orelse=self.convert(expr.if_false),
|
||||
)
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> ast.expr:
|
||||
return ast.List(
|
||||
elts=[self.convert(item) for item in expr.items],
|
||||
)
|
||||
|
||||
def visit_dict_expr(self, expr: p.DictExpr) -> ast.expr:
|
||||
return ast.Dict(
|
||||
keys=[self.convert(key) if key is not None else None for key in expr.keys],
|
||||
values=[self.convert(value) for value in expr.values],
|
||||
)
|
||||
|
||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr:
|
||||
return ast.Subscript(
|
||||
value=self.convert(expr.object),
|
||||
slice=self.convert(expr.index),
|
||||
)
|
||||
|
||||
def visit_slice_expr(self, expr: p.SliceExpr) -> ast.expr:
|
||||
return ast.Slice(
|
||||
lower=self.convert(expr.lower) if expr.lower is not None else None,
|
||||
upper=self.convert(expr.upper) if expr.upper is not None else None,
|
||||
step=self.convert(expr.step) if expr.step is not None else None,
|
||||
)
|
||||
|
||||
def visit_tuple_expr(self, expr: p.TupleExpr) -> ast.expr:
|
||||
return ast.Tuple(
|
||||
elts=[self.convert(item) for item in expr.items],
|
||||
)
|
||||
|
||||
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
|
||||
return expr.expr
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt:
|
||||
return ast.Expr(
|
||||
value=self.convert(stmt.expr),
|
||||
)
|
||||
|
||||
def make_args(self, params: p.ParamSpec) -> ast.arguments:
|
||||
return ast.arguments(
|
||||
posonlyargs=[ast.arg(arg=param.name) for param in params.pos],
|
||||
args=[ast.arg(arg=param.name) for param in params.mixed],
|
||||
kwonlyargs=[ast.arg(arg=param.name) for param in params.kw],
|
||||
defaults=[
|
||||
self.convert(param.default)
|
||||
for param in params.pos + params.mixed
|
||||
if param.default is not None
|
||||
],
|
||||
kw_defaults=[
|
||||
self.convert(param.default) if param.default is not None else None
|
||||
for param in params.kw
|
||||
],
|
||||
)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> ast.stmt:
|
||||
return ast.FunctionDef(
|
||||
name=stmt.name,
|
||||
args=self.make_args(stmt.params),
|
||||
body=self._visit_body(stmt.body),
|
||||
decorator_list=[],
|
||||
)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> ast.stmt:
|
||||
# TODO: is that ok?
|
||||
return ast.Pass()
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> ast.stmt:
|
||||
return ast.Assign(
|
||||
targets=[self.convert(target) for target in stmt.targets],
|
||||
value=self.convert(stmt.value),
|
||||
)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> ast.stmt:
|
||||
return ast.Return(
|
||||
value=self.convert(stmt.value) if stmt.value is not None else None,
|
||||
)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> ast.stmt:
|
||||
return ast.If(
|
||||
test=self.convert(stmt.test),
|
||||
body=self._visit_body(stmt.body),
|
||||
orelse=self._visit_body(stmt.orelse, can_be_empty=True),
|
||||
)
|
||||
|
||||
def visit_pass(self, stmt: p.Pass) -> ast.stmt:
|
||||
return ast.Pass()
|
||||
|
||||
def visit_for_stmt(self, stmt: p.ForStmt) -> ast.stmt:
|
||||
return ast.For(
|
||||
target=self.convert(stmt.target),
|
||||
iter=self.convert(stmt.iterator),
|
||||
body=self._visit_body(stmt.body),
|
||||
orelse=[],
|
||||
)
|
||||
|
||||
def visit_raw_stmt(self, stmt: p.RawStmt) -> ast.stmt:
|
||||
return stmt.stmt
|
||||
|
||||
def _visit_body(
|
||||
self, stmts: list[p.Stmt], can_be_empty: bool = False
|
||||
) -> list[ast.stmt]:
|
||||
generated: list[ast.stmt] = []
|
||||
for stmt in stmts:
|
||||
scope = Scope()
|
||||
self._scopes.append(scope)
|
||||
|
||||
stmt2 = stmt.accept(self)
|
||||
generated.extend(scope.pre_assertions)
|
||||
generated.append(stmt2)
|
||||
if len(scope.aliases) != 0:
|
||||
generated.append(
|
||||
ast.Delete(targets=[ast.Name(id=alias) for alias in scope.aliases])
|
||||
)
|
||||
self._scopes.pop()
|
||||
|
||||
# Remove redundant pass statements
|
||||
if len(generated) > 1:
|
||||
generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)]
|
||||
if len(generated) == 0 and not can_be_empty:
|
||||
generated = [ast.Pass()]
|
||||
return generated
|
||||
|
||||
def _make_alias(self, node: p.Expr, expr: ast.expr) -> ast.expr:
|
||||
name: str = f"__midas_a{self._alias_count}__"
|
||||
alias = ast.Name(id=name)
|
||||
self._alias_count += 1
|
||||
self._scopes[-1].aliases.append(name)
|
||||
self._scopes[-1].pre_assertions.append(
|
||||
ast.Assign(
|
||||
targets=[alias],
|
||||
value=expr,
|
||||
)
|
||||
)
|
||||
self._aliases.append((node, alias))
|
||||
return alias
|
||||
|
||||
def _build_assert(self, expr: ast.expr, message: str | ast.expr) -> ast.stmt:
|
||||
if isinstance(message, str):
|
||||
message = ast.Constant(value=message)
|
||||
return ast.Assert(
|
||||
test=expr,
|
||||
msg=message,
|
||||
)
|
||||
|
||||
def _add_assert(self, assertion: ast.stmt):
|
||||
self._scopes[-1].pre_assertions.append(assertion)
|
||||
|
||||
def _get_expr_type(self, query: p.Expr) -> Type:
|
||||
for expr, type in self._typed_ast.judgements:
|
||||
if expr == query:
|
||||
return type
|
||||
raise RuntimeError(f"Cannot get type judgement for {query}")
|
||||
|
||||
def _make_cast_asserts(
|
||||
self, src_location: Location, expr: ast.expr, type: Type
|
||||
) -> list[ast.stmt]:
|
||||
match type:
|
||||
case UnknownType() | TopType():
|
||||
return []
|
||||
|
||||
case BaseType(name=name):
|
||||
return [
|
||||
self._build_assert(
|
||||
ast.Call(
|
||||
func=ast.Name(id="isinstance"),
|
||||
args=[expr, ast.Name(id=name)],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_cast_assert_message(src_location, expr, type),
|
||||
)
|
||||
]
|
||||
|
||||
case DerivedType(type=base):
|
||||
return self._make_cast_asserts(src_location, expr, base)
|
||||
|
||||
case UnitType():
|
||||
return [
|
||||
self._build_assert(
|
||||
ast.Compare(
|
||||
left=expr,
|
||||
ops=[ast.Is()],
|
||||
comparators=[
|
||||
ast.Constant(value=None),
|
||||
],
|
||||
),
|
||||
self._make_cast_assert_message(src_location, expr, type),
|
||||
),
|
||||
]
|
||||
|
||||
case AppliedType(body=body):
|
||||
return self._make_cast_asserts(src_location, expr, body)
|
||||
|
||||
case ConstraintType(type=base, constraint=constraint):
|
||||
asserts: list[ast.stmt] = self._make_cast_asserts(
|
||||
src_location, expr, base
|
||||
)
|
||||
asserts.append(
|
||||
self._make_constraint_assert(src_location, expr, constraint)
|
||||
)
|
||||
return asserts
|
||||
|
||||
case TypeVar(bound=bound):
|
||||
# TODO: check with type from arguments / use call-site context
|
||||
if bound is None:
|
||||
return []
|
||||
return self._make_cast_asserts(src_location, expr, bound)
|
||||
|
||||
case TupleType(items=items):
|
||||
asserts: list[ast.stmt] = [
|
||||
self._build_assert(
|
||||
ast.Call(
|
||||
func=ast.Name(id="isinstance"),
|
||||
args=[expr, ast.Name(id="tuple")],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_cast_assert_message(src_location, expr, type),
|
||||
),
|
||||
]
|
||||
assert isinstance(expr, ast.Tuple)
|
||||
for item, item_type in zip(expr.elts, items):
|
||||
asserts.extend(
|
||||
self._make_cast_asserts(src_location, item, item_type)
|
||||
)
|
||||
return asserts
|
||||
|
||||
case DataFrameType(columns=columns):
|
||||
self.define_is_dataframe = True
|
||||
asserts: list[ast.stmt] = [
|
||||
self._build_assert(
|
||||
ast.Call(
|
||||
func=ast.Name(id=self.IS_DATAFRAME_FUNC),
|
||||
args=[expr],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_cast_assert_message(
|
||||
src_location, expr, type, ": Not a dataframe"
|
||||
),
|
||||
),
|
||||
]
|
||||
for column in columns:
|
||||
asserts.append(
|
||||
self._build_assert(
|
||||
ast.Compare(
|
||||
left=ast.Constant(value=column.name),
|
||||
ops=[ast.In()],
|
||||
comparators=[expr],
|
||||
),
|
||||
self._make_cast_assert_message(
|
||||
src_location,
|
||||
expr,
|
||||
type,
|
||||
f": Missing column {column.name}",
|
||||
),
|
||||
)
|
||||
)
|
||||
asserts.extend(
|
||||
self._make_cast_asserts(
|
||||
src_location,
|
||||
ast.Subscript(
|
||||
value=expr, slice=ast.Constant(value=column.name)
|
||||
),
|
||||
column.type,
|
||||
)
|
||||
)
|
||||
return asserts
|
||||
|
||||
case ColumnType():
|
||||
self.define_is_column = True
|
||||
asserts: list[ast.stmt] = [
|
||||
self._build_assert(
|
||||
ast.Call(
|
||||
func=ast.Name(id=self.IS_COLUMN_FUNC),
|
||||
args=[expr],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_cast_assert_message(
|
||||
src_location, expr, type, ": Not a column"
|
||||
),
|
||||
),
|
||||
]
|
||||
inner_assert: Optional[ast.stmt] = self._make_column_inner_assert(
|
||||
src_location, expr, type
|
||||
)
|
||||
if inner_assert is not None:
|
||||
asserts.append(inner_assert)
|
||||
return asserts
|
||||
|
||||
case (
|
||||
Function()
|
||||
| OverloadedFunction()
|
||||
| ComplexType()
|
||||
| ExtensionType()
|
||||
| GenericType()
|
||||
| FrameGroupBy()
|
||||
| ColumnGroupBy()
|
||||
):
|
||||
self.logger.warning(f"Can't make assertion for type {type}")
|
||||
return []
|
||||
|
||||
# Ensure exhaustiveness
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
def _make_cast_assert_message(
|
||||
self,
|
||||
location: Location,
|
||||
expr: ast.expr,
|
||||
type: Type,
|
||||
extra: Optional[str] = None,
|
||||
) -> ast.expr:
|
||||
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
|
||||
return ast.JoinedStr(
|
||||
values=[
|
||||
ast.Constant(f"{loc_str}: CastError: Cannot cast "),
|
||||
ast.FormattedValue(
|
||||
value=ast.Attribute(
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="type"),
|
||||
args=[expr],
|
||||
keywords=[],
|
||||
),
|
||||
attr="__name__",
|
||||
),
|
||||
conversion=-1,
|
||||
),
|
||||
ast.Constant(f" to {type}{extra or ''}"),
|
||||
]
|
||||
)
|
||||
|
||||
def _make_constraint_assert(
|
||||
self, src_location: Location, expr: ast.expr, constraint: m.Expr
|
||||
) -> ast.stmt:
|
||||
test_func: ast.expr = self._get_constraint(constraint)
|
||||
return self._build_assert(
|
||||
ast.Call(
|
||||
func=test_func,
|
||||
args=[expr],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_constraint_assert_message(src_location, expr, constraint),
|
||||
)
|
||||
|
||||
def _make_constraint_assert_message(
|
||||
self, location: Location, expr: ast.expr, constraint: m.Expr
|
||||
) -> ast.expr:
|
||||
printer = MidasPrinter()
|
||||
constraint_str: str = printer.print(constraint)
|
||||
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||
# f"file.py:L1:1: ConstraintError: Value does not fit constraint 'v > 0'"
|
||||
return ast.Constant(
|
||||
f"{loc_str}: ConstraintError: Value does not fit constraint '{constraint_str}'"
|
||||
)
|
||||
|
||||
def _get_constraint(self, expr: m.Expr) -> ast.expr:
|
||||
for expr2, constraint in self._constraints:
|
||||
if expr2 == expr:
|
||||
return constraint
|
||||
|
||||
constraint: ast.expr = self._constraint_generator.generate(expr)
|
||||
self._constraints.append((expr, constraint))
|
||||
return constraint
|
||||
|
||||
def _is_dataframe_definition(self) -> ast.stmt:
|
||||
"""
|
||||
def IS_DATAFRAME_FUNC(obj) -> bool:
|
||||
import pandas as pd
|
||||
return isinstance(obj, pd.DataFrame)
|
||||
"""
|
||||
|
||||
return ast.FunctionDef(
|
||||
name=self.IS_DATAFRAME_FUNC,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[ast.arg(arg="obj")],
|
||||
args=[],
|
||||
kwonlyargs=[],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
),
|
||||
body=[
|
||||
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
|
||||
ast.Return(
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="isinstance"),
|
||||
args=[
|
||||
ast.Name(id="obj"),
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="DataFrame",
|
||||
),
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
),
|
||||
],
|
||||
decorator_list=[],
|
||||
returns=ast.Name(id="bool"),
|
||||
)
|
||||
|
||||
def _is_column_definition(self) -> ast.stmt:
|
||||
"""
|
||||
def IS_COLUMN_FUNC(obj) -> bool:
|
||||
import pandas as pd
|
||||
return isinstance(obj, pd.Series)
|
||||
"""
|
||||
|
||||
return ast.FunctionDef(
|
||||
name=self.IS_COLUMN_FUNC,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[ast.arg(arg="obj")],
|
||||
args=[],
|
||||
kwonlyargs=[],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
),
|
||||
body=[
|
||||
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
|
||||
ast.Return(
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="isinstance"),
|
||||
args=[
|
||||
ast.Name(id="obj"),
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="Series",
|
||||
),
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
),
|
||||
],
|
||||
decorator_list=[],
|
||||
returns=ast.Name(id="bool"),
|
||||
)
|
||||
|
||||
def _make_column_inner_assert(
|
||||
self, src_location: Location, column: ast.expr, type: ColumnType
|
||||
) -> Optional[ast.stmt]:
|
||||
# TODO: improve message, maybe chain contexts
|
||||
col: ast.expr = ast.Name(id="col")
|
||||
body: list[ast.stmt] = self._make_cast_asserts(src_location, col, type.type)
|
||||
if len(body) == 0:
|
||||
return None
|
||||
return ast.For(
|
||||
target=col,
|
||||
iter=column,
|
||||
body=body,
|
||||
orelse=[],
|
||||
)
|
||||
|
||||
def _convert_assertion(self, assertion: Assertion) -> ast.stmt:
|
||||
inputs: list[ast.expr] = []
|
||||
|
||||
for input in assertion.inputs:
|
||||
converted: ast.expr = self.convert(input)
|
||||
alias: ast.expr = self._make_alias(input, converted)
|
||||
inputs.append(alias)
|
||||
|
||||
test: ast.expr = assertion.builder(*inputs)
|
||||
location: Location = assertion.bound_expr.location
|
||||
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||
return self._build_assert(
|
||||
test, f"{loc_str}: AssertionError: {assertion.message}"
|
||||
)
|
||||
|
||||
def _apply_assertions(self, expr: p.Expr, assertions: list[Assertion]) -> ast.expr:
|
||||
for assertion in assertions:
|
||||
assert_stmt: ast.stmt
|
||||
assert_stmt = self._convert_assertion(assertion)
|
||||
self._add_assert(assert_stmt)
|
||||
|
||||
# Mutating list in frozen dataclass
|
||||
# Not ideal but easiest way to avoid duplicate assertions
|
||||
self._typed_ast.assertions.remove(assertion)
|
||||
|
||||
return expr.accept(self)
|
||||
480
midas/generator/stubs.py
Normal file
480
midas/generator/stubs.py
Normal file
@@ -0,0 +1,480 @@
|
||||
import ast
|
||||
from typing import Optional, assert_never
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.registry import Member, TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
DataFrameType,
|
||||
DerivedType,
|
||||
ExtensionType,
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
TupleType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
UnknownType,
|
||||
Variance,
|
||||
substitute_typevars,
|
||||
)
|
||||
|
||||
Empty = ast.Constant(value=...)
|
||||
|
||||
|
||||
class StubsGenerator:
|
||||
def __init__(self, types: TypesRegistry) -> None:
|
||||
self.types: TypesRegistry = types
|
||||
self.stubs: list[ast.stmt] = []
|
||||
self.typing_imports: set[str] = set()
|
||||
self.import_pandas: bool = False
|
||||
self.protocol_idx: int = 0
|
||||
self.stub_idx: int = 0
|
||||
self.type_var_idx: int = 0
|
||||
self.substitutions: dict[str, dict[str, Type]] = {}
|
||||
|
||||
def generate_stubs(self) -> ast.Module:
|
||||
self.stubs = []
|
||||
self.typing_imports = set()
|
||||
self.import_pandas = False
|
||||
for name, type in self.types._types.items():
|
||||
# Skip builtin types, not just based on name so the user can override
|
||||
# TODO: check if added members on builtin type
|
||||
match type:
|
||||
case BaseType(name=name_) if name == name_:
|
||||
continue
|
||||
case GenericType(
|
||||
name=name1,
|
||||
body=BaseType(name=name2),
|
||||
) if (
|
||||
name == name1 == name2
|
||||
):
|
||||
continue
|
||||
self.generate_stub(name, type)
|
||||
|
||||
imports: list[ast.stmt] = [
|
||||
ast.ImportFrom(
|
||||
module="__future__",
|
||||
names=[ast.alias(name="annotations")],
|
||||
level=0,
|
||||
)
|
||||
]
|
||||
if len(self.typing_imports) != 0:
|
||||
imports.append(
|
||||
ast.ImportFrom(
|
||||
module="typing",
|
||||
names=[
|
||||
ast.alias(name=name) for name in sorted(self.typing_imports)
|
||||
],
|
||||
level=0,
|
||||
)
|
||||
)
|
||||
if self.import_pandas:
|
||||
imports.append(
|
||||
ast.Import(
|
||||
names=[
|
||||
ast.alias(
|
||||
name="pandas",
|
||||
asname="pd",
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
return ast.Module(body=imports + self.stubs, type_ignores=[])
|
||||
|
||||
def generate_stub(self, name: str, type: Type):
|
||||
base_type: Type = type
|
||||
|
||||
# TODO: improve
|
||||
match type:
|
||||
case DerivedType(name=name_) | GenericType(name=name_) if name_ == name:
|
||||
pass
|
||||
case UnitType() if name == "None":
|
||||
pass
|
||||
case TopType() if name == "Any":
|
||||
pass
|
||||
case _:
|
||||
alias = ast.Assign(
|
||||
targets=[ast.Name(id=name)], value=self.dump_type(type)
|
||||
)
|
||||
self.add_stub(alias)
|
||||
return
|
||||
|
||||
members: dict[str, Member] = self.types._members.get(name, {})
|
||||
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
|
||||
return
|
||||
|
||||
bases: list[ast.expr] = []
|
||||
substitutions: dict[str, Type] = {}
|
||||
bases, substitutions = self.get_bases(type)
|
||||
self.substitutions[name] = substitutions
|
||||
|
||||
body = self.generate_body(members, substitutions)
|
||||
stub = ast.ClassDef(
|
||||
name=name,
|
||||
bases=bases,
|
||||
body=body,
|
||||
keywords=[],
|
||||
decorator_list=[],
|
||||
)
|
||||
self.add_stub(stub)
|
||||
|
||||
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
|
||||
match type:
|
||||
case DerivedType(type=base):
|
||||
return [self.dump_type(base)], {}
|
||||
|
||||
case GenericType(params=params, body=body):
|
||||
self.add_typing_import("Generic")
|
||||
type_vars: ast.expr
|
||||
|
||||
params2: list[TypeVar] = self.define_type_vars(params)
|
||||
if len(params) == 1:
|
||||
type_vars = ast.Name(id=params2[0].name)
|
||||
else:
|
||||
type_vars = ast.Tuple(
|
||||
elts=[ast.Name(id=param.name) for param in params2]
|
||||
)
|
||||
|
||||
substitutions: dict[str, TypeVar] = {
|
||||
param.name: param2 for param, param2 in zip(params, params2)
|
||||
}
|
||||
|
||||
body_bases, body_subsitutions = self.get_bases(body)
|
||||
return (
|
||||
body_bases
|
||||
+ [
|
||||
ast.Subscript(
|
||||
value=ast.Name(id="Generic"),
|
||||
slice=type_vars,
|
||||
)
|
||||
],
|
||||
body_subsitutions | substitutions,
|
||||
)
|
||||
|
||||
case ConstraintType(type=base):
|
||||
return self.get_bases(base)
|
||||
|
||||
case TypeVar(bound=bound) if bound is not None:
|
||||
return [self.dump_type(bound)], {}
|
||||
|
||||
case _:
|
||||
return [], {}
|
||||
|
||||
def generate_body(
|
||||
self, members: dict[str, Member], substitutions: dict[str, Type]
|
||||
) -> list[ast.stmt]:
|
||||
if len(members) == 0:
|
||||
return [ast.Expr(value=Empty)]
|
||||
|
||||
body: list[ast.stmt] = []
|
||||
for name, member in members.items():
|
||||
type: Type = member.type
|
||||
type = substitute_typevars(type, substitutions)
|
||||
match member.kind:
|
||||
case m.MemberKind.PROPERTY:
|
||||
body.append(
|
||||
ast.AnnAssign(
|
||||
target=ast.Name(id=name),
|
||||
annotation=self.dump_type(type),
|
||||
simple=1,
|
||||
)
|
||||
)
|
||||
case m.MemberKind.METHOD:
|
||||
body.extend(self.dump_method(name, type))
|
||||
return body
|
||||
|
||||
def dump_type(self, type: Type) -> ast.expr:
|
||||
match type:
|
||||
case DerivedType(name=name) | GenericType(name=name) if (
|
||||
name in self.substitutions
|
||||
):
|
||||
type = substitute_typevars(type, self.substitutions[name])
|
||||
|
||||
match type:
|
||||
case TopType() | UnknownType():
|
||||
self.add_typing_import("Any")
|
||||
return ast.Name(id="Any")
|
||||
|
||||
case BaseType(name=name):
|
||||
return ast.Name(id=name)
|
||||
|
||||
case DerivedType(name=name):
|
||||
return ast.Name(id=name)
|
||||
|
||||
case UnitType():
|
||||
return ast.Constant(value=None)
|
||||
|
||||
case Function():
|
||||
name: str = self.define_protocol(type)
|
||||
return ast.Name(id=name)
|
||||
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
if len(overloads) == 1:
|
||||
return self.dump_type(overloads[0])
|
||||
return ast.BinOp(
|
||||
left=self.dump_type(OverloadedFunction(overloads=overloads[:-1])),
|
||||
op=ast.BitOr(),
|
||||
right=self.dump_type(overloads[-1]),
|
||||
)
|
||||
|
||||
case ComplexType():
|
||||
name: str = self.new_stub_name()
|
||||
self.generate_stub(name, type)
|
||||
return ast.Name(id=name)
|
||||
|
||||
case ExtensionType():
|
||||
raise NotImplementedError
|
||||
|
||||
case TypeVar():
|
||||
return ast.Name(id=type.name)
|
||||
|
||||
case GenericType(name=name):
|
||||
params: ast.expr
|
||||
if len(type.params) == 1:
|
||||
params = self.dump_type(type.params[0])
|
||||
else:
|
||||
params = ast.Tuple(
|
||||
elts=[self.dump_type(param) for param in type.params]
|
||||
)
|
||||
return ast.Subscript(
|
||||
value=ast.Name(id=type.name),
|
||||
slice=params,
|
||||
)
|
||||
|
||||
case AppliedType():
|
||||
args: ast.expr
|
||||
if len(type.args) == 1:
|
||||
args = self.dump_type(type.args[0])
|
||||
else:
|
||||
args = ast.Tuple(elts=[self.dump_type(arg) for arg in type.args])
|
||||
return ast.Subscript(
|
||||
value=ast.Name(id=type.name),
|
||||
slice=args,
|
||||
)
|
||||
|
||||
case ConstraintType():
|
||||
return self.dump_type(type.type)
|
||||
|
||||
case TupleType(items=items):
|
||||
return ast.Subscript(
|
||||
value=ast.Name(id="tuple"),
|
||||
slice=ast.Tuple(
|
||||
elts=[self.dump_type(item) for item in items],
|
||||
),
|
||||
)
|
||||
|
||||
case ColumnType(type=inner):
|
||||
self.import_pandas = True
|
||||
return ast.Subscript(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="Series",
|
||||
),
|
||||
slice=self.dump_type(inner),
|
||||
)
|
||||
|
||||
case DataFrameType():
|
||||
self.import_pandas = True
|
||||
return ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="DataFrame",
|
||||
)
|
||||
|
||||
case FrameGroupBy():
|
||||
self.import_pandas = True
|
||||
return ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="api",
|
||||
),
|
||||
attr="typing",
|
||||
),
|
||||
attr="DataFrameGroupBy",
|
||||
)
|
||||
|
||||
case ColumnGroupBy():
|
||||
self.import_pandas = True
|
||||
return ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="api",
|
||||
),
|
||||
attr="typing",
|
||||
),
|
||||
attr="SeriesGroupBy",
|
||||
)
|
||||
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
def dump_method(
|
||||
self, name: str, method: Type, overloaded: bool = False
|
||||
) -> list[ast.stmt]:
|
||||
match method:
|
||||
case Function():
|
||||
if overloaded:
|
||||
self.add_typing_import("overload")
|
||||
return [
|
||||
ast.FunctionDef(
|
||||
name=name,
|
||||
args=self.dump_params(method.params, with_self=True),
|
||||
returns=self.dump_type(method.returns),
|
||||
body=[ast.Expr(value=Empty)],
|
||||
decorator_list=[ast.Name(id="overload")] if overloaded else [],
|
||||
)
|
||||
]
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
stmts: list[ast.stmt] = []
|
||||
for overload in overloads:
|
||||
stmts.extend(self.dump_method(name, overload, True))
|
||||
return stmts
|
||||
case _:
|
||||
return [
|
||||
ast.AnnAssign(
|
||||
target=ast.Name(id=name),
|
||||
annotation=self.dump_type(method),
|
||||
simple=1,
|
||||
)
|
||||
]
|
||||
|
||||
def dump_params(self, params: ParamSpec, with_self: bool = False) -> ast.arguments:
|
||||
pos: list[ast.arg] = [
|
||||
ast.arg(
|
||||
arg=f"_{param.pos}",
|
||||
annotation=self.dump_type(param.type),
|
||||
)
|
||||
for param in params.pos
|
||||
]
|
||||
mixed: list[ast.arg] = [
|
||||
ast.arg(
|
||||
arg=param.name,
|
||||
annotation=self.dump_type(param.type),
|
||||
)
|
||||
for param in params.mixed
|
||||
]
|
||||
kw: list[ast.arg] = [
|
||||
ast.arg(
|
||||
arg=param.name,
|
||||
annotation=self.dump_type(param.type),
|
||||
)
|
||||
for param in params.kw
|
||||
]
|
||||
defaults: list[ast.expr] = [
|
||||
Empty for param in params.pos + params.mixed if not param.required
|
||||
]
|
||||
kw_defaults: list[Optional[ast.expr]] = [
|
||||
None if param.required else Empty for param in params.kw
|
||||
]
|
||||
if with_self:
|
||||
arg = ast.arg(arg="self", annotation=None)
|
||||
if len(pos) != 0:
|
||||
pos.insert(0, arg)
|
||||
else:
|
||||
mixed.insert(0, arg)
|
||||
return ast.arguments(
|
||||
posonlyargs=pos,
|
||||
args=mixed,
|
||||
kwonlyargs=kw,
|
||||
defaults=defaults,
|
||||
kw_defaults=kw_defaults,
|
||||
)
|
||||
|
||||
def define_protocol(self, func: Function) -> str:
|
||||
self.add_typing_import("Protocol")
|
||||
name: str = self.new_protocol_name()
|
||||
protocol = ast.ClassDef(
|
||||
name=name,
|
||||
bases=[ast.Name(id="Protocol")],
|
||||
keywords=[],
|
||||
body=[
|
||||
ast.FunctionDef(
|
||||
name="__call__",
|
||||
args=self.dump_params(func.params, with_self=True),
|
||||
returns=self.dump_type(func.returns),
|
||||
body=[ast.Expr(value=Empty)],
|
||||
decorator_list=[],
|
||||
),
|
||||
],
|
||||
decorator_list=[],
|
||||
)
|
||||
self.add_stub(protocol)
|
||||
return name
|
||||
|
||||
def new_protocol_name(self) -> str:
|
||||
name: str = f"_Protocol{self.protocol_idx}"
|
||||
self.protocol_idx += 1
|
||||
return name
|
||||
|
||||
def new_stub_name(self) -> str:
|
||||
name: str = f"_Stub_{self.stub_idx}"
|
||||
self.stub_idx += 1
|
||||
return name
|
||||
|
||||
def new_type_var_name(self) -> str:
|
||||
name: str = f"_T{self.type_var_idx}"
|
||||
self.type_var_idx += 1
|
||||
return name
|
||||
|
||||
def add_stub(self, stub: ast.stmt):
|
||||
self.stubs.append(stub)
|
||||
|
||||
def add_typing_import(self, name: str):
|
||||
self.typing_imports.add(name)
|
||||
|
||||
def define_type_vars(self, vars: list[TypeVar]) -> list[TypeVar]:
|
||||
vars2: list[TypeVar] = []
|
||||
for var in vars:
|
||||
vars2.append(self.define_type_var(var))
|
||||
return vars2
|
||||
|
||||
def define_type_var(self, var: TypeVar) -> TypeVar:
|
||||
name: str = self.new_type_var_name()
|
||||
self.add_typing_import("TypeVar")
|
||||
|
||||
kwargs: list[ast.keyword] = []
|
||||
if var.bound is not None:
|
||||
kwargs.append(
|
||||
ast.keyword(
|
||||
arg="bound",
|
||||
value=self.dump_type(var.bound),
|
||||
)
|
||||
)
|
||||
if var.variance == Variance.COVARIANT:
|
||||
kwargs.append(
|
||||
ast.keyword(
|
||||
arg="covariant",
|
||||
value=ast.Constant(value=True),
|
||||
)
|
||||
)
|
||||
elif var.variance == Variance.CONTRAVARIANT:
|
||||
kwargs.append(
|
||||
ast.keyword(
|
||||
arg="contravariant",
|
||||
value=ast.Constant(value=True),
|
||||
)
|
||||
)
|
||||
self.add_stub(
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id=name)],
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="TypeVar"),
|
||||
args=[
|
||||
ast.Constant(value=name),
|
||||
],
|
||||
keywords=kwargs,
|
||||
),
|
||||
)
|
||||
)
|
||||
return TypeVar(name=name, bound=None)
|
||||
0
midas/lexer/__init__.py
Normal file
0
midas/lexer/__init__.py
Normal file
@@ -1,17 +1,25 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from lexer.position import Position
|
||||
from lexer.token import Token, TokenType
|
||||
from midas.lexer.position import Position
|
||||
from midas.lexer.token import Token, TokenType
|
||||
|
||||
|
||||
class MidasSyntaxError(Exception):
|
||||
def __init__(self, pos: Position, message: str):
|
||||
super().__init__(f"[ERROR] Error at {pos}: {message}")
|
||||
self.pos: Position = pos
|
||||
self.message: str = message
|
||||
|
||||
|
||||
class Lexer(ABC):
|
||||
"""An abstract lexer which provides methods to easily extend it into a concrete one
|
||||
|
||||
This implementation is based on the [_Crafting Interpreters_][1] book by Robert Nystrom,
|
||||
more specifically on my [previous Python implementation](https://git.kb28.ch/HEL/pebble)
|
||||
more specifically on my [previous Python implementation][2]
|
||||
|
||||
[1]: https://craftinginterpreters.com/
|
||||
[2]: https://git.kb28.ch/HEL/pebble
|
||||
"""
|
||||
|
||||
def __init__(self, source: str, file: Optional[str] = None) -> None:
|
||||
@@ -38,9 +46,9 @@ class Lexer(ABC):
|
||||
msg (str): the error message
|
||||
|
||||
Raises:
|
||||
SyntaxError
|
||||
MidasSyntaxError
|
||||
"""
|
||||
raise SyntaxError(f"[ERROR] Error at {self.start_pos}: {msg}")
|
||||
raise MidasSyntaxError(self.start_pos, msg)
|
||||
|
||||
def process(self) -> list[Token]:
|
||||
"""Scan tokens out of the source text
|
||||
@@ -49,7 +57,7 @@ class Lexer(ABC):
|
||||
list[Token]: all the tokens that could be scanned
|
||||
|
||||
Raises:
|
||||
SyntaxError: if a syntax error is found
|
||||
MidasSyntaxError: if a syntax error is found
|
||||
"""
|
||||
self.scan_tokens()
|
||||
self.tokens.append(Token(TokenType.EOF, "", None, self.get_position()))
|
||||
@@ -161,6 +169,6 @@ class Lexer(ABC):
|
||||
def scan_token(self) -> None:
|
||||
"""Scan a token
|
||||
|
||||
This function should (at least) consume the current character and produce the appropriate token(s), using `add_token`
|
||||
This function should (at least) consume the current character and produce the appropriate token(s), using :func:`add_token`
|
||||
"""
|
||||
pass
|
||||
@@ -1,6 +1,5 @@
|
||||
from lexer.base import Lexer
|
||||
from lexer.keyword import MIDAS_KEYWORDS
|
||||
from lexer.token import TokenType
|
||||
from midas.lexer.base import Lexer
|
||||
from midas.lexer.token import KEYWORDS, TokenType
|
||||
|
||||
|
||||
class MidasLexer(Lexer):
|
||||
@@ -31,30 +30,34 @@ class MidasLexer(Lexer):
|
||||
self.add_token(
|
||||
TokenType.EQUAL_EQUAL if self.match("=") else TokenType.EQUAL
|
||||
)
|
||||
case "!":
|
||||
if self.match("="):
|
||||
self.add_token(TokenType.BANG_EQUAL)
|
||||
else:
|
||||
self.error("Unexpected single bang. Did you mean '!=' ?")
|
||||
case "!" if self.match("="):
|
||||
self.add_token(TokenType.BANG_EQUAL)
|
||||
case ":":
|
||||
self.add_token(TokenType.COLON)
|
||||
case ".":
|
||||
self.add_token(TokenType.DOT)
|
||||
case "&":
|
||||
self.add_token(TokenType.AND)
|
||||
case "?":
|
||||
self.add_token(TokenType.QMARK)
|
||||
case ",":
|
||||
self.add_token(TokenType.COMMA)
|
||||
case "_":
|
||||
case "_" if not self.is_identifier_char(self.peek_next(), start=False):
|
||||
self.add_token(TokenType.UNDERSCORE)
|
||||
case "-" if self.match(">"):
|
||||
self.add_token(TokenType.ARROW)
|
||||
case "+":
|
||||
self.add_token(TokenType.PLUS)
|
||||
case "-":
|
||||
self.add_token(TokenType.MINUS)
|
||||
case "*":
|
||||
self.add_token(TokenType.STAR)
|
||||
case "/" if self.match("/"):
|
||||
self.scan_comment()
|
||||
case "/" if self.match("*"):
|
||||
self.scan_comment_multiline()
|
||||
case "/":
|
||||
if self.match("/"):
|
||||
self.scan_comment()
|
||||
elif self.match("*"):
|
||||
self.scan_comment_multiline()
|
||||
else:
|
||||
self.add_token(TokenType.SLASH)
|
||||
self.add_token(TokenType.SLASH)
|
||||
case "\n":
|
||||
self.add_token(TokenType.NEWLINE)
|
||||
case " " | "\r" | "\t":
|
||||
@@ -66,15 +69,34 @@ class MidasLexer(Lexer):
|
||||
):
|
||||
self.advance()
|
||||
self.add_token(TokenType.WHITESPACE)
|
||||
case '"' | "'":
|
||||
self.scan_string(char)
|
||||
case _:
|
||||
if char.isdigit():
|
||||
self.scan_number()
|
||||
elif char.isalpha():
|
||||
elif self.is_identifier_char(char, start=True):
|
||||
self.scan_identifier()
|
||||
else:
|
||||
self.error("Unexpected character")
|
||||
return None
|
||||
|
||||
def scan_string(self, opening: str):
|
||||
"""Scan the rest of a string and add it as a token
|
||||
|
||||
Args:
|
||||
opening (str): the opening quote or double quote, to be matched
|
||||
at the end of the string
|
||||
"""
|
||||
while self.peek() != opening and not self.is_at_end():
|
||||
self.advance()
|
||||
|
||||
if self.is_at_end():
|
||||
self.error("Unterminated string")
|
||||
|
||||
self.advance()
|
||||
value: str = self.source[self.start + 1 : self.idx - 1]
|
||||
self.add_token(TokenType.STRING, value)
|
||||
|
||||
def scan_number(self):
|
||||
"""Scan the rest of number and add it as a token
|
||||
|
||||
@@ -98,11 +120,11 @@ class MidasLexer(Lexer):
|
||||
An identifier starts with a letter, followed by any number of
|
||||
alphanumerical characters or underscores
|
||||
"""
|
||||
while self.peek().isalnum() or self.peek() == "_":
|
||||
while self.is_identifier_char(self.peek(), start=False):
|
||||
self.advance()
|
||||
|
||||
lexeme: str = self.source[self.start : self.idx]
|
||||
token_type: TokenType = MIDAS_KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
|
||||
token_type: TokenType = KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
|
||||
self.add_token(token_type)
|
||||
|
||||
def scan_comment(self):
|
||||
@@ -129,3 +151,24 @@ class MidasLexer(Lexer):
|
||||
if not self.is_at_end():
|
||||
self.advance()
|
||||
self.add_token(TokenType.COMMENT)
|
||||
|
||||
def is_identifier_char(self, char: str, *, start: bool) -> bool:
|
||||
"""Check whether a character is a valid as part of an identifier
|
||||
|
||||
Identifiers can contain any alphanumerical character or underscore.
|
||||
They cannot start with a digit.
|
||||
|
||||
Args:
|
||||
char (str): the character to check
|
||||
start (bool): whether this is the first character of the identifier
|
||||
|
||||
Returns:
|
||||
bool: `True` if the character is valid, `False` otherwise
|
||||
"""
|
||||
if char == "_":
|
||||
return True
|
||||
if char.isalpha():
|
||||
return True
|
||||
if not start and char.isdigit():
|
||||
return True
|
||||
return False
|
||||
@@ -5,6 +5,7 @@ from typing import Optional
|
||||
@dataclass(frozen=True)
|
||||
class Position:
|
||||
"""A simple structure to store the position of a token"""
|
||||
|
||||
file: Optional[str]
|
||||
line: int
|
||||
column: int
|
||||
120
midas/lexer/token.py
Normal file
120
midas/lexer/token.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Any
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.lexer.position import Position
|
||||
|
||||
|
||||
class TokenType(Enum):
|
||||
# Punctuation
|
||||
LEFT_PAREN = auto()
|
||||
RIGHT_PAREN = auto()
|
||||
LEFT_BRACKET = auto()
|
||||
RIGHT_BRACKET = auto()
|
||||
LEFT_BRACE = auto()
|
||||
RIGHT_BRACE = auto()
|
||||
COLON = auto()
|
||||
COMMA = auto()
|
||||
UNDERSCORE = auto()
|
||||
ARROW = auto()
|
||||
AND = auto()
|
||||
QMARK = auto()
|
||||
DOT = auto()
|
||||
|
||||
# Operators
|
||||
PLUS = auto()
|
||||
MINUS = auto()
|
||||
STAR = auto()
|
||||
SLASH = auto()
|
||||
GREATER = auto()
|
||||
GREATER_EQUAL = auto()
|
||||
LESS = auto()
|
||||
LESS_EQUAL = auto()
|
||||
EQUAL = auto()
|
||||
EQUAL_EQUAL = auto()
|
||||
BANG_EQUAL = auto()
|
||||
|
||||
# Literals
|
||||
IDENTIFIER = auto()
|
||||
NUMBER = auto()
|
||||
TRUE = auto()
|
||||
FALSE = auto()
|
||||
NONE = auto()
|
||||
STRING = auto()
|
||||
|
||||
# Keywords
|
||||
TYPE = auto()
|
||||
ALIAS = auto()
|
||||
PREDICATE = auto()
|
||||
EXTEND = auto()
|
||||
WHERE = auto()
|
||||
PROP = auto()
|
||||
DEF = auto()
|
||||
FUNC = auto()
|
||||
|
||||
# Misc
|
||||
COMMENT = auto()
|
||||
WHITESPACE = auto()
|
||||
EOF = auto()
|
||||
NEWLINE = auto()
|
||||
|
||||
|
||||
KEYWORDS: dict[str, TokenType] = {
|
||||
"type": TokenType.TYPE,
|
||||
"alias": TokenType.ALIAS,
|
||||
"predicate": TokenType.PREDICATE,
|
||||
"extend": TokenType.EXTEND,
|
||||
"where": TokenType.WHERE,
|
||||
"true": TokenType.TRUE,
|
||||
"false": TokenType.FALSE,
|
||||
"none": TokenType.NONE,
|
||||
"prop": TokenType.PROP,
|
||||
"def": TokenType.DEF,
|
||||
"fn": TokenType.FUNC,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Token:
|
||||
"""A scanned token"""
|
||||
|
||||
type: TokenType
|
||||
lexeme: str
|
||||
value: Any
|
||||
position: Position
|
||||
|
||||
def get_location(self) -> Location:
|
||||
lineno: int = self.position.line
|
||||
col_offset: int = self.position.column - 1
|
||||
end_lineno = lineno
|
||||
end_col_offset = col_offset
|
||||
for c in self.lexeme:
|
||||
end_col_offset += 1
|
||||
if c == "\n":
|
||||
end_lineno += 1
|
||||
end_col_offset = 0
|
||||
return Location(
|
||||
lineno=lineno,
|
||||
col_offset=col_offset,
|
||||
end_lineno=end_lineno,
|
||||
end_col_offset=end_col_offset,
|
||||
)
|
||||
|
||||
def location_to(self, to: Token) -> Location:
|
||||
"""Create a new :class:`Location` spanning from this token to another
|
||||
|
||||
Args:
|
||||
to (Token): the end token
|
||||
|
||||
Returns:
|
||||
Location: a new :class:`Location` starting at this token and ending
|
||||
at `to`, both included
|
||||
"""
|
||||
return Location.span(self.get_location(), to.get_location())
|
||||
|
||||
@property
|
||||
def is_keyword(self) -> bool:
|
||||
return self.lexeme in KEYWORDS
|
||||
@@ -2,8 +2,8 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from lexer.token import Token, TokenType
|
||||
from parser.errors import ParsingError
|
||||
from midas.lexer.token import Token, TokenType
|
||||
from midas.parser.errors import ParsingError
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -16,6 +16,9 @@ class TokenError:
|
||||
def get_report(self) -> str:
|
||||
"""Get a detailed error message
|
||||
|
||||
The error message is formatted as "(<position>) Error at <token>: <message>".
|
||||
For example: "(L2:5) Error at '3': Expected ')' after arguments."
|
||||
|
||||
Returns:
|
||||
str: the complete error message
|
||||
"""
|
||||
@@ -32,9 +35,10 @@ class Parser(ABC, Generic[T]):
|
||||
"""An abstract parser which provides methods to easily extend it into a concrete one
|
||||
|
||||
This implementation is based on the [_Crafting Interpreters_][1] book by Robert Nystrom,
|
||||
more specifically on my [previous Python implementation](https://git.kb28.ch/HEL/pebble)
|
||||
more specifically on my [previous Python implementation][2]
|
||||
|
||||
[1]: https://craftinginterpreters.com/
|
||||
[2]: https://git.kb28.ch/HEL/pebble
|
||||
"""
|
||||
|
||||
IGNORE: set[TokenType] = {
|
||||
@@ -173,7 +177,7 @@ class Parser(ABC, Generic[T]):
|
||||
error_msg (str): the error message if the token doesn't match
|
||||
|
||||
Raises:
|
||||
SyntaxError: if the current token doesn't match the given type
|
||||
ParsingError: if the current token doesn't match the given type
|
||||
|
||||
Returns:
|
||||
Token: the current token which matched the given type
|
||||
874
midas/parser/midas.py
Normal file
874
midas/parser/midas.py
Normal file
@@ -0,0 +1,874 @@
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.midas import (
|
||||
AliasStmt,
|
||||
BinaryExpr,
|
||||
CallExpr,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
ExtensionType,
|
||||
FrameType,
|
||||
FunctionType,
|
||||
GenericType,
|
||||
GetExpr,
|
||||
GroupingExpr,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
MemberKind,
|
||||
MemberStmt,
|
||||
NamedType,
|
||||
ParamSpec,
|
||||
PredicateStmt,
|
||||
Stmt,
|
||||
Type,
|
||||
TypeParam,
|
||||
TypeStmt,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
WildcardExpr,
|
||||
)
|
||||
from midas.lexer.token import KEYWORDS, Token, TokenType
|
||||
from midas.parser.base import Parser
|
||||
from midas.parser.errors import ParsingError
|
||||
|
||||
|
||||
class MidasParser(Parser[list[Stmt]]):
|
||||
"""A simple parser for midas type definitions"""
|
||||
|
||||
SYNC_BOUNDARY: set[TokenType] = {
|
||||
TokenType.ALIAS,
|
||||
TokenType.TYPE,
|
||||
TokenType.EXTEND,
|
||||
TokenType.PREDICATE,
|
||||
TokenType.PROP,
|
||||
TokenType.FUNC,
|
||||
}
|
||||
|
||||
def parse(self) -> list[Stmt]:
|
||||
statements: list[Stmt] = []
|
||||
while not self.is_at_end():
|
||||
stmt: Optional[Stmt] = self.declaration()
|
||||
if stmt is None:
|
||||
print("Early stop")
|
||||
break
|
||||
statements.append(stmt)
|
||||
return statements
|
||||
|
||||
def synchronize(self):
|
||||
"""Skip tokens until a synchronization boundary is found
|
||||
|
||||
This method allows gracefully recovering from a parse error
|
||||
to a safe place and continue parsing
|
||||
"""
|
||||
self.advance()
|
||||
while not self.is_at_end():
|
||||
if self.previous().type == TokenType.NEWLINE:
|
||||
return
|
||||
if self.peek().type in self.SYNC_BOUNDARY:
|
||||
return
|
||||
self.advance()
|
||||
|
||||
def declaration(self) -> Optional[Stmt]:
|
||||
"""Try and parse a declaration
|
||||
|
||||
Any parsing error is caught and `None` is returned
|
||||
|
||||
Returns:
|
||||
Optional[Stmt]: the parsed Midas statement, or `None` if a ParsingError was raised
|
||||
"""
|
||||
try:
|
||||
if self.match(TokenType.TYPE):
|
||||
return self.type_declaration()
|
||||
if self.match(TokenType.ALIAS):
|
||||
return self.alias_declaration()
|
||||
if self.match(TokenType.EXTEND):
|
||||
return self.extend_declaration()
|
||||
if self.match(TokenType.PREDICATE):
|
||||
return self.predicate_declaration()
|
||||
raise self.error(self.peek(), "Unexpected token")
|
||||
except ParsingError:
|
||||
self.synchronize()
|
||||
return None
|
||||
|
||||
def type_declaration(self) -> TypeStmt:
|
||||
"""Parse a type declaration
|
||||
|
||||
A type declaration creates a named subtype of a type expression.
|
||||
It can have an optional template expression after its name, wrapped in brackets, to handle type parameters.
|
||||
|
||||
A type statement consists of:
|
||||
- the `type` keyword
|
||||
- a name (identifier)
|
||||
- (optional) type parameters
|
||||
- a body, a type expression (see :func:`type_expr`)
|
||||
|
||||
Returns:
|
||||
TypeStmt: the parsed type declaration statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume_identifier("Expected type name")
|
||||
params: list[TypeParam] = self.type_params()
|
||||
|
||||
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
|
||||
|
||||
type: Type = self.type_expr()
|
||||
|
||||
return TypeStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
name=name,
|
||||
params=params,
|
||||
type=type,
|
||||
)
|
||||
|
||||
def type_params(self) -> list[TypeParam]:
|
||||
"""Parse a list of type parameters
|
||||
|
||||
Type parameters are a comma-separated list of type variables wrapped in brackets.
|
||||
Each type variable is either a simple variable, or a bounded variable written `S <: T`
|
||||
|
||||
Returns:
|
||||
list[TypeParam]: the list of type parameters, if any, or an empty list
|
||||
"""
|
||||
if not self.match(TokenType.LEFT_BRACKET):
|
||||
return []
|
||||
|
||||
params: list[TypeParam] = []
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||
name: Token = self.consume_identifier("Expected type variable")
|
||||
bound: Optional[Type] = None
|
||||
if self.match(TokenType.LESS):
|
||||
self.consume(TokenType.COLON, "Expected ':' after '<'")
|
||||
bound = self.type_expr()
|
||||
params.append(
|
||||
TypeParam(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
bound=bound,
|
||||
)
|
||||
)
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
|
||||
return params
|
||||
|
||||
def alias_declaration(self) -> AliasStmt:
|
||||
"""Parse an alias declaration
|
||||
|
||||
An alias statement consists of:
|
||||
- the `alias` keyword
|
||||
- a name (identifier)
|
||||
- a body, a type expression (see :func:`type_expr`)
|
||||
|
||||
Returns:
|
||||
AliasStmt: the parsed alias declaration statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume_identifier("Expected alias name")
|
||||
|
||||
self.consume(TokenType.EQUAL, "Expected '=' before alias definition")
|
||||
|
||||
type: Type = self.type_expr()
|
||||
|
||||
return AliasStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
name=name,
|
||||
type=type,
|
||||
)
|
||||
|
||||
def type_expr(self) -> Type:
|
||||
"""Parse a type expression
|
||||
|
||||
A type expression can either be a function type (see :func:`function`)
|
||||
or a constraint type (see :func:`constraint_type`)
|
||||
|
||||
Returns:
|
||||
TypeExpr: the parsed type expression
|
||||
"""
|
||||
base: Type
|
||||
if self.match(TokenType.FUNC):
|
||||
base = self.function()
|
||||
else:
|
||||
base = self.constraint_type()
|
||||
if self.match(TokenType.AND):
|
||||
extension: ComplexType = self.complex_type()
|
||||
return ExtensionType(
|
||||
location=Location.span(base.location, extension.location),
|
||||
base=base,
|
||||
extension=extension,
|
||||
)
|
||||
return base
|
||||
|
||||
def constraint_type(self) -> Type:
|
||||
"""Parse a constraint type expression
|
||||
|
||||
A constraint type consists of a base type (see :func:`base_type`),
|
||||
optionally followed by the `where` keyword and a constraint
|
||||
expression (see :func:`constraint`)
|
||||
|
||||
Returns:
|
||||
Type: the parsed constraint type expression
|
||||
"""
|
||||
type: Type = self.base_type()
|
||||
if self.match(TokenType.WHERE):
|
||||
constraint: Expr = self.constraint()
|
||||
return ConstraintType(
|
||||
location=Location.span(type.location, constraint.location),
|
||||
type=type,
|
||||
constraint=constraint,
|
||||
)
|
||||
return type
|
||||
|
||||
def base_type(self) -> Type:
|
||||
"""Parse a base type expression
|
||||
|
||||
A base type is either a parenthesized type expression (see :func:`type_expr`)
|
||||
or a generic type (see :func:`generic_type`)
|
||||
|
||||
Returns:
|
||||
Type: the parsed base type expression
|
||||
"""
|
||||
if self.match(TokenType.LEFT_PAREN):
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
|
||||
return type
|
||||
|
||||
if self.check(TokenType.LEFT_BRACE):
|
||||
return self.complex_type()
|
||||
|
||||
return self.generic_type()
|
||||
|
||||
def generic_type(self) -> Type:
|
||||
"""Parse a generic type expression
|
||||
|
||||
A generic type consists of a named type (see :func:`named_type`),
|
||||
optionally followed by type arguments in brackets.
|
||||
|
||||
The special `Frame` type accepts a frame schema instead of type
|
||||
arguments (see :func:`frame_type`).
|
||||
|
||||
Returns:
|
||||
Type: the parsed generic type
|
||||
"""
|
||||
type: NamedType = self.named_type()
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
if type.name.lexeme == "Frame":
|
||||
return self.frame_type()
|
||||
args: list[Type] = self.type_args()
|
||||
return GenericType(
|
||||
location=Location.span(type.location, self.previous().get_location()),
|
||||
type=type,
|
||||
args=args,
|
||||
)
|
||||
return type
|
||||
|
||||
def type_args(self) -> list[Type]:
|
||||
"""Parse a list of type arguments
|
||||
|
||||
Type arguments are a comma-separated list of type expression wrapped in brackets.
|
||||
|
||||
Returns:
|
||||
list[Type]: the list of type arguments, if any, or an empty list
|
||||
"""
|
||||
args: list[Type] = []
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments")
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||
args.append(self.type_expr())
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
|
||||
return args
|
||||
|
||||
def named_type(self) -> NamedType:
|
||||
"""Parse a named type expression
|
||||
|
||||
A named type is an identifier token
|
||||
|
||||
Returns:
|
||||
NamedType: the parsed named type expression
|
||||
"""
|
||||
name: Token = self.consume_identifier("Expected type name")
|
||||
return NamedType(
|
||||
location=name.get_location(),
|
||||
name=name,
|
||||
)
|
||||
|
||||
def complex_type(self) -> ComplexType:
|
||||
"""Parse a complex type expression
|
||||
|
||||
A complex type consists of zero or more member statements enclosed in
|
||||
curly braces
|
||||
|
||||
Returns:
|
||||
ComplexType: the parsed complex type expression
|
||||
"""
|
||||
left: Token = self.consume(
|
||||
TokenType.LEFT_BRACE, "Expected '{' to start type body"
|
||||
)
|
||||
members: list[MemberStmt] = []
|
||||
# TODO: add keyword to differentiate properties and methods,
|
||||
# and allow multiple methods with the same name but not properties
|
||||
names: set[str] = set()
|
||||
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
||||
member: MemberStmt = self.member_stmt()
|
||||
# if member.name.lexeme in names:
|
||||
# raise self.error(member.name, "Duplicate property")
|
||||
# names.add(member.name.lexeme)
|
||||
members.append(member)
|
||||
right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
|
||||
return ComplexType(
|
||||
location=left.location_to(right),
|
||||
members=members,
|
||||
)
|
||||
|
||||
def frame_type(self) -> FrameType:
|
||||
"""Parse a frame type expression
|
||||
|
||||
A frame type consists of:
|
||||
- the `Frame` identifier
|
||||
- an opening bracket `[`
|
||||
- a list of comma-separated column expression consisting of:
|
||||
- a name (token)
|
||||
- a colon `:`
|
||||
- a type expression (see :func:`type_expr`)
|
||||
- a closing bracket `]`
|
||||
|
||||
Returns:
|
||||
FrameType: the parsed frame type
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
self.consume(TokenType.LEFT_BRACKET, "Expected '[' to start frame schema")
|
||||
|
||||
columns: list[FrameType.Column] = []
|
||||
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
|
||||
name: Token = self.advance()
|
||||
self.consume(TokenType.COLON, "Expected ':' between column name and type")
|
||||
type: Type = self.type_expr()
|
||||
columns.append(
|
||||
FrameType.Column(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
type=type,
|
||||
)
|
||||
)
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Unclosed frame schema")
|
||||
|
||||
return FrameType(
|
||||
location=keyword.location_to(self.previous()),
|
||||
columns=columns,
|
||||
)
|
||||
|
||||
def constraint(self) -> Expr:
|
||||
"""Parse a constraint expression
|
||||
|
||||
A constraint is an expression (see :func:`expression`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed constraint expression
|
||||
"""
|
||||
return self.expression()
|
||||
|
||||
def expression(self) -> Expr:
|
||||
"""Parse an expression
|
||||
|
||||
An expression consists of a logical AND expression (see :func:`and_`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
return self.and_()
|
||||
|
||||
def and_(self) -> Expr:
|
||||
"""Parse a logical AND expression
|
||||
|
||||
An AND consists of one or more equality expressions (see :func:`equality`)
|
||||
separated by logical AND operators (`&`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.equality()
|
||||
while self.match(TokenType.AND):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.equality()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = LogicalExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
return expr
|
||||
|
||||
def equality(self) -> Expr:
|
||||
"""Parse an equality expression
|
||||
|
||||
An equality consists of one or more comparison expressions (see :func:`comparison`)
|
||||
separated by equality operators (`==`, `!=`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.comparison()
|
||||
while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.comparison()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = BinaryExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
return expr
|
||||
|
||||
def comparison(self) -> Expr:
|
||||
"""Parse a comparison expression
|
||||
|
||||
A comparison consists of one or more term expressions (see :func:`term`)
|
||||
separated by comparison operators (`<`, `<=`, `>`, `>=`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.term()
|
||||
while self.match(
|
||||
TokenType.LESS,
|
||||
TokenType.LESS_EQUAL,
|
||||
TokenType.GREATER,
|
||||
TokenType.GREATER_EQUAL,
|
||||
):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.term()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = BinaryExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
return expr
|
||||
|
||||
def term(self) -> Expr:
|
||||
"""Parse a term expression
|
||||
|
||||
A term consists of one or more factor expressions (see :func:`factor`)
|
||||
separated by weak arithmetic operators (`+`, `-`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.factor()
|
||||
while self.match(TokenType.PLUS, TokenType.MINUS):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.factor()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = BinaryExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
return expr
|
||||
|
||||
def factor(self) -> Expr:
|
||||
"""Parse a factor expression
|
||||
|
||||
A factor consists of one or more unary expressions (see :func:`unary`)
|
||||
separated by strong arithmetic operators (`*`, `/`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.unary()
|
||||
while self.match(TokenType.STAR, TokenType.SLASH):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.unary()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = BinaryExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
return expr
|
||||
|
||||
def unary(self) -> Expr:
|
||||
"""Parse a unary expression
|
||||
|
||||
A unary consists of a call expression (see :func:`call`) optionally
|
||||
preceded by zero or more unary operators (`+`, `-`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
if self.match(TokenType.PLUS, TokenType.MINUS):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.unary()
|
||||
location: Location = Location.span(operator.get_location(), right.location)
|
||||
return UnaryExpr(location=location, operator=operator, right=right)
|
||||
return self.call()
|
||||
|
||||
def call(self) -> Expr:
|
||||
"""Parse a call expression
|
||||
|
||||
A call consists of a reference expression (see :func:`reference`)
|
||||
optionally followed by zero or more argument groups.
|
||||
|
||||
Argument groups are parenthesize, comma-separated list of arguments (see :func:`finish_call`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.reference()
|
||||
while self.match(TokenType.LEFT_PAREN):
|
||||
expr = self.finish_call(expr)
|
||||
return expr
|
||||
|
||||
def finish_call(self, callee: Expr) -> Expr:
|
||||
"""Parse an argument group, i.e. the arguments of a call
|
||||
|
||||
Arguments are either passed positionally or by name (keyword argument).
|
||||
All positional arguments must come before any keyword argument and
|
||||
vice-versa. Arguments are separated by commas.
|
||||
|
||||
A positional argument simply consists of an expression (see :func:`expression`)
|
||||
|
||||
A keyword argument consists of and identifier, followed by the equal `=`
|
||||
token and an expression (see :func:`expression`).
|
||||
|
||||
Args:
|
||||
callee (Expr): the callee expression
|
||||
|
||||
Raises:
|
||||
ParsingError: if a positional argument is passed after a keyword
|
||||
argument or if a keyword argument's name is invalid (i.e. not
|
||||
an identifier)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed call expression
|
||||
"""
|
||||
pos_args: list[Expr] = []
|
||||
kw_args: dict[str, Expr] = {}
|
||||
keywords: bool = False
|
||||
while not self.check(TokenType.RIGHT_PAREN):
|
||||
if self.check_identifier() and self.check_next(TokenType.EQUAL):
|
||||
keywords = True
|
||||
keyword: Token = self.advance()
|
||||
self.advance()
|
||||
value: Expr = self.expression()
|
||||
name: str = keyword.lexeme
|
||||
if name in kw_args:
|
||||
self.error(
|
||||
self.peek(),
|
||||
f"Multiple values passed for '{name}', only the last occurrence will be used",
|
||||
)
|
||||
kw_args[name] = value
|
||||
else:
|
||||
value = self.expression()
|
||||
if self.check(TokenType.EQUAL):
|
||||
error_msg: str
|
||||
if keywords:
|
||||
error_msg = "Invalid keyword argument name"
|
||||
else:
|
||||
error_msg = (
|
||||
"Cannot pass positional arguments after a keyword argument"
|
||||
)
|
||||
raise self.error(self.peek(), error_msg)
|
||||
pos_args.append(value)
|
||||
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
|
||||
r_paren: Token = self.consume(
|
||||
TokenType.RIGHT_PAREN, "Expected ')' after arguments."
|
||||
)
|
||||
return CallExpr(
|
||||
location=Location.span(callee.location, r_paren.get_location()),
|
||||
callee=callee,
|
||||
arguments=pos_args,
|
||||
keywords=kw_args,
|
||||
)
|
||||
|
||||
def reference(self) -> Expr:
|
||||
"""Parse a reference expression
|
||||
|
||||
A reference consists of a primary expression (see :func:`primary`)
|
||||
optionally followed by zero or more attribute accesses.
|
||||
|
||||
An attribute access consists of a dot `.` token followed by an identifier
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.primary()
|
||||
while self.match(TokenType.DOT):
|
||||
name: Token = self.consume_identifier("Expected property name after '.'")
|
||||
location: Location = Location.span(expr.location, name.get_location())
|
||||
expr = GetExpr(location=location, expr=expr, name=name)
|
||||
return expr
|
||||
|
||||
def primary(self) -> Expr:
|
||||
"""Parse a primary expression
|
||||
|
||||
This includes literals (booleans, numbers, etc.), wildcards, identifiers
|
||||
and grouped expressions
|
||||
|
||||
Raises:
|
||||
ParsingError: if a primary expressions cannot be parsed from the
|
||||
following tokens
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
token: Token = self.peek()
|
||||
if self.match(TokenType.FALSE):
|
||||
return LiteralExpr(location=token.get_location(), value=False)
|
||||
if self.match(TokenType.TRUE):
|
||||
return LiteralExpr(location=token.get_location(), value=True)
|
||||
if self.match(TokenType.NONE):
|
||||
return LiteralExpr(location=token.get_location(), value=None)
|
||||
|
||||
if self.match(TokenType.NUMBER):
|
||||
return LiteralExpr(location=token.get_location(), value=token.value)
|
||||
|
||||
if self.match(TokenType.STRING):
|
||||
return LiteralExpr(location=token.get_location(), value=token.value)
|
||||
|
||||
if self.match_identifier():
|
||||
return VariableExpr(location=token.get_location(), name=token)
|
||||
|
||||
if self.match(TokenType.UNDERSCORE):
|
||||
return WildcardExpr(location=token.get_location(), token=token)
|
||||
|
||||
if self.match(TokenType.LEFT_PAREN):
|
||||
expr: Expr = self.constraint()
|
||||
right: Token = self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
|
||||
return GroupingExpr(location=token.location_to(right), expr=expr)
|
||||
|
||||
raise self.error(self.peek(), "Expected expression")
|
||||
|
||||
def consume_identifier(self, message: str = "Expected identifier") -> Token:
|
||||
"""Consume the current token if it is a valid identifier or raise an error (see :func:`check_identifier`)
|
||||
|
||||
If the current token is not a valid identifier, an error is raised
|
||||
with the provided message
|
||||
|
||||
Args:
|
||||
message (str, optional): the error message. Defaults to "Expected identifier".
|
||||
|
||||
Raises:
|
||||
ParsingError: if the current token is not a valid identifier
|
||||
|
||||
Returns:
|
||||
Token: the current token which is a valid identifier
|
||||
"""
|
||||
if not self.match_identifier():
|
||||
raise self.error(self.peek(), message)
|
||||
return self.previous()
|
||||
|
||||
def match_identifier(self) -> bool:
|
||||
"""Consume the next token if it is a valid identifier (see :func:`check_identifier`)
|
||||
|
||||
Returns:
|
||||
bool: whether a token was matched and consumed
|
||||
"""
|
||||
return self.match(TokenType.IDENTIFIER, *KEYWORDS.values())
|
||||
|
||||
def check_identifier(self) -> bool:
|
||||
"""Check whether the current token is a valid identifier
|
||||
|
||||
A valid identifier is either an identifier token or a keyword token.
|
||||
This function always returns False if the parser is at the EOF token
|
||||
|
||||
Returns:
|
||||
bool: True if the current token is a valid identifier and not EOF
|
||||
"""
|
||||
for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]:
|
||||
if self.check(tt):
|
||||
return True
|
||||
return False
|
||||
|
||||
def member_stmt(self) -> MemberStmt:
|
||||
"""Parse a member statement
|
||||
|
||||
A member statement is written consists of:
|
||||
- the `prop` (for a property) or `def` (for a method) keyword
|
||||
- an name (identifier)
|
||||
- a colon `:`
|
||||
- a type expression (see :func:`type_expr`)
|
||||
|
||||
Raises:
|
||||
ParsingError: if the first token is neither `prop` nor `def`
|
||||
|
||||
Returns:
|
||||
MemberStmt: the parsed member statement
|
||||
"""
|
||||
kind: MemberKind
|
||||
if self.match(TokenType.PROP):
|
||||
kind = MemberKind.PROPERTY
|
||||
elif self.match(TokenType.DEF):
|
||||
kind = MemberKind.METHOD
|
||||
else:
|
||||
raise self.error(self.peek(), "Expected 'prop' or 'def'")
|
||||
|
||||
name: Token = self.consume_identifier("Expected member name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after member name")
|
||||
|
||||
type: Type = self.type_expr()
|
||||
return MemberStmt(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
type=type,
|
||||
kind=kind,
|
||||
)
|
||||
|
||||
def extend_declaration(self) -> ExtendStmt:
|
||||
"""Parse an extension definition
|
||||
|
||||
An extension statement consists of:
|
||||
- the `extend` keyword
|
||||
- a type name (identifier)
|
||||
- (optional) type parameters (see :func:`type_params`)
|
||||
- an opening brace `{`
|
||||
- zero or more member statements (see :func:`member_stmt`)
|
||||
- a closing brace `}`
|
||||
|
||||
Returns:
|
||||
ExtendStmt: the parsed extension statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume_identifier("Expected type name")
|
||||
params: list[TypeParam] = self.type_params()
|
||||
|
||||
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
|
||||
members: list[MemberStmt] = []
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
|
||||
members.append(self.member_stmt())
|
||||
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
|
||||
location: Location = keyword.location_to(self.previous())
|
||||
return ExtendStmt(
|
||||
location=location,
|
||||
name=name,
|
||||
params=params,
|
||||
members=members,
|
||||
)
|
||||
|
||||
def predicate_declaration(self) -> PredicateStmt:
|
||||
"""Parse a predicate declaration
|
||||
|
||||
A predicate statement consists of:
|
||||
- the `predicate` keyword
|
||||
- a name (identifier)
|
||||
- (optional) zero or more parameter specs (see :func:`function_params`)
|
||||
- an equal sign `=`
|
||||
- a body, a constraint expression (see :func:`constraint`)
|
||||
|
||||
Returns:
|
||||
PredicateStmt: the parsed predicate declaration statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
|
||||
name: Token = self.consume_identifier("Expected predicate name")
|
||||
|
||||
params: list[ParamSpec] = []
|
||||
while self.check(TokenType.LEFT_PAREN):
|
||||
params.append(self.function_params())
|
||||
|
||||
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
|
||||
body: Expr = self.constraint()
|
||||
return PredicateStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
name=name,
|
||||
params=params,
|
||||
body=body,
|
||||
)
|
||||
|
||||
def function(self) -> FunctionType:
|
||||
"""Parse a function type expression
|
||||
|
||||
A function consists of:
|
||||
- the `fn` keyword
|
||||
- a parameter spec (see :func:`function_params`)
|
||||
- the arrow keyword `->`
|
||||
- a result type expression (see :func:`type_expr`)
|
||||
|
||||
Returns:
|
||||
FunctionType: the parsed function type expression
|
||||
"""
|
||||
params: ParamSpec = self.function_params()
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
result: Type = self.type_expr()
|
||||
|
||||
return FunctionType(
|
||||
location=params.l_paren.location_to(self.previous()),
|
||||
params=params,
|
||||
returns=result,
|
||||
)
|
||||
|
||||
def function_params(self) -> ParamSpec:
|
||||
"""Parse a parameter spec
|
||||
|
||||
A parameter spec consists of zero or more comma-separated parameters,
|
||||
wrapped in parentheses.
|
||||
|
||||
Like in Python, it can contain positional-only, mixed and keyword-only
|
||||
parameters (separated by `/` and `*`).
|
||||
|
||||
Each parameter has a type (see :func:`type_expr`),
|
||||
preceded by a name (identifier) and a colon `:` (not required for
|
||||
positional-only parameters).
|
||||
|
||||
Returns:
|
||||
ParamSpec: the parsed parameter spec
|
||||
"""
|
||||
l_paren: Token = self.consume(
|
||||
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
||||
)
|
||||
pos: list[FunctionType.Parameter] = []
|
||||
mixed: list[FunctionType.Parameter] = []
|
||||
kw: list[FunctionType.Parameter] = []
|
||||
|
||||
mixed_first_tokens: list[Token] = []
|
||||
|
||||
section: int = 0
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
|
||||
match section:
|
||||
case 0 if self.match(TokenType.SLASH):
|
||||
pos = mixed
|
||||
mixed = []
|
||||
mixed_first_tokens = []
|
||||
section = 1
|
||||
case 0 | 1 if self.match(TokenType.STAR):
|
||||
section = 2
|
||||
case _:
|
||||
# Record first token of mixed parameters for errors if unnamed
|
||||
if section != 2:
|
||||
mixed_first_tokens.append(self.peek())
|
||||
|
||||
name: Optional[Token] = None
|
||||
if section == 2:
|
||||
name = self.consume_identifier(
|
||||
"Expected keyword parameter name"
|
||||
)
|
||||
self.consume(
|
||||
TokenType.COLON, "Expected ':' after parameter name"
|
||||
)
|
||||
elif self.check_identifier() and self.check_next(TokenType.COLON):
|
||||
name = self.advance()
|
||||
self.advance()
|
||||
|
||||
type: Type = self.type_expr()
|
||||
optional: bool = self.match(TokenType.QMARK)
|
||||
param = FunctionType.Parameter(
|
||||
location=None,
|
||||
name=name,
|
||||
type=type,
|
||||
required=not optional,
|
||||
)
|
||||
if section == 2:
|
||||
kw.append(param)
|
||||
else:
|
||||
mixed.append(param)
|
||||
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
|
||||
for param, token in zip(mixed, mixed_first_tokens):
|
||||
if param.name is None:
|
||||
# Not raised because we can keep parsing
|
||||
self.error(token, "Unnamed mixed parameter")
|
||||
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
||||
return ParamSpec(l_paren=l_paren, pos=pos, mixed=mixed, kw=kw)
|
||||
566
midas/parser/python.py
Normal file
566
midas/parser/python.py
Normal file
@@ -0,0 +1,566 @@
|
||||
import ast
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.python import (
|
||||
AssignStmt,
|
||||
BaseType,
|
||||
BinaryExpr,
|
||||
CallExpr,
|
||||
CastExpr,
|
||||
CompareExpr,
|
||||
ConstraintType,
|
||||
DictExpr,
|
||||
Expr,
|
||||
ExpressionStmt,
|
||||
ForStmt,
|
||||
FrameColumn,
|
||||
FrameType,
|
||||
Function,
|
||||
GetExpr,
|
||||
IfStmt,
|
||||
ListExpr,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
MidasType,
|
||||
ParamSpec,
|
||||
RawExpr,
|
||||
RawStmt,
|
||||
ReturnStmt,
|
||||
SliceExpr,
|
||||
Stmt,
|
||||
SubscriptExpr,
|
||||
TernaryExpr,
|
||||
TupleExpr,
|
||||
TypeAssign,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
)
|
||||
|
||||
|
||||
class InvalidSyntaxError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedSyntaxError(Exception):
|
||||
def __init__(self, expr: ast.expr) -> None:
|
||||
super().__init__(
|
||||
f"Unsupported syntax at L{expr.lineno}:{expr.col_offset}: {ast.unparse(expr)}"
|
||||
)
|
||||
|
||||
|
||||
class PythonParser:
|
||||
"""A parser to convert raw Python `ast` nodes in custom IR nodes"""
|
||||
|
||||
CAST_FUNCTION = "cast"
|
||||
UNSAFE_CAST_FUNCTION = "unsafe_cast"
|
||||
|
||||
def parse_module(self, node: ast.Module) -> list[Stmt]:
|
||||
statements: list[Stmt] = []
|
||||
for stmt in node.body:
|
||||
try:
|
||||
parsed: None | Stmt | list[Stmt] = self.parse_stmt(stmt)
|
||||
if isinstance(parsed, Stmt):
|
||||
statements.append(parsed)
|
||||
elif parsed is not None:
|
||||
statements.extend(parsed)
|
||||
except UnsupportedSyntaxError as e:
|
||||
print(f"{e}, skipping")
|
||||
continue
|
||||
return statements
|
||||
|
||||
def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]:
|
||||
location: Location = Location.from_ast(node)
|
||||
match node:
|
||||
case ast.AnnAssign():
|
||||
return self.parse_annotation_assign(node)
|
||||
|
||||
case ast.Assign():
|
||||
return self.parse_assign(node)
|
||||
|
||||
case ast.AugAssign():
|
||||
return self.parse_aug_assign(node)
|
||||
|
||||
case ast.FunctionDef():
|
||||
return self.parse_function(node)
|
||||
|
||||
case ast.Expr(value=expr):
|
||||
return ExpressionStmt(
|
||||
location=location,
|
||||
expr=self.parse_expr(expr),
|
||||
)
|
||||
|
||||
case ast.Return(value=value):
|
||||
return ReturnStmt(
|
||||
location=location,
|
||||
value=self.parse_expr(value) if value is not None else None,
|
||||
)
|
||||
|
||||
case ast.If():
|
||||
return self.parse_if(node)
|
||||
|
||||
case ast.Pass():
|
||||
return None
|
||||
|
||||
case ast.For(orelse=[]):
|
||||
return self.parse_for(node)
|
||||
|
||||
case _:
|
||||
print(f"Unsupported statement: {ast.unparse(node)}")
|
||||
return RawStmt(location=location, stmt=node)
|
||||
|
||||
def parse_annotation_assign(self, node: ast.AnnAssign) -> list[Stmt]:
|
||||
statements: list[Stmt] = []
|
||||
loc: Location = Location.from_ast(node)
|
||||
match node:
|
||||
case ast.AnnAssign(
|
||||
target=ast.Name(id=target),
|
||||
annotation=annotation,
|
||||
value=value,
|
||||
simple=1,
|
||||
):
|
||||
type = self._parse_type(annotation)
|
||||
statements.append(
|
||||
TypeAssign(
|
||||
location=loc,
|
||||
name=target,
|
||||
type=type,
|
||||
)
|
||||
)
|
||||
|
||||
if value is not None:
|
||||
statements.append(
|
||||
AssignStmt(
|
||||
location=loc,
|
||||
targets=[
|
||||
VariableExpr(
|
||||
location=Location.from_ast(node.target), name=target
|
||||
),
|
||||
],
|
||||
value=self.parse_expr(value),
|
||||
),
|
||||
)
|
||||
case _:
|
||||
print(f"Unsupported annotation: {ast.unparse(node)}")
|
||||
return statements
|
||||
|
||||
def parse_assign(self, node: ast.Assign) -> AssignStmt:
|
||||
targets: list[Expr] = []
|
||||
for target in node.targets:
|
||||
targets.append(self.parse_expr(target))
|
||||
value: Expr = self.parse_expr(node.value)
|
||||
return AssignStmt(
|
||||
location=Location.from_ast(node),
|
||||
targets=targets,
|
||||
value=value,
|
||||
)
|
||||
|
||||
def parse_aug_assign(self, node: ast.AugAssign) -> AssignStmt:
|
||||
location: Location = Location.from_ast(node)
|
||||
target: Expr = self.parse_expr(node.target)
|
||||
value: Expr = self.parse_expr(node.value)
|
||||
return AssignStmt(
|
||||
location=location,
|
||||
targets=[target],
|
||||
value=BinaryExpr(
|
||||
location=location,
|
||||
left=target,
|
||||
operator=node.op,
|
||||
right=value,
|
||||
),
|
||||
)
|
||||
|
||||
def parse_if(self, node: ast.If) -> IfStmt:
|
||||
body: list[Stmt] = []
|
||||
for stmt in node.body:
|
||||
stmts = self.parse_stmt(stmt)
|
||||
if isinstance(stmts, Stmt):
|
||||
body.append(stmts)
|
||||
elif stmts is not None:
|
||||
body.extend(stmts)
|
||||
|
||||
orelse: list[Stmt] = []
|
||||
for stmt in node.orelse:
|
||||
stmts = self.parse_stmt(stmt)
|
||||
if isinstance(stmts, Stmt):
|
||||
orelse.append(stmts)
|
||||
elif stmts is not None:
|
||||
orelse.extend(stmts)
|
||||
|
||||
return IfStmt(
|
||||
location=Location.from_ast(node),
|
||||
test=self.parse_expr(node.test),
|
||||
body=body,
|
||||
orelse=orelse,
|
||||
)
|
||||
|
||||
def parse_for(self, node: ast.For) -> ForStmt:
|
||||
body: list[Stmt] = []
|
||||
for stmt in node.body:
|
||||
stmts = self.parse_stmt(stmt)
|
||||
if isinstance(stmts, Stmt):
|
||||
body.append(stmts)
|
||||
elif stmts is not None:
|
||||
body.extend(stmts)
|
||||
|
||||
return ForStmt(
|
||||
location=Location.from_ast(node),
|
||||
target=self.parse_expr(node.target),
|
||||
iterator=self.parse_expr(node.iter),
|
||||
body=body,
|
||||
)
|
||||
|
||||
def parse_function(self, node: ast.FunctionDef) -> Function:
|
||||
loc: Location = Location.from_ast(node)
|
||||
match node:
|
||||
case ast.FunctionDef(
|
||||
name=name,
|
||||
args=args,
|
||||
returns=returns,
|
||||
body=raw_body,
|
||||
):
|
||||
body: list[Stmt] = []
|
||||
for stmt in raw_body:
|
||||
stmts = self.parse_stmt(stmt)
|
||||
if isinstance(stmts, Stmt):
|
||||
body.append(stmts)
|
||||
elif stmts is not None:
|
||||
body.extend(stmts)
|
||||
|
||||
return Function(
|
||||
location=loc,
|
||||
name=name,
|
||||
params=self._parse_param_spec(args),
|
||||
returns=self._parse_type(returns) if returns is not None else None,
|
||||
body=body,
|
||||
)
|
||||
case _:
|
||||
print(f"Unsupported function definition: {ast.unparse(node)}")
|
||||
|
||||
def _parse_param_spec(self, args: ast.arguments) -> ParamSpec:
|
||||
def parse_params(
|
||||
args_list: list[ast.arg], defaults: list[Optional[Expr]]
|
||||
) -> list[Function.Parameter]:
|
||||
return [
|
||||
self._parse_function_parameter(arg, default)
|
||||
for arg, default in zip(args_list, defaults)
|
||||
]
|
||||
|
||||
defaults: list[ast.expr] = args.defaults
|
||||
parsed_defaults: list[Optional[Expr]] = [
|
||||
self.parse_expr(default) for default in defaults
|
||||
]
|
||||
n_pos: int = len(args.posonlyargs)
|
||||
n_mixed: int = len(args.args)
|
||||
n_all_pos = n_pos + n_mixed
|
||||
parsed_defaults = [
|
||||
None,
|
||||
] * (n_all_pos - len(defaults)) + parsed_defaults
|
||||
|
||||
pos_defaults: list[Optional[Expr]] = parsed_defaults[:n_pos]
|
||||
mixed_defaults: list[Optional[Expr]] = parsed_defaults[n_pos:]
|
||||
kw_defaults: list[Optional[Expr]] = [
|
||||
self.parse_expr(default) if default is not None else None
|
||||
for default in args.kw_defaults
|
||||
]
|
||||
|
||||
return ParamSpec(
|
||||
pos=parse_params(args.posonlyargs, pos_defaults),
|
||||
mixed=parse_params(args.args, mixed_defaults),
|
||||
kw=parse_params(args.kwonlyargs, kw_defaults),
|
||||
)
|
||||
|
||||
def _parse_function_parameter(
|
||||
self, arg: ast.arg, default: Optional[Expr]
|
||||
) -> Function.Parameter:
|
||||
loc: Location = Location.from_ast(arg)
|
||||
name: str = arg.arg
|
||||
type: Optional[MidasType] = None
|
||||
if arg.annotation is not None:
|
||||
type = self._parse_type(arg.annotation)
|
||||
return Function.Parameter(
|
||||
location=loc,
|
||||
name=name,
|
||||
type=type,
|
||||
default=default,
|
||||
)
|
||||
|
||||
def _parse_type(self, type_expr: ast.expr) -> MidasType:
|
||||
loc: Location = Location.from_ast(type_expr)
|
||||
match type_expr:
|
||||
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
|
||||
return self._parse_frame_type(schema)
|
||||
|
||||
case ast.Subscript(value=ast.Name(id=name), slice=arg):
|
||||
args: tuple[MidasType, ...] = (
|
||||
tuple(self._parse_type(a) for a in arg.elts)
|
||||
if isinstance(arg, ast.Tuple)
|
||||
else (self._parse_type(arg),)
|
||||
)
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base=name,
|
||||
args=args,
|
||||
)
|
||||
|
||||
case ast.Name(id=name):
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base=name,
|
||||
args=(),
|
||||
)
|
||||
|
||||
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
|
||||
left = self._parse_type(left_expr)
|
||||
match left:
|
||||
# If chained constraints, separate base type and rebuild constraint
|
||||
case ConstraintType(type=left_type, constraint=left_constraint):
|
||||
constraint = ast.BinOp(
|
||||
left=left_constraint,
|
||||
op=ast.Add(),
|
||||
right=right_expr,
|
||||
)
|
||||
ast.copy_location(constraint, type_expr)
|
||||
return ConstraintType(
|
||||
location=loc,
|
||||
type=left_type,
|
||||
constraint=constraint,
|
||||
)
|
||||
|
||||
case _:
|
||||
return ConstraintType(
|
||||
location=loc,
|
||||
type=left,
|
||||
constraint=right_expr,
|
||||
)
|
||||
|
||||
case ast.Constant(value=None):
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base="None",
|
||||
args=(),
|
||||
)
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(type_expr)
|
||||
|
||||
def _parse_frame_type(self, schema: ast.expr) -> FrameType:
|
||||
loc: Location = Location.from_ast(schema)
|
||||
columns: list[FrameColumn] = []
|
||||
|
||||
match schema:
|
||||
case ast.Tuple(elts=cols):
|
||||
for col in cols:
|
||||
columns.append(self._parse_frame_column(col))
|
||||
|
||||
case ast.Slice() | ast.Name():
|
||||
columns.append(self._parse_frame_column(schema))
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(schema)
|
||||
|
||||
return FrameType(location=loc, columns=columns)
|
||||
|
||||
def _parse_frame_column(self, column: ast.expr) -> FrameColumn:
|
||||
loc: Location = Location.from_ast(column)
|
||||
match column:
|
||||
case ast.Name():
|
||||
return FrameColumn(
|
||||
location=loc,
|
||||
name=None,
|
||||
type=self._parse_type(column),
|
||||
)
|
||||
|
||||
case ast.Slice(lower=ast.Name(id=name), upper=type_expr):
|
||||
if name == "_":
|
||||
name = None
|
||||
|
||||
type: Optional[MidasType] = None
|
||||
match type_expr:
|
||||
case None:
|
||||
raise InvalidSyntaxError("Missing column type")
|
||||
case ast.Name(id="_"):
|
||||
type = None
|
||||
case ast.expr():
|
||||
type = self._parse_type(type_expr)
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(type_expr)
|
||||
return FrameColumn(location=loc, name=name, type=type)
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(column)
|
||||
|
||||
def parse_expr(self, node: ast.expr) -> Expr:
|
||||
location: Location = Location.from_ast(node)
|
||||
match node:
|
||||
case ast.BoolOp():
|
||||
return self.parse_bool_op(node)
|
||||
|
||||
case ast.BinOp(left=left, op=op, right=right):
|
||||
return BinaryExpr(
|
||||
location=location,
|
||||
left=self.parse_expr(left),
|
||||
operator=op,
|
||||
right=self.parse_expr(right),
|
||||
)
|
||||
|
||||
case ast.UnaryOp(op=op, operand=right):
|
||||
return UnaryExpr(
|
||||
location=location,
|
||||
operator=op,
|
||||
right=self.parse_expr(right),
|
||||
)
|
||||
|
||||
case ast.Compare():
|
||||
return self.parse_compare(node)
|
||||
|
||||
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
|
||||
return self.parse_cast(node)
|
||||
|
||||
case ast.Call(func=ast.Name(id=self.UNSAFE_CAST_FUNCTION)):
|
||||
return self.parse_cast(node)
|
||||
|
||||
case ast.Call():
|
||||
return self.parse_call(node)
|
||||
|
||||
case ast.IfExp():
|
||||
return self.parse_ternary(node)
|
||||
|
||||
case ast.Constant(value=value):
|
||||
return LiteralExpr(location=location, value=value)
|
||||
|
||||
case ast.Attribute(value=object, attr=name):
|
||||
return GetExpr(
|
||||
location=location,
|
||||
object=self.parse_expr(object),
|
||||
name=name,
|
||||
)
|
||||
|
||||
case ast.Name(id=name):
|
||||
return VariableExpr(location=location, name=name)
|
||||
|
||||
case ast.List(elts=items):
|
||||
return ListExpr(
|
||||
location=location,
|
||||
items=[self.parse_expr(item) for item in items],
|
||||
)
|
||||
|
||||
case ast.Dict(keys=keys, values=values):
|
||||
return DictExpr(
|
||||
location=location,
|
||||
keys=[
|
||||
self.parse_expr(key) if key is not None else None
|
||||
for key in keys
|
||||
],
|
||||
values=[self.parse_expr(value) for value in values],
|
||||
)
|
||||
|
||||
case ast.Subscript(value=value, slice=index):
|
||||
return SubscriptExpr(
|
||||
location=location,
|
||||
object=self.parse_expr(value),
|
||||
index=self.parse_expr(index),
|
||||
)
|
||||
|
||||
case ast.Slice(lower=lower, upper=upper, step=step):
|
||||
return SliceExpr(
|
||||
location=location,
|
||||
lower=self.parse_expr(lower) if lower is not None else None,
|
||||
upper=self.parse_expr(upper) if upper is not None else None,
|
||||
step=self.parse_expr(step) if step is not None else None,
|
||||
)
|
||||
|
||||
case ast.Tuple(elts=items):
|
||||
return TupleExpr(
|
||||
location=location,
|
||||
items=tuple(self.parse_expr(item) for item in items),
|
||||
)
|
||||
|
||||
case _:
|
||||
print(f"Unsupported expression: {ast.unparse(node)}")
|
||||
return RawExpr(location=location, expr=node)
|
||||
|
||||
def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr:
|
||||
op: ast.boolop = node.op
|
||||
rights: list[Expr] = [self.parse_expr(expr) for expr in node.values]
|
||||
expr: LogicalExpr = LogicalExpr(
|
||||
location=Location.span(
|
||||
rights[0].location,
|
||||
rights[1].location,
|
||||
),
|
||||
left=rights[0],
|
||||
operator=op,
|
||||
right=rights[1],
|
||||
)
|
||||
for right in rights[2:]:
|
||||
expr = LogicalExpr(
|
||||
location=Location.span(expr.location, right.location),
|
||||
left=expr,
|
||||
operator=op,
|
||||
right=right,
|
||||
)
|
||||
return expr
|
||||
|
||||
def parse_compare(self, node: ast.Compare) -> Expr:
|
||||
ops: list[ast.cmpop] = node.ops
|
||||
left: Expr = self.parse_expr(node.left)
|
||||
rights: list[Expr] = [self.parse_expr(expr) for expr in node.comparators]
|
||||
expr: Expr = CompareExpr(
|
||||
location=Location.span(
|
||||
left.location,
|
||||
rights[0].location,
|
||||
),
|
||||
left=left,
|
||||
operator=ops[0],
|
||||
right=rights[0],
|
||||
)
|
||||
for i, right in enumerate(rights[1:]):
|
||||
comparison = CompareExpr(
|
||||
location=Location.span(rights[i].location, right.location),
|
||||
left=rights[i],
|
||||
operator=ops[i],
|
||||
right=right,
|
||||
)
|
||||
expr = LogicalExpr(
|
||||
location=Location.span(expr.location, comparison.location),
|
||||
left=expr,
|
||||
operator=ast.And(),
|
||||
right=comparison,
|
||||
)
|
||||
return expr
|
||||
|
||||
def parse_cast(self, node: ast.Call) -> CastExpr:
|
||||
assert isinstance(node.func, ast.Name)
|
||||
func: str = node.func.id
|
||||
match node:
|
||||
case ast.Call(args=[type, expr], keywords=[]):
|
||||
return CastExpr(
|
||||
location=Location.from_ast(node),
|
||||
type=self._parse_type(type),
|
||||
expr=self.parse_expr(expr),
|
||||
unsafe=func == self.UNSAFE_CAST_FUNCTION,
|
||||
)
|
||||
case _:
|
||||
raise InvalidSyntaxError(
|
||||
f"Invalid call to {func}, expected type and expression"
|
||||
)
|
||||
|
||||
def parse_call(self, node: ast.Call) -> CallExpr:
|
||||
return CallExpr(
|
||||
location=Location.from_ast(node),
|
||||
callee=self.parse_expr(node.func),
|
||||
arguments=[self.parse_expr(arg) for arg in node.args],
|
||||
keywords={
|
||||
arg.arg: self.parse_expr(arg.value)
|
||||
for arg in node.keywords
|
||||
if arg.arg is not None # Should always be True, type checker happy
|
||||
},
|
||||
)
|
||||
|
||||
def parse_ternary(self, node: ast.IfExp) -> TernaryExpr:
|
||||
return TernaryExpr(
|
||||
location=Location.from_ast(node),
|
||||
test=self.parse_expr(node.test),
|
||||
if_true=self.parse_expr(node.body),
|
||||
if_false=self.parse_expr(node.orelse),
|
||||
)
|
||||
52
midas/typing.py
Normal file
52
midas/typing.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Generic, TypeVar
|
||||
from typing import cast as typing_cast
|
||||
|
||||
cast = typing_cast
|
||||
"""### Midas documentation
|
||||
Cast a value to a type.
|
||||
|
||||
- **Compile-time**: tells the type checker that the return value has the designated type.
|
||||
- **Run-time**: generates assertions to ensure the value can be interpreted as the given type.
|
||||
|
||||
---
|
||||
<br>
|
||||
<br>
|
||||
<br>
|
||||
|
||||
_**Internal Python documentation**_
|
||||
"""
|
||||
|
||||
|
||||
unsafe_cast = typing_cast
|
||||
"""### Midas documentation
|
||||
Cast a value to a type.
|
||||
|
||||
- **Compile-time**: tells the type checker that the return value has the designated type.
|
||||
- **Run-time**: -
|
||||
|
||||
This operation is unsound, use at your own risk!
|
||||
|
||||
---
|
||||
<br>
|
||||
<br>
|
||||
<br>
|
||||
|
||||
_**Internal Python documentation**_
|
||||
"""
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Frame(Generic[T]):
|
||||
"""A `Frame` is the abstract type implemented by `DataFrame`
|
||||
|
||||
A frame contains any number of named columns (see :class:`Column`)
|
||||
"""
|
||||
|
||||
|
||||
class Column(Generic[T]):
|
||||
"""A `Column` is the abstract type implemented by `Series`
|
||||
|
||||
A column contains a any number of values of the same type
|
||||
"""
|
||||
67
midas/utils.py
Normal file
67
midas/utils.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.checker.types import Type
|
||||
from midas.generator.collector import AssertionCollector
|
||||
|
||||
AllowRepeat = Callable[[object], bool]
|
||||
|
||||
|
||||
class UniversalJSONDumper:
|
||||
@classmethod
|
||||
def dump(
|
||||
cls,
|
||||
obj: Any,
|
||||
include_keys: Optional[list[str | tuple[str, str]]] = None,
|
||||
allow_repeat: Optional[AllowRepeat] = None,
|
||||
) -> Any:
|
||||
if include_keys is None:
|
||||
include_keys = []
|
||||
return cls._dump(obj, include_keys, allow_repeat, [])
|
||||
|
||||
@classmethod
|
||||
def _dump(
|
||||
cls,
|
||||
obj: Any,
|
||||
include_keys: list[str | tuple[str, str]],
|
||||
allow_repeat: Optional[AllowRepeat],
|
||||
visited: list[Any],
|
||||
) -> Any:
|
||||
if obj in visited:
|
||||
return None
|
||||
match obj:
|
||||
case str() | int() | float() | None:
|
||||
return obj
|
||||
case list() | set() | tuple():
|
||||
return [
|
||||
cls._dump(child, include_keys, allow_repeat, visited)
|
||||
for child in obj
|
||||
]
|
||||
case dict():
|
||||
return {
|
||||
str(k): cls._dump(v, include_keys, allow_repeat, visited)
|
||||
for k, v in obj.items()
|
||||
}
|
||||
case object():
|
||||
if allow_repeat is None or not allow_repeat(obj):
|
||||
visited.append(obj)
|
||||
return {
|
||||
"_type": obj.__class__.__name__,
|
||||
} | {
|
||||
k: cls._dump(v, include_keys, allow_repeat, visited)
|
||||
for k, v in obj.__dict__.items()
|
||||
if not k.startswith("_")
|
||||
or k in include_keys
|
||||
or (obj.__class__.__name__, k) in include_keys
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported value: {obj}")
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TypedAST:
|
||||
stmts: list[p.Stmt]
|
||||
judgements: list[tuple[p.Expr, Type]]
|
||||
evaluated_casts: list[p.CastExpr]
|
||||
assertions: AssertionCollector
|
||||
@@ -1,152 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.ast.annotations import (
|
||||
AnnotationStmt,
|
||||
ConstraintExpr,
|
||||
Expr,
|
||||
LiteralExpr,
|
||||
SchemaElementExpr,
|
||||
SchemaExpr,
|
||||
Stmt,
|
||||
TypeExpr,
|
||||
WildcardExpr,
|
||||
)
|
||||
from lexer.token import Token, TokenType
|
||||
from parser.base import Parser
|
||||
from parser.errors import ParsingError
|
||||
|
||||
|
||||
class AnnotationParser(Parser):
|
||||
"""A simple parser for custom type annotations"""
|
||||
|
||||
SYNC_BOUNDARY: set[TokenType] = set()
|
||||
|
||||
def parse(self) -> Optional[Stmt]:
|
||||
stmt: Optional[Stmt] = None
|
||||
try:
|
||||
stmt = self.annotation()
|
||||
except ParsingError:
|
||||
self.synchronize()
|
||||
if not self.is_at_end():
|
||||
self.error(self.peek(), "Extra tokens")
|
||||
return stmt
|
||||
|
||||
def synchronize(self):
|
||||
"""Skip tokens until a synchronization boundary is found
|
||||
|
||||
This method allows gracefully recovering from a parse error
|
||||
to a safe place and continue parsing
|
||||
"""
|
||||
self.advance()
|
||||
while not self.is_at_end():
|
||||
if self.peek().type in self.SYNC_BOUNDARY:
|
||||
return
|
||||
self.advance()
|
||||
|
||||
def annotation(self) -> AnnotationStmt:
|
||||
"""Parse an annotation
|
||||
|
||||
An annotation is written as `Type` or `Type[Schema]`
|
||||
|
||||
Returns:
|
||||
AnnotationStmt: the parsed annotation statement
|
||||
"""
|
||||
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type identifier")
|
||||
schema: Optional[SchemaExpr] = None
|
||||
if self.match(TokenType.LEFT_BRACKET):
|
||||
schema = self.schema()
|
||||
return AnnotationStmt(name=name, schema=schema)
|
||||
|
||||
def type_expr(self) -> TypeExpr:
|
||||
"""Parse a type expression
|
||||
|
||||
Returns:
|
||||
TypeExpr: the parsed type expression
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
constraints: list[ConstraintExpr] = []
|
||||
|
||||
while not self.is_at_end() and self.match(TokenType.PLUS):
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before type constraint")
|
||||
constraints.append(self.constraint_expr())
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after type constraint")
|
||||
|
||||
return TypeExpr(name=name, constraints=constraints)
|
||||
|
||||
def constraint_expr(self) -> ConstraintExpr:
|
||||
"""Parse a type constraint
|
||||
|
||||
Returns:
|
||||
ConstraintExpr: the parsed type constraint expression
|
||||
"""
|
||||
|
||||
left: Expr = self.constraint_value()
|
||||
op: Token = self.constraint_operator()
|
||||
right: Expr = self.constraint_value()
|
||||
return ConstraintExpr(left=left, op=op, right=right)
|
||||
|
||||
def constraint_value(self) -> Expr:
|
||||
if self.match(TokenType.UNDERSCORE):
|
||||
return WildcardExpr(self.previous())
|
||||
return self.literal()
|
||||
|
||||
def literal(self) -> LiteralExpr:
|
||||
if self.match(TokenType.FALSE):
|
||||
return LiteralExpr(False)
|
||||
if self.match(TokenType.TRUE):
|
||||
return LiteralExpr(True)
|
||||
if self.match(TokenType.NONE):
|
||||
return LiteralExpr(None)
|
||||
|
||||
if self.match(TokenType.NUMBER):
|
||||
return LiteralExpr(self.previous().value)
|
||||
|
||||
raise self.error(self.peek(), "Expected literal")
|
||||
|
||||
def constraint_operator(self) -> Token:
|
||||
if self.match(TokenType.LESS, TokenType.LESS_EQUAL, TokenType.GREATER, TokenType.GREATER_EQUAL, TokenType.EQUAL_EQUAL, TokenType.BANG_EQUAL):
|
||||
return self.previous()
|
||||
raise self.error(self.peek(), "Expected constraint operator")
|
||||
|
||||
def schema(self) -> SchemaExpr:
|
||||
"""Parse a schema definition
|
||||
|
||||
A comma separated list of schema elements
|
||||
|
||||
Returns:
|
||||
SchemaExpr: the parsed schema expression
|
||||
"""
|
||||
left: Token = self.previous()
|
||||
elements: list[Expr] = []
|
||||
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
|
||||
elements.append(self.schema_element())
|
||||
if not self.check(TokenType.RIGHT_BRACKET):
|
||||
self.consume(TokenType.COMMA, "Expected ',' between schema elements")
|
||||
|
||||
right: Token = self.consume(TokenType.RIGHT_BRACKET, "Unclosed schema")
|
||||
return SchemaExpr(left=left, elements=elements, right=right)
|
||||
|
||||
def schema_element(self) -> SchemaElementExpr:
|
||||
"""Parse a schema element
|
||||
|
||||
An anonymous element (`_`), a type, an untyped named column (`name: _`),
|
||||
or a named column (`name: Type`)
|
||||
|
||||
Returns:
|
||||
SchemaElementExpr: the parsed schema element expression
|
||||
"""
|
||||
if self.match(TokenType.UNDERSCORE):
|
||||
return SchemaElementExpr(name=None, type=None)
|
||||
|
||||
if not self.check(TokenType.IDENTIFIER):
|
||||
raise self.error(self.peek(), "Expected schema element")
|
||||
|
||||
name: Optional[Token] = None
|
||||
type: Optional[TypeExpr] = None
|
||||
if self.check_next(TokenType.COLON):
|
||||
name = self.advance()
|
||||
self.advance()
|
||||
if not self.match(TokenType.UNDERSCORE):
|
||||
type = self.type_expr()
|
||||
return SchemaElementExpr(name=name, type=type)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user