Compare commits
146 Commits
v0.0.1-pro
...
feat/subty
| Author | SHA1 | Date | |
|---|---|---|---|
|
e0179bc442
|
|||
|
e665d03533
|
|||
|
b8cb2b4273
|
|||
|
d278dc5f5b
|
|||
|
59e73f0fd9
|
|||
|
3e0dc60283
|
|||
|
c24eb5125e
|
|||
|
25bd895dde
|
|||
|
bccd75317e
|
|||
|
f0e3f7574f
|
|||
|
5d44081847
|
|||
|
2a2bb0aec7
|
|||
|
67c40a3909
|
|||
|
1c30188122
|
|||
|
82a0f13242
|
|||
| 288d15a9bc | |||
|
504703d0f7
|
|||
|
e48895d0af
|
|||
| 13d32d0d27 | |||
| 19b9fdd623 | |||
|
ddcaebb51a
|
|||
|
f182312cd2
|
|||
|
73b21789d5
|
|||
|
5d7c724bc8
|
|||
|
74b297c89c
|
|||
|
822a74acce
|
|||
|
9a934fabfd
|
|||
|
828ec9a3fa
|
|||
|
63a43d79dd
|
|||
|
029caf4526
|
|||
|
1c5c418f1c
|
|||
|
a4139d4652
|
|||
|
2fd2071d40
|
|||
|
97b1ee8ab8
|
|||
|
dee479def5
|
|||
|
c8536e20d2
|
|||
|
d70137775f
|
|||
|
35ceda99aa
|
|||
|
7f3d74ee49
|
|||
|
b9f378de6f
|
|||
|
ccb17c7290
|
|||
|
505779310a
|
|||
|
bea3f399ad
|
|||
|
55060bfecd
|
|||
|
dd126f2559
|
|||
|
4151f5373d
|
|||
|
bd31713ab4
|
|||
|
f4dc57cb96
|
|||
|
261fd47494
|
|||
|
1b66a8553d
|
|||
|
65164abadb
|
|||
|
9d45163d9c
|
|||
|
ab0fa1de1a
|
|||
|
5d4df7978b
|
|||
|
86ad348b99
|
|||
|
29f691e38a
|
|||
|
f2c61d24e2
|
|||
|
112ed0e816
|
|||
|
7eb1e13b70
|
|||
|
893e1ba190
|
|||
|
1a1b0e8e15
|
|||
|
4ddde364ed
|
|||
|
4a3363a3d6
|
|||
|
0a3216e07d
|
|||
|
c29c0ed3ec
|
|||
|
fa7e56cb77
|
|||
|
13c19db818
|
|||
|
95b218fbed
|
|||
|
c3722c7438
|
|||
|
9dd547d6c1
|
|||
|
e2d5943517
|
|||
|
86e4763a12
|
|||
|
89ec63cb05
|
|||
|
e6375f1aa9
|
|||
|
d16e192a3a
|
|||
|
3f61f84e5a
|
|||
|
fd5399f50a
|
|||
|
8906ac3db8
|
|||
|
022aebf55b
|
|||
|
5dc6903425
|
|||
|
1b078b832c
|
|||
|
7515716864
|
|||
|
218b0c5b78
|
|||
|
928901ef9c
|
|||
|
4b62c78874
|
|||
|
f882eebaf5
|
|||
|
a872938405
|
|||
|
146be72fd7
|
|||
|
6de54e1da1
|
|||
|
c82b41a4df
|
|||
|
8304760fe0
|
|||
|
6bf91db757
|
|||
|
3f6b650a4b
|
|||
| ec079f32ca | |||
|
6524b3591a
|
|||
|
170101aa37
|
|||
|
0b3f33d7fe
|
|||
|
8a9b4f3989
|
|||
|
bbd0e3ae8d
|
|||
|
4d23e8840e
|
|||
|
c64d626d1c
|
|||
|
ecab1b74a4
|
|||
|
0bbdf04621
|
|||
|
939e5af4ce
|
|||
|
a735113466
|
|||
|
0e0a1b26f2
|
|||
|
e94db2181f
|
|||
|
9b59058881
|
|||
|
d0c54db33a
|
|||
|
5aedddfabb
|
|||
|
8d7c115432
|
|||
|
832c350b61
|
|||
|
3d599b3462
|
|||
|
4f799caaf5
|
|||
|
f4d2be3b1b
|
|||
|
7ce2840f03
|
|||
|
e2f3cabe15
|
|||
|
5a112332f2
|
|||
|
eb79cf6dc3
|
|||
|
8a9bb6ef4e
|
|||
|
6e0190a378
|
|||
| b5969e9a2b | |||
|
409d9f8fa6
|
|||
|
12d762429d
|
|||
|
53929ee514
|
|||
|
2f6e137f1a
|
|||
|
5224e79d9f
|
|||
|
bdcb12c58a
|
|||
|
5cb4d587e3
|
|||
|
8f9ec8d73b
|
|||
|
c1c50a448e
|
|||
|
19229db0b1
|
|||
|
f3b6bd146f
|
|||
|
98c3510bd4
|
|||
|
429d0d98fe
|
|||
|
db8fe5d3ff
|
|||
|
7477ec8d70
|
|||
|
adf7f4e7a2
|
|||
|
abf6787946
|
|||
|
e282b08597
|
|||
|
0a02b9d3d9
|
|||
| 875ca589e4 | |||
|
88f92d6e1f
|
|||
|
db4ed74365
|
|||
|
7cbf4fdece
|
|||
|
1fa9a09bfe
|
5
.gitignore
vendored
5
.gitignore
vendored
@@ -3,4 +3,7 @@ __pycache__
|
||||
.env
|
||||
venv
|
||||
.venv
|
||||
*.pyc
|
||||
*.pyc
|
||||
uv.lock
|
||||
.python-version
|
||||
/out
|
||||
79
README.md
79
README.md
@@ -5,3 +5,82 @@
|
||||
*Midas* aims at providing Python developers with a simple annotation system to enable compile-time integrity and data type checks, as well as generating runtime assertions.
|
||||
|
||||
This framework is being developed as part of a Bachelor's Thesis by Louis Heredero at HEI Sion.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python 3.11+
|
||||
- [uv](https://docs.astral.sh/uv/getting-started/installation/)
|
||||
|
||||
## Installation
|
||||
|
||||
1. Clone the repository
|
||||
```shell
|
||||
git clone https://git.kb28.ch/HEL/midas.git
|
||||
```
|
||||
2. Go in the project directory
|
||||
```shell
|
||||
cd midas
|
||||
```
|
||||
3. Install the CLI as a user-wide tool
|
||||
```shell
|
||||
uv tool install .
|
||||
```
|
||||
4. You can now run the `midas` command from anywhere
|
||||
```shell
|
||||
midas --help
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
### Compiling
|
||||
|
||||
> [!NOTE]
|
||||
> In the current state of the project, the `compile` command doesn't generate any runnable code, it only runs the parsers and type checker on the provided files
|
||||
|
||||
```shell
|
||||
midas compile -t types.midas source.py
|
||||
```
|
||||
|
||||
With the `compile` command, you can process a source Python file, with any number of custom type definition files (`-t FILE` option), and the type checker will verify the coherence of your program and generate the runnable code with valid syntax and runtime assertions.
|
||||
|
||||
The optional `-l FILE` option lets you produce a highlighted version of the source code showing diagnostics from the type checker (see [Highlighting](#highlighting))
|
||||
|
||||
### Highlighting
|
||||
|
||||
```shell
|
||||
midas utils highlight source.py
|
||||
# or
|
||||
midas utils highlight types.midas
|
||||
```
|
||||
|
||||
The `highlight` command takes in a source file (Python or Midas), runs the appropriate parser and outputs an HTML file containing the source code with added highlighting. This highlighting takes the form of hoverable annotations showing some of the parsed structures (e.g. a function definition, an assignment, a generic type, etc.)
|
||||
|
||||
The optional `-o FILE` option can be used to specify an output path. By default, the file is printed in stdout (equivalent to `-o -`).
|
||||
|
||||
### Dumping the AST
|
||||
|
||||
```shell
|
||||
midas utils dump-ast source.py
|
||||
# or
|
||||
midas utils dump-ast types.midas
|
||||
```
|
||||
|
||||
For debugging purposes, you can output the AST parsed from a Python or Midas file. For Python files, the `-p` flags lets you toggle the custom AST parsing. Without `-p`, the raw AST is returned, as produced by the builtin `ast` module. This flag has no effect on Midas files.
|
||||
|
||||
The optional `-o FILE` option can be used to specify an output path. By default, the file is printed in stdout (equivalent to `-o -`).
|
||||
|
||||
## Tests
|
||||
|
||||
Several snapshot tests are available to assert the good behaviour of the parsers and type checker. They can be run as follows:
|
||||
|
||||
```shell
|
||||
uv run -m tests.midas run -a
|
||||
uv run -m tests.python run -a
|
||||
uv run -m tests.checker run -a
|
||||
```
|
||||
|
||||
**Available subcommands:**
|
||||
- Run all tests: `run -a`
|
||||
- Run specific tests: `run tests/cases/test1.py tests/cases/test2.py ...`
|
||||
- Update all tests: `update -a`
|
||||
- Update specific tests: `update tests/cases/test1.py tests/cases/test2.py ...`
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from lexer.token import Token
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Stmt(ABC):
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_annotation_stmt(self, stmt: AnnotationStmt) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AnnotationStmt(Stmt):
|
||||
name: Token
|
||||
schema: Optional[SchemaExpr]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_annotation_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Expr(ABC):
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_type_expr(self, expr: TypeExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint_expr(self, expr: ConstraintExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_schema_expr(self, expr: SchemaExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_schema_element_expr(self, expr: SchemaElementExpr) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WildcardExpr(Expr):
|
||||
token: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_wildcard_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LiteralExpr(Expr):
|
||||
value: Any
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_literal_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeExpr(Expr):
|
||||
name: Token
|
||||
constraints: list[ConstraintExpr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_type_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstraintExpr(Expr):
|
||||
left: Expr
|
||||
op: Token
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SchemaExpr(Expr):
|
||||
left: Token
|
||||
elements: list[Expr]
|
||||
right: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_schema_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SchemaElementExpr(Expr):
|
||||
name: Optional[Token]
|
||||
type: Optional[Expr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_schema_element_expr(self)
|
||||
@@ -1,138 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from lexer.token import Token
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# Statements
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Stmt(ABC):
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_op_stmt(self, stmt: OpStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint_stmt(self, stmt: ConstraintStmt) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeStmt(Stmt):
|
||||
name: Token
|
||||
bases: list[TypeExpr]
|
||||
body: Optional[TypeBodyExpr]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_type_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PropertyStmt(Stmt):
|
||||
name: Token
|
||||
type: TypeExpr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_property_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OpStmt(Stmt):
|
||||
left: TypeExpr
|
||||
op: Token
|
||||
right: TypeExpr
|
||||
result: TypeExpr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_op_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstraintStmt(Stmt):
|
||||
name: Token
|
||||
constraint: ConstraintExpr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_stmt(self)
|
||||
|
||||
|
||||
# Expressions
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Expr(ABC):
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_type_expr(self, expr: TypeExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint_expr(self, expr: ConstraintExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_type_body_expr(self, expr: TypeBodyExpr) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WildcardExpr(Expr):
|
||||
token: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_wildcard_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LiteralExpr(Expr):
|
||||
value: Any
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_literal_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeExpr(Expr):
|
||||
name: Token
|
||||
constraints: list[ConstraintExpr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_type_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstraintExpr(Expr):
|
||||
left: Expr
|
||||
op: Token
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeBodyExpr(Expr):
|
||||
properties: list[PropertyStmt]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_type_body_expr(self)
|
||||
@@ -1,360 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
import io
|
||||
from typing import Generator, Generic, Optional, Protocol, TypeVar
|
||||
|
||||
import core.ast.annotations as a
|
||||
import core.ast.midas as m
|
||||
|
||||
|
||||
class _Level(Enum):
|
||||
EMPTY = auto()
|
||||
ACTIVE = auto()
|
||||
LAST = auto()
|
||||
|
||||
|
||||
class Expr(Protocol):
|
||||
def accept(self, printer: AstPrinter) -> None: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Expr)
|
||||
|
||||
|
||||
class AstPrinter(Generic[T]):
|
||||
LAST_CHILD = "└── "
|
||||
CHILD = "├── "
|
||||
VERTICAL = "│ "
|
||||
EMPTY = " "
|
||||
|
||||
def __init__(self):
|
||||
self._levels: list[_Level] = []
|
||||
self._idx: Optional[int] = None
|
||||
self._buf: io.StringIO = io.StringIO()
|
||||
|
||||
def print(self, expr: T):
|
||||
self._buf = io.StringIO()
|
||||
expr.accept(self)
|
||||
return self._buf.getvalue()
|
||||
|
||||
@contextmanager
|
||||
def _child_level(self, last: bool = False) -> Generator[None, None, None]:
|
||||
self._levels.append(_Level.LAST if last else _Level.ACTIVE)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._levels.pop()
|
||||
|
||||
def _mark_last(self):
|
||||
if self._levels:
|
||||
self._levels[-1] = _Level.LAST
|
||||
|
||||
def _write_line(self, text: str, *, last: bool = False):
|
||||
if last:
|
||||
self._mark_last()
|
||||
indent: str = self._build_indent()
|
||||
if self._idx is not None:
|
||||
text = f"[{self._idx}] {text}"
|
||||
self._idx = None
|
||||
self._buf.write(indent + text + "\n")
|
||||
|
||||
def _build_indent(self) -> str:
|
||||
parts: list[str] = []
|
||||
for level in self._levels[:-1]:
|
||||
parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL)
|
||||
if self._levels:
|
||||
if self._levels[-1] == _Level.LAST:
|
||||
parts.append(self.LAST_CHILD)
|
||||
self._levels[-1] = _Level.EMPTY
|
||||
else:
|
||||
parts.append(self.CHILD)
|
||||
return "".join(parts)
|
||||
|
||||
def _write_optional_child(
|
||||
self, label: str, child: Optional[T], *, last: bool = False
|
||||
):
|
||||
if last:
|
||||
self._mark_last()
|
||||
if child is None:
|
||||
self._write_line(f"{label}: None")
|
||||
else:
|
||||
self._write_line(label)
|
||||
with self._child_level(last=True):
|
||||
child.accept(self)
|
||||
|
||||
|
||||
class AnnotationAstPrinter(AstPrinter, a.Expr.Visitor[None], a.Stmt.Visitor[None]):
|
||||
def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> None:
|
||||
self._write_line("AnnotationStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_optional_child("schema", stmt.schema, last=True)
|
||||
|
||||
def visit_type_expr(self, expr: a.TypeExpr):
|
||||
self._write_line("TypeExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"')
|
||||
self._write_line("constraints", last=True)
|
||||
with self._child_level():
|
||||
for i, constraint in enumerate(expr.constraints):
|
||||
self._idx = i
|
||||
if i == len(expr.constraints) - 1:
|
||||
self._mark_last()
|
||||
constraint.accept(self)
|
||||
|
||||
def visit_constraint_expr(self, expr: a.ConstraintExpr) -> None:
|
||||
self._write_line("ConstraintExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.op.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_schema_expr(self, expr: a.SchemaExpr):
|
||||
self._write_line("SchemaExpr")
|
||||
with self._child_level():
|
||||
for i, elmt in enumerate(expr.elements):
|
||||
self._idx = i
|
||||
if i == len(expr.elements) - 1:
|
||||
self._mark_last()
|
||||
elmt.accept(self)
|
||||
|
||||
def visit_schema_element_expr(self, expr: a.SchemaElementExpr):
|
||||
self._write_line("SchemaElementExpr")
|
||||
with self._child_level():
|
||||
name_text: str = "None" if expr.name is None else f'"{expr.name.lexeme}"'
|
||||
self._write_line(f"name: {name_text}")
|
||||
self._write_optional_child("type", expr.type, last=True)
|
||||
|
||||
def visit_wildcard_expr(self, expr: a.WildcardExpr) -> None:
|
||||
self._write_line("WildcardExpr")
|
||||
|
||||
def visit_literal_expr(self, expr: a.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"value: {expr.value}", last=True)
|
||||
|
||||
|
||||
class AnnotationPrinter(a.Expr.Visitor[str], a.Stmt.Visitor[str]):
|
||||
def print(self, expr: a.Expr | a.Stmt):
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> str:
|
||||
schema: str = ""
|
||||
if stmt.schema is not None:
|
||||
schema = stmt.schema.accept(self)
|
||||
return f"{stmt.name.lexeme}{schema}"
|
||||
|
||||
def visit_type_expr(self, expr: a.TypeExpr) -> str:
|
||||
parts: list[str] = [expr.name.lexeme]
|
||||
for constraint in expr.constraints:
|
||||
parts.append("(" + constraint.accept(self) + ")")
|
||||
return " + ".join(parts)
|
||||
|
||||
def visit_constraint_expr(self, expr: a.ConstraintExpr) -> str:
|
||||
parts: list[str] = [
|
||||
expr.left.accept(self),
|
||||
expr.op.lexeme,
|
||||
expr.right.accept(self),
|
||||
]
|
||||
return " ".join(parts)
|
||||
|
||||
def visit_schema_expr(self, expr: a.SchemaExpr) -> str:
|
||||
res: str = expr.left.lexeme
|
||||
res += ", ".join(elmt.accept(self) for elmt in expr.elements)
|
||||
res += expr.right.lexeme
|
||||
return res
|
||||
|
||||
def visit_schema_element_expr(self, expr: a.SchemaElementExpr) -> str:
|
||||
parts: list[str] = []
|
||||
if expr.name is not None:
|
||||
parts.append(expr.name.lexeme)
|
||||
|
||||
if expr.type is None:
|
||||
parts.append("_")
|
||||
else:
|
||||
parts.append(expr.type.accept(self))
|
||||
return ": ".join(parts)
|
||||
|
||||
def visit_wildcard_expr(self, expr: a.WildcardExpr) -> str:
|
||||
return "_"
|
||||
|
||||
def visit_literal_expr(self, expr: a.LiteralExpr) -> str:
|
||||
return str(expr.value)
|
||||
|
||||
|
||||
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt):
|
||||
self._write_line("TypeStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("bases")
|
||||
with self._child_level():
|
||||
for i, base in enumerate(stmt.bases):
|
||||
self._idx = i
|
||||
if i == len(stmt.bases) - 1:
|
||||
self._mark_last()
|
||||
base.accept(self)
|
||||
self._write_optional_child("body", stmt.body, last=True)
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||
self._write_line("PropertyStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
|
||||
self._write_line("OpStmt")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
stmt.left.accept(self)
|
||||
|
||||
self._write_line(f'op: "{stmt.op.lexeme}"')
|
||||
|
||||
self._write_line("right")
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
stmt.right.accept(self)
|
||||
|
||||
self._write_line("result", last=True)
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
stmt.result.accept(self)
|
||||
|
||||
def visit_constraint_stmt(self, stmt: m.ConstraintStmt):
|
||||
self._write_line("ConstraintStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("constraint", last=True)
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
stmt.constraint.accept(self)
|
||||
|
||||
def visit_type_expr(self, expr: m.TypeExpr):
|
||||
self._write_line("TypeExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"')
|
||||
self._write_line("constraints", last=True)
|
||||
with self._child_level():
|
||||
for i, constraint in enumerate(expr.constraints):
|
||||
self._idx = i
|
||||
if i == len(expr.constraints) - 1:
|
||||
self._mark_last()
|
||||
constraint.accept(self)
|
||||
|
||||
def visit_constraint_expr(self, expr: m.ConstraintExpr):
|
||||
self._write_line("ConstraintExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.op.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_type_body_expr(self, expr: m.TypeBodyExpr):
|
||||
self._write_line("TypeBodyExpr")
|
||||
with self._child_level():
|
||||
self._write_line("properties", last=True)
|
||||
with self._child_level():
|
||||
for i, property in enumerate(expr.properties):
|
||||
self._idx = i
|
||||
if i == len(expr.properties) - 1:
|
||||
self._mark_last()
|
||||
property.accept(self)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||
self._write_line("WildcardExpr")
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"value: {expr.value}", last=True)
|
||||
|
||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
||||
def __init__(self, indent: int = 4):
|
||||
self.indent: int = indent
|
||||
self.level: int = 0
|
||||
|
||||
def indented(self, text: str) -> str:
|
||||
return " " * (self.level * self.indent) + text
|
||||
|
||||
def print(self, expr: m.Expr | m.Stmt):
|
||||
self.level = 0
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt):
|
||||
bases: list[str] = [
|
||||
b.accept(self)
|
||||
for b in stmt.bases
|
||||
]
|
||||
|
||||
res: str = self.indented(f"type {stmt.name.lexeme}<{', '.join(bases)}>")
|
||||
if stmt.body is not None:
|
||||
res += " {\n"
|
||||
self.level += 1
|
||||
res += stmt.body.accept(self)
|
||||
self.level -= 1
|
||||
res += "\n" + self.indented("}")
|
||||
|
||||
return res
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||
return f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt):
|
||||
left: str = stmt.left.accept(self)
|
||||
op: str = stmt.op.lexeme
|
||||
right: str = stmt.right.accept(self)
|
||||
result: str = stmt.result.accept(self)
|
||||
return self.indented(f"op <{left}> {op} <{right}> = <{result}>")
|
||||
|
||||
def visit_constraint_stmt(self, stmt: m.ConstraintStmt):
|
||||
name: str = stmt.name.lexeme
|
||||
constraint: str = stmt.constraint.accept(self)
|
||||
return self.indented(f"constraint {name} = {constraint}")
|
||||
|
||||
def visit_type_expr(self, expr: m.TypeExpr):
|
||||
parts: list[str] = [expr.name.lexeme]
|
||||
for constraint in expr.constraints:
|
||||
parts.append("(" + constraint.accept(self) + ")")
|
||||
return " + ".join(parts)
|
||||
|
||||
def visit_constraint_expr(self, expr: m.ConstraintExpr):
|
||||
parts: list[str] = [
|
||||
expr.left.accept(self),
|
||||
expr.op.lexeme,
|
||||
expr.right.accept(self),
|
||||
]
|
||||
return " ".join(parts)
|
||||
|
||||
def visit_type_body_expr(self, expr: m.TypeBodyExpr):
|
||||
properties: list[str] = [
|
||||
self.indented(prop.accept(self))
|
||||
for prop in expr.properties
|
||||
]
|
||||
return "\n".join(properties)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr):
|
||||
return "_"
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr):
|
||||
return str(expr.value)
|
||||
150
docs/architecture.typ
Normal file
150
docs/architecture.typ
Normal file
@@ -0,0 +1,150 @@
|
||||
#import "@preview/cetz:0.5.2": canvas, draw
|
||||
|
||||
#let diagram-only = false
|
||||
|
||||
#set document(
|
||||
title: [Midas Architecture],
|
||||
//author: "Louis Heredero",
|
||||
)
|
||||
|
||||
#set text(
|
||||
font: "Source Sans 3",
|
||||
)
|
||||
|
||||
#let diagram = canvas({
|
||||
let framed = draw.content.with(
|
||||
padding: (x: .8em, y: 1em),
|
||||
frame: "rect",
|
||||
stroke: black,
|
||||
)
|
||||
let arrow = draw.line.with(mark: (end: ">", fill: black))
|
||||
framed(
|
||||
(0, 0),
|
||||
name: "python-parser",
|
||||
)[Python parser]
|
||||
|
||||
draw.content(
|
||||
(rel: (0, 1), to: "python-parser.north"),
|
||||
padding: 5pt,
|
||||
anchor: "south",
|
||||
name: "source-py",
|
||||
)[_`source.py`_]
|
||||
arrow("source-py", "python-parser")
|
||||
|
||||
framed(
|
||||
(rel: (3, 0), to: "python-parser.east"),
|
||||
anchor: "west",
|
||||
name: "custom-parser",
|
||||
align(center)[Custom python\ parser],
|
||||
)
|
||||
|
||||
arrow("python-parser", "custom-parser", name: "arrow-python-ast")
|
||||
draw.content(
|
||||
"arrow-python-ast",
|
||||
anchor: "south",
|
||||
padding: 5pt,
|
||||
)[`ast.Module`]
|
||||
|
||||
framed(
|
||||
(rel: (-3, -2), to: "custom-parser.south"),
|
||||
anchor: "east",
|
||||
name: "python-resolver",
|
||||
)[Python Resolver]
|
||||
arrow(
|
||||
"custom-parser",
|
||||
((), "|-", "python-resolver.east"),
|
||||
"python-resolver",
|
||||
name: "arrow-python-custom-ast",
|
||||
)
|
||||
draw.content(
|
||||
(rel: (1.5, 0), to: "arrow-python-custom-ast.end"),
|
||||
padding: 5pt,
|
||||
anchor: "south",
|
||||
)[P-AST#footnote[#strong[P]ython *AST*]<fn-past>]
|
||||
draw.content(
|
||||
"python-resolver.west",
|
||||
padding: 5pt,
|
||||
anchor: "south-east",
|
||||
)[Resolved P-AST@fn-past]
|
||||
|
||||
draw.circle(
|
||||
(rel: (1, -2), to: "custom-parser.south-east"),
|
||||
radius: .4,
|
||||
name: "midas-loader",
|
||||
)
|
||||
arrow(
|
||||
"custom-parser",
|
||||
"midas-loader",
|
||||
name: "arrow-load-midas",
|
||||
mark: (end: (symbol: ">", fill: black), start: "o"),
|
||||
)
|
||||
draw.content(
|
||||
"arrow-load-midas",
|
||||
anchor: "west",
|
||||
padding: 5pt,
|
||||
)[```python midas.using("types.midas")```]
|
||||
|
||||
framed(
|
||||
(rel: (0, -2), to: "midas-loader.south"),
|
||||
name: "midas-parser",
|
||||
)[Midas lexer/parser]
|
||||
arrow("midas-loader", "midas-parser", name: "arrow-midas-source")
|
||||
draw.content(
|
||||
"arrow-midas-source",
|
||||
anchor: "west",
|
||||
padding: 5pt,
|
||||
)[_`types.midas`_]
|
||||
|
||||
|
||||
framed(
|
||||
(rel: (-2, 0), to: "midas-parser.west"),
|
||||
anchor: "east",
|
||||
name: "midas-resolver",
|
||||
)[Midas Resolver]
|
||||
arrow("midas-parser", "midas-resolver", name: "arrow-midas-ast")
|
||||
draw.content(
|
||||
"arrow-midas-ast",
|
||||
anchor: "south",
|
||||
padding: 5pt,
|
||||
)[M-AST#footnote[#strong[M]idas *AST*]<fn-mast>]
|
||||
|
||||
framed(
|
||||
(rel: (-3, 0), to: "midas-resolver.west"),
|
||||
anchor: "east",
|
||||
name: "checker",
|
||||
)[Checker]
|
||||
arrow("midas-resolver", "checker", name: "arrow-type-ctx")
|
||||
arrow(
|
||||
"python-resolver",
|
||||
((), "-|", "checker.north"),
|
||||
"checker",
|
||||
)
|
||||
draw.content(
|
||||
"arrow-type-ctx",
|
||||
anchor: "south",
|
||||
padding: 5pt,
|
||||
)[Types context]
|
||||
})
|
||||
|
||||
#show: doc => if diagram-only {
|
||||
set page(width: auto, height: auto, margin: .5cm)
|
||||
diagram
|
||||
} else { doc }
|
||||
|
||||
#align(center, title())
|
||||
|
||||
#v(1cm)
|
||||
|
||||
#figure(
|
||||
diagram,
|
||||
caption: [Midas type-checker architecture],
|
||||
)
|
||||
|
||||
== Components
|
||||
|
||||
- *Python parser*: builtin Python AST parser, extracts abstract syntax from the raw Python source (```python ast.parse(...)```)
|
||||
- *Custom python parser*: converts the raw Python AST into custom, more suitable constructs, especially for type annotations
|
||||
- *Python resolver*: resolves bindings and references, tracks binding scopes
|
||||
- *Midas lexer/parser*: parses a Midas type definition file and extracts its AST
|
||||
- *Midas resolver*: walks the AST and fills the environment with the defined types and operations
|
||||
- *Checker*: evaluates expressions and checks type coherence
|
||||
@@ -2,10 +2,6 @@
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
# Prototype of custom type import to use valid Python syntax
|
||||
import midas
|
||||
midas.using("02_custom_types.midas")
|
||||
|
||||
# A data-frame using a custom type
|
||||
df: Frame[
|
||||
location: GeoLocation
|
||||
@@ -21,7 +17,7 @@ lat + lon # Invalid operation
|
||||
# Registered operations are permitted
|
||||
lat1: Latitude = lat[0]
|
||||
lat2: Latitude = lat[1]
|
||||
lat_diff: LatitudeDiff = lat2 - lat1 # Valid operation
|
||||
lat_diff: Difference[Latitude] = lat2 - lat1 # Valid operation
|
||||
|
||||
# In addition to the type, a column can have one or more constraints, either defined inline or in a separate file
|
||||
df2: Frame[
|
||||
|
||||
73
examples/00_syntax_prototype/03_custom_types_v2.midas
Normal file
73
examples/00_syntax_prototype/03_custom_types_v2.midas
Normal file
@@ -0,0 +1,73 @@
|
||||
// Simple custom type derived from float
|
||||
type Custom(float)
|
||||
|
||||
// Simple custom types with constraints
|
||||
type Latitude(float) where (-90 <= _ <= 90)
|
||||
type Longitude(float) where (-180 <= _ <= 180)
|
||||
|
||||
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
|
||||
type Difference[T](T)
|
||||
|
||||
// Complex custom type, containing two values accessible through properties
|
||||
type GeoLocation {
|
||||
lat: Latitude
|
||||
lon: Longitude
|
||||
}
|
||||
|
||||
// Define operations on our custom type
|
||||
extend GeoLocation {
|
||||
// This type is compatible with the `-` operation with another GeoLocation
|
||||
// i.e. you can subtract a GeoLocation from another GeoLocation, resulting
|
||||
// in a Difference of GeoLocations
|
||||
op __sub__(GeoLocation) -> Difference[GeoLocation]
|
||||
}
|
||||
|
||||
// For complex generics, you need to specify how the genericity the properties
|
||||
// are handled
|
||||
type Difference[GeoLocation] {
|
||||
lat: Difference[Latitude]
|
||||
lon: Difference[Longitude]
|
||||
}
|
||||
|
||||
// Simple operation defined on our custom types
|
||||
extend Latitude {
|
||||
op __sub__(Latitude) -> Difference[Latitude]
|
||||
}
|
||||
|
||||
extend Longitude {
|
||||
op __sub__(Longitude) -> Difference[Longitude]
|
||||
}
|
||||
|
||||
// Predefined custom predicates that can be referenced in other definitions
|
||||
predicate Positive(v: float) = v >= 0
|
||||
predicate StrictlyPositive(v: float) = v > 0
|
||||
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
|
||||
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
|
||||
|
||||
type Person {
|
||||
name: str
|
||||
|
||||
// Property with an inline constraint
|
||||
age: int? where (0 <= _ < 150)
|
||||
|
||||
// Property referencing a predicate
|
||||
height: float where StrictlyPositive
|
||||
|
||||
home: GeoLocation
|
||||
}
|
||||
|
||||
// Custom complex type derived from another complex type, with a constraint
|
||||
// on a property
|
||||
// Multiple proposed syntaxes, not yet defined
|
||||
|
||||
// Explicit, but new keyword
|
||||
type EquatorialPerson refines Person where Equatorial(_.home)
|
||||
|
||||
// Explicit with existing keyword, might be confusing if expectations regarding 'is'
|
||||
type EquatorialPerson is Person where Equatorial(_.home)
|
||||
|
||||
// Consistent and Python-friendly but can be confused with structural extension
|
||||
type EquatorialPerson(Person) where Equatorial(_.home)
|
||||
|
||||
// Allow new properties, probably not useful
|
||||
type EquatorialPerson extends Person where Equatorial(_.home)
|
||||
15
examples/00_syntax_prototype/04_functions.py
Normal file
15
examples/00_syntax_prototype/04_functions.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def func(
|
||||
col1: Column[float + (0 <= _ <= 1)],
|
||||
col2: Column[float + (0 <= _ <= 1)],
|
||||
) -> Column[float + (0 <= _ <= 2)]:
|
||||
result: Column[float + (0 <= _ <= 2)] = col1 + col2
|
||||
return result
|
||||
|
||||
|
||||
def func2(a: int, /, b: float, *, c: str):
|
||||
pass
|
||||
33
examples/00_syntax_prototype/05_custom_types_v3.midas
Normal file
33
examples/00_syntax_prototype/05_custom_types_v3.midas
Normal file
@@ -0,0 +1,33 @@
|
||||
type Foo1 = float
|
||||
type Foo2 = float where (_ > 3)
|
||||
type Foo3 = int | float
|
||||
type Foo4 = int where (_ > 3) | float where (_ > 3)
|
||||
type Foo5 = (int | float) where (_ > 3)
|
||||
type Foo6 = {
|
||||
foo: float
|
||||
bar: float where (_ > 3)
|
||||
}
|
||||
|
||||
type Foo7[T] = T where (_ > 3)
|
||||
type Foo8[A, B<:int] = {
|
||||
a: A
|
||||
b: B
|
||||
}
|
||||
|
||||
type Complex = {
|
||||
a: int
|
||||
b: int
|
||||
}
|
||||
type Complex2 = Complex where (_.a > 3 & _.b < 5)
|
||||
|
||||
predicate Positive(n: int) = n >= 0
|
||||
|
||||
extend Foo1 {
|
||||
op __add__(Foo1) -> Foo1
|
||||
}
|
||||
|
||||
extend Foo7[T] {
|
||||
op __add__(Foo7[T]) -> Foo7[T]
|
||||
}
|
||||
|
||||
type Optional[T] = None | T
|
||||
11
examples/01_simple_type_checking/01_simple_operations.py
Normal file
11
examples/01_simple_type_checking/01_simple_operations.py
Normal file
@@ -0,0 +1,11 @@
|
||||
a: int = 3
|
||||
b: int = 4
|
||||
|
||||
c = a + b # -> int
|
||||
|
||||
c = "invalid" # -> can't assign str to int variable
|
||||
|
||||
d = True
|
||||
e = d + d
|
||||
|
||||
f: float = a
|
||||
14
examples/01_simple_type_checking/02_simple_types.midas
Normal file
14
examples/01_simple_type_checking/02_simple_types.midas
Normal file
@@ -0,0 +1,14 @@
|
||||
type Meter = float
|
||||
type Second = float
|
||||
type MeterPerSecond = float
|
||||
|
||||
extend Meter {
|
||||
op __add__(Meter) -> Meter
|
||||
op __sub__(Meter) -> Meter
|
||||
op __truediv__(Second) -> MeterPerSecond
|
||||
}
|
||||
|
||||
extend Second {
|
||||
op __add__(Second) -> Second
|
||||
op __sub__(Second) -> Second
|
||||
}
|
||||
6
examples/01_simple_type_checking/02_simple_types.py
Normal file
6
examples/01_simple_type_checking/02_simple_types.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
distance: Meter = cast(Meter, 123.45)
|
||||
time: Second = cast(Second, 6.7)
|
||||
speed = distance / time
|
||||
23
examples/01_simple_type_checking/03_control_flow.py
Normal file
23
examples/01_simple_type_checking/03_control_flow.py
Normal file
@@ -0,0 +1,23 @@
|
||||
def minimum(x: int, y: int):
|
||||
if x < y:
|
||||
return x
|
||||
else:
|
||||
return y
|
||||
|
||||
|
||||
a = 15
|
||||
b = 72
|
||||
c = minimum(a, b)
|
||||
|
||||
|
||||
def factorial(n: int) -> int:
|
||||
if n <= 1:
|
||||
return 1
|
||||
return n * factorial(n - 1)
|
||||
|
||||
|
||||
category = "Category 1" if a < 10 else "Category 2"
|
||||
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
11
examples/01_simple_type_checking/04_complex_types.midas
Normal file
11
examples/01_simple_type_checking/04_complex_types.midas
Normal file
@@ -0,0 +1,11 @@
|
||||
type Meter = float
|
||||
|
||||
extend Meter {
|
||||
op __add__(Meter) -> Meter
|
||||
op __sub__(Meter) -> Meter
|
||||
}
|
||||
|
||||
type Coordinate = {
|
||||
x: Meter
|
||||
y: Meter
|
||||
}
|
||||
11
examples/01_simple_type_checking/04_complex_types.py
Normal file
11
examples/01_simple_type_checking/04_complex_types.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
p1: Coordinate
|
||||
p2: Coordinate
|
||||
|
||||
diff_x = p2.x - p1.x
|
||||
diff_y = p2.y - p1.y
|
||||
|
||||
dist = diff_x + diff_y
|
||||
|
||||
p2.x += cast(Meter, 1)
|
||||
146
gen/gen.py
Normal file
146
gen/gen.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
HEADER = '''"""
|
||||
This file was generated by a script. Any manual changes might be overwritten.
|
||||
Please modify {defs_path} instead and run {gen_path}
|
||||
"""'''
|
||||
|
||||
SECTION_TEMPLATE = """{banner}
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class {base}(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
{visitor_methods}
|
||||
|
||||
|
||||
{classes}"""
|
||||
|
||||
TEMPLATE = """{header}
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
{imports}
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
{sections}
|
||||
"""
|
||||
|
||||
VISITOR_METHOD_TEMPLATE = """
|
||||
@abstractmethod
|
||||
def visit_{func_name}(self, {param}: {cls}) -> T: ...
|
||||
"""
|
||||
|
||||
CLASS_TEMPLATE = """
|
||||
@dataclass(frozen=True)
|
||||
class {cls}({base}):
|
||||
{body}
|
||||
|
||||
def accept(self, visitor: {base}.Visitor[T]) -> T:
|
||||
return visitor.visit_{func_name}(self)
|
||||
"""
|
||||
|
||||
SECTION_REGEX = re.compile(
|
||||
r"^###>\s*(?P<base>[^\n]*?)\s*\|\s*(?P<name>[^\n]*?)(\s*\|\s*(?P<param>[^\n]*?))?\s*?\n(?P<body>.*?)\n###<$",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
IMPORTS_REGEX = re.compile(
|
||||
r"^###>\s*Imports\s*?\n(?P<body>.*?)\n###<$",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def snake_case(text: str) -> str:
|
||||
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
|
||||
|
||||
|
||||
def make_visitor_method(cls: str, param: str):
|
||||
method: str = VISITOR_METHOD_TEMPLATE.format(
|
||||
func_name=snake_case(cls), param=param, cls=cls
|
||||
)
|
||||
return method.strip("\n")
|
||||
|
||||
|
||||
def make_class(name: str, cls: str, base: str):
|
||||
body: str = cls.split("\n", 1)[1]
|
||||
func_name: str = snake_case(name)
|
||||
cls_def: str = CLASS_TEMPLATE.format(
|
||||
cls=name,
|
||||
base=base,
|
||||
body=body,
|
||||
func_name=func_name,
|
||||
)
|
||||
return cls_def.strip("\n")
|
||||
|
||||
|
||||
def make_banner(text: str) -> str:
|
||||
middle: str = f"# {text} #"
|
||||
rule: str = "#" * len(middle)
|
||||
return "\n".join((rule, middle, rule))
|
||||
|
||||
|
||||
def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
||||
visitor_methods: list[str] = []
|
||||
classes: list[str] = []
|
||||
definitions: list[str] = body.strip("\n").split("\n\n\n")
|
||||
for cls in definitions:
|
||||
cls = cls.strip("\n")
|
||||
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
|
||||
print(f"Processing {name}")
|
||||
visitor_methods.append(make_visitor_method(name, param))
|
||||
classes.append(make_class(name, cls, base))
|
||||
|
||||
return SECTION_TEMPLATE.format(
|
||||
banner=make_banner(full_name),
|
||||
base=base,
|
||||
visitor_methods="\n\n".join(visitor_methods),
|
||||
classes="\n\n\n".join(classes),
|
||||
)
|
||||
|
||||
|
||||
def generate(definitions_path: Path, out_path: Path):
|
||||
root_dir: Path = Path(__file__).parent.parent
|
||||
rel_path: Path = definitions_path.relative_to(root_dir)
|
||||
src: str = definitions_path.read_text()
|
||||
sections: list[str] = []
|
||||
|
||||
imports: str = ""
|
||||
if m := IMPORTS_REGEX.search(src):
|
||||
imports = m.group("body").strip("\n")
|
||||
|
||||
for section_m in SECTION_REGEX.finditer(src):
|
||||
full_name: str = section_m.group("name")
|
||||
base: str = section_m.group("base")
|
||||
param: str = section_m.group("param") or base.lower()
|
||||
body: str = section_m.group("body")
|
||||
sections.append(make_section(full_name, base, param, body))
|
||||
|
||||
result: str = TEMPLATE.format(
|
||||
header=HEADER.format(
|
||||
defs_path=rel_path,
|
||||
gen_path=Path(__file__).relative_to(root_dir),
|
||||
),
|
||||
imports=imports,
|
||||
sections="\n\n\n".join(sections),
|
||||
)
|
||||
out_path.write_text(result)
|
||||
|
||||
|
||||
def main():
|
||||
root: Path = Path(__file__).parent.parent
|
||||
defs_dir: Path = root / "gen"
|
||||
ast_dir: Path = root / "midas" / "ast"
|
||||
generate(defs_dir / "midas.py", ast_dir / "midas.py")
|
||||
generate(defs_dir / "python.py", ast_dir / "python.py")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
118
gen/midas.py
Normal file
118
gen/midas.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821, F401]
|
||||
|
||||
###> Imports
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.lexer.token import Token
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Stmt | Statements
|
||||
class TypeStmt:
|
||||
name: Token
|
||||
params: list[Param]
|
||||
type: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Param:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
|
||||
class PropertyStmt:
|
||||
name: Token
|
||||
type: Type
|
||||
|
||||
|
||||
class ExtendStmt:
|
||||
type: Type
|
||||
operations: list[OpStmt]
|
||||
|
||||
|
||||
class OpStmt:
|
||||
name: Token
|
||||
operand: Type
|
||||
result: Type
|
||||
|
||||
|
||||
class PredicateStmt:
|
||||
name: Token
|
||||
subject: Token
|
||||
type: Type
|
||||
condition: Expr
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Expr | Expressions
|
||||
|
||||
|
||||
class LogicalExpr:
|
||||
left: Expr
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
|
||||
class BinaryExpr:
|
||||
left: Expr
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
|
||||
class UnaryExpr:
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
|
||||
class GetExpr:
|
||||
expr: Expr
|
||||
name: Token
|
||||
|
||||
|
||||
class VariableExpr:
|
||||
name: Token
|
||||
|
||||
|
||||
class GroupingExpr:
|
||||
expr: Expr
|
||||
|
||||
|
||||
class LiteralExpr:
|
||||
value: Any
|
||||
|
||||
|
||||
class WildcardExpr:
|
||||
token: Token
|
||||
|
||||
|
||||
###<
|
||||
|
||||
###> Type | Types
|
||||
|
||||
|
||||
class NamedType:
|
||||
name: Token
|
||||
|
||||
|
||||
class GenericType:
|
||||
type: Type
|
||||
params: list[Type]
|
||||
|
||||
|
||||
class ConstraintType:
|
||||
type: Type
|
||||
constraint: Expr
|
||||
|
||||
|
||||
class ComplexType:
|
||||
properties: list[PropertyStmt]
|
||||
|
||||
|
||||
###<
|
||||
142
gen/python.py
Normal file
142
gen/python.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821, F401]
|
||||
|
||||
###> Imports
|
||||
import ast
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from midas.ast.location import Location
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> MidasType | Type annotations | node
|
||||
class BaseType:
|
||||
base: str
|
||||
param: Optional[MidasType]
|
||||
|
||||
|
||||
class ConstraintType:
|
||||
type: MidasType
|
||||
constraint: ast.expr
|
||||
|
||||
|
||||
class FrameColumn:
|
||||
name: Optional[str]
|
||||
type: Optional[MidasType]
|
||||
|
||||
|
||||
class FrameType:
|
||||
columns: list[FrameColumn]
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Stmt | Statements
|
||||
class ExpressionStmt:
|
||||
expr: Expr
|
||||
|
||||
|
||||
class Function:
|
||||
name: str
|
||||
posonlyargs: list[Argument]
|
||||
args: list[Argument]
|
||||
sink: Optional[Argument]
|
||||
kwonlyargs: list[Argument]
|
||||
kw_sink: Optional[Argument]
|
||||
returns: Optional[MidasType]
|
||||
body: list[Stmt]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
location: Optional[Location] = None
|
||||
name: str
|
||||
type: Optional[MidasType]
|
||||
default: Optional[Expr]
|
||||
|
||||
@property
|
||||
def all_args(self) -> list[Argument]:
|
||||
return self.posonlyargs + self.args + self.kwonlyargs
|
||||
|
||||
|
||||
class TypeAssign:
|
||||
name: str
|
||||
type: MidasType
|
||||
|
||||
|
||||
class AssignStmt:
|
||||
targets: list[Expr]
|
||||
value: Expr
|
||||
|
||||
|
||||
class ReturnStmt:
|
||||
value: Optional[Expr]
|
||||
|
||||
|
||||
class IfStmt:
|
||||
test: Expr
|
||||
body: list[Stmt]
|
||||
orelse: list[Stmt]
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Expr | Expressions
|
||||
class BinaryExpr:
|
||||
left: Expr
|
||||
operator: ast.operator
|
||||
right: Expr
|
||||
|
||||
|
||||
class CompareExpr:
|
||||
left: Expr
|
||||
operator: ast.cmpop
|
||||
right: Expr
|
||||
|
||||
|
||||
class UnaryExpr:
|
||||
operator: ast.unaryop
|
||||
right: Expr
|
||||
|
||||
|
||||
class CallExpr:
|
||||
callee: Expr
|
||||
arguments: list[Expr]
|
||||
keywords: dict[str, Expr]
|
||||
|
||||
|
||||
class GetExpr:
|
||||
object: Expr
|
||||
name: str
|
||||
|
||||
|
||||
class LiteralExpr:
|
||||
value: Any
|
||||
|
||||
|
||||
class VariableExpr:
|
||||
name: str
|
||||
|
||||
|
||||
class LogicalExpr:
|
||||
left: Expr
|
||||
operator: ast.boolop
|
||||
right: Expr
|
||||
|
||||
|
||||
class CastExpr:
|
||||
type: MidasType
|
||||
expr: Expr
|
||||
|
||||
|
||||
class TernaryExpr:
|
||||
test: Expr
|
||||
if_true: Expr
|
||||
if_false: Expr
|
||||
|
||||
|
||||
###<
|
||||
@@ -1,102 +0,0 @@
|
||||
from lexer.base import Lexer
|
||||
from lexer.keyword import ANNOTATION_KEYWORDS
|
||||
from lexer.token import TokenType
|
||||
|
||||
|
||||
class AnnotationLexer(Lexer):
|
||||
def scan_token(self) -> None:
|
||||
char: str = self.advance()
|
||||
match char:
|
||||
case "(":
|
||||
self.add_token(TokenType.LEFT_PAREN)
|
||||
case ")":
|
||||
self.add_token(TokenType.RIGHT_PAREN)
|
||||
case "[":
|
||||
self.add_token(TokenType.LEFT_BRACKET)
|
||||
case "]":
|
||||
self.add_token(TokenType.RIGHT_BRACKET)
|
||||
case "<":
|
||||
self.add_token(
|
||||
TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS
|
||||
)
|
||||
case ">":
|
||||
self.add_token(
|
||||
TokenType.GREATER_EQUAL if self.match("=") else TokenType.GREATER
|
||||
)
|
||||
case "=":
|
||||
self.add_token(
|
||||
TokenType.EQUAL_EQUAL if self.match("=") else TokenType.EQUAL
|
||||
)
|
||||
case "!":
|
||||
if self.match("="):
|
||||
self.add_token(TokenType.BANG_EQUAL)
|
||||
else:
|
||||
self.error("Unexpected single bang. Did you mean '!=' ?")
|
||||
case ":":
|
||||
self.add_token(TokenType.COLON)
|
||||
case ",":
|
||||
self.add_token(TokenType.COMMA)
|
||||
case "_":
|
||||
self.add_token(TokenType.UNDERSCORE)
|
||||
case "+":
|
||||
self.add_token(TokenType.PLUS)
|
||||
case "#":
|
||||
self.scan_comment()
|
||||
case "\n":
|
||||
self.add_token(TokenType.NEWLINE)
|
||||
case " " | "\r" | "\t":
|
||||
# Consume all whitespace characters until EOL or EOF
|
||||
while (
|
||||
self.peek().isspace()
|
||||
and self.peek() != "\n"
|
||||
and not self.is_at_end()
|
||||
):
|
||||
self.advance()
|
||||
self.add_token(TokenType.WHITESPACE)
|
||||
case _:
|
||||
if char.isdigit():
|
||||
self.scan_number()
|
||||
elif char.isalpha():
|
||||
self.scan_identifier()
|
||||
else:
|
||||
self.error("Unexpected character")
|
||||
return None
|
||||
|
||||
def scan_number(self):
|
||||
"""Scan the rest of number and add it as a token
|
||||
|
||||
This method handles both simple integers and floats. Scientific notation
|
||||
and base prefixes (0x, 0b, 0o) are not supported
|
||||
"""
|
||||
while self.peek().isdigit():
|
||||
self.advance()
|
||||
|
||||
if self.peek() == "." and self.peek_next().isdigit():
|
||||
self.advance()
|
||||
while self.peek().isdigit():
|
||||
self.advance()
|
||||
|
||||
value: float = float(self.source[self.start : self.idx])
|
||||
self.add_token(TokenType.NUMBER, value)
|
||||
|
||||
def scan_identifier(self):
|
||||
"""Scan the rest of an identifier and add it as a token
|
||||
|
||||
An identifier starts with a letter, followed by any number of
|
||||
alphanumerical characters or underscores
|
||||
"""
|
||||
while self.peek().isalnum() or self.peek() == "_":
|
||||
self.advance()
|
||||
|
||||
lexeme: str = self.source[self.start : self.idx]
|
||||
token_type: TokenType = ANNOTATION_KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
|
||||
self.add_token(token_type)
|
||||
|
||||
def scan_comment(self):
|
||||
"""Scan the rest of a comment and add it as a token
|
||||
|
||||
A comment starts with a `#` character and ends at the EOL/EOF
|
||||
"""
|
||||
while self.peek() != "\n" and not self.is_at_end():
|
||||
self.advance()
|
||||
self.add_token(TokenType.COMMENT)
|
||||
@@ -1,16 +0,0 @@
|
||||
from lexer.token import TokenType
|
||||
|
||||
ANNOTATION_KEYWORDS: dict[str, TokenType] = {
|
||||
"True": TokenType.TRUE,
|
||||
"False": TokenType.FALSE,
|
||||
"None": TokenType.NONE,
|
||||
}
|
||||
|
||||
MIDAS_KEYWORDS: dict[str, TokenType] = {
|
||||
"type": TokenType.TYPE,
|
||||
"op": TokenType.OP,
|
||||
"constraint": TokenType.CONSTRAINT,
|
||||
"true": TokenType.TRUE,
|
||||
"false": TokenType.FALSE,
|
||||
"none": TokenType.NONE,
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Any
|
||||
|
||||
from lexer.position import Position
|
||||
|
||||
|
||||
class TokenType(Enum):
|
||||
# Punctuation
|
||||
LEFT_PAREN = auto()
|
||||
RIGHT_PAREN = auto()
|
||||
LEFT_BRACKET = auto()
|
||||
RIGHT_BRACKET = auto()
|
||||
LEFT_BRACE = auto()
|
||||
RIGHT_BRACE = auto()
|
||||
COLON = auto()
|
||||
COMMA = auto()
|
||||
UNDERSCORE = auto()
|
||||
|
||||
# Operators
|
||||
PLUS = auto()
|
||||
MINUS = auto()
|
||||
STAR = auto()
|
||||
SLASH = auto()
|
||||
GREATER = auto()
|
||||
GREATER_EQUAL = auto()
|
||||
LESS = auto()
|
||||
LESS_EQUAL = auto()
|
||||
EQUAL = auto()
|
||||
EQUAL_EQUAL = auto()
|
||||
BANG_EQUAL = auto()
|
||||
|
||||
# Literals
|
||||
IDENTIFIER = auto()
|
||||
NUMBER = auto()
|
||||
TRUE = auto()
|
||||
FALSE = auto()
|
||||
NONE = auto()
|
||||
|
||||
# Keywords
|
||||
TYPE = auto()
|
||||
OP = auto()
|
||||
CONSTRAINT = auto()
|
||||
|
||||
# Misc
|
||||
COMMENT = auto()
|
||||
WHITESPACE = auto()
|
||||
EOF = auto()
|
||||
NEWLINE = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Token:
|
||||
"""A scanned token"""
|
||||
|
||||
type: TokenType
|
||||
lexeme: str
|
||||
value: Any
|
||||
position: Position
|
||||
37
midas/ast/location.py
Normal file
37
midas/ast/location.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Protocol
|
||||
|
||||
|
||||
class HasLocation(Protocol):
|
||||
lineno: int
|
||||
col_offset: int
|
||||
end_lineno: Optional[int]
|
||||
end_col_offset: Optional[int]
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Location:
|
||||
lineno: int
|
||||
col_offset: int
|
||||
end_lineno: Optional[int]
|
||||
end_col_offset: Optional[int]
|
||||
|
||||
@staticmethod
|
||||
def from_ast(obj: HasLocation) -> Location:
|
||||
return Location(
|
||||
lineno=obj.lineno,
|
||||
col_offset=obj.col_offset,
|
||||
end_lineno=obj.end_lineno,
|
||||
end_col_offset=obj.end_col_offset,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def span(start: Location, end: Location) -> Location:
|
||||
return Location(
|
||||
lineno=start.lineno,
|
||||
col_offset=start.col_offset,
|
||||
end_lineno=end.lineno,
|
||||
end_col_offset=end.end_col_offset,
|
||||
)
|
||||
266
midas/ast/midas.py
Normal file
266
midas/ast/midas.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
This file was generated by a script. Any manual changes might be overwritten.
|
||||
Please modify gen/midas.py instead and run gen/gen.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.lexer.token import Token
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
##############
|
||||
# Statements #
|
||||
##############
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Stmt(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_extend_stmt(self, stmt: ExtendStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_op_stmt(self, stmt: OpStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_predicate_stmt(self, stmt: PredicateStmt) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeStmt(Stmt):
|
||||
name: Token
|
||||
params: list[Param]
|
||||
type: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Param:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_type_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PropertyStmt(Stmt):
|
||||
name: Token
|
||||
type: Type
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_property_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExtendStmt(Stmt):
|
||||
type: Type
|
||||
operations: list[OpStmt]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_extend_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OpStmt(Stmt):
|
||||
name: Token
|
||||
operand: Type
|
||||
result: Type
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_op_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PredicateStmt(Stmt):
|
||||
name: Token
|
||||
subject: Token
|
||||
type: Type
|
||||
condition: Expr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_predicate_stmt(self)
|
||||
|
||||
|
||||
###############
|
||||
# Expressions #
|
||||
###############
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Expr(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_binary_expr(self, expr: BinaryExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_get_expr(self, expr: GetExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_variable_expr(self, expr: VariableExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_grouping_expr(self, expr: GroupingExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LogicalExpr(Expr):
|
||||
left: Expr
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_logical_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BinaryExpr(Expr):
|
||||
left: Expr
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_binary_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UnaryExpr(Expr):
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_unary_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GetExpr(Expr):
|
||||
expr: Expr
|
||||
name: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_get_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VariableExpr(Expr):
|
||||
name: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_variable_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GroupingExpr(Expr):
|
||||
expr: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_grouping_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LiteralExpr(Expr):
|
||||
value: Any
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_literal_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WildcardExpr(Expr):
|
||||
token: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_wildcard_expr(self)
|
||||
|
||||
|
||||
#########
|
||||
# Types #
|
||||
#########
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Type(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_named_type(self, type: NamedType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_generic_type(self, type: GenericType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint_type(self, type: ConstraintType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_complex_type(self, type: ComplexType) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NamedType(Type):
|
||||
name: Token
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_named_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenericType(Type):
|
||||
type: Type
|
||||
params: list[Type]
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_generic_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstraintType(Type):
|
||||
type: Type
|
||||
constraint: Expr
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ComplexType(Type):
|
||||
properties: list[PropertyStmt]
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_complex_type(self)
|
||||
628
midas/ast/printer.py
Normal file
628
midas/ast/printer.py
Normal file
@@ -0,0 +1,628 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import io
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
from typing import Generator, Generic, Optional, Protocol, TypeVar
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
|
||||
|
||||
class _Level(Enum):
|
||||
EMPTY = auto()
|
||||
ACTIVE = auto()
|
||||
LAST = auto()
|
||||
|
||||
|
||||
class Expr(Protocol):
|
||||
def accept(self, printer: AstPrinter) -> None: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Expr)
|
||||
|
||||
|
||||
class AstPrinter(Generic[T]):
|
||||
LAST_CHILD = "└── "
|
||||
CHILD = "├── "
|
||||
VERTICAL = "│ "
|
||||
EMPTY = " "
|
||||
|
||||
def __init__(self):
|
||||
self._levels: list[_Level] = []
|
||||
self._idx: Optional[int] = None
|
||||
self._buf: io.StringIO = io.StringIO()
|
||||
|
||||
def print(self, expr: T):
|
||||
self._buf = io.StringIO()
|
||||
expr.accept(self)
|
||||
return self._buf.getvalue()
|
||||
|
||||
@contextmanager
|
||||
def _child_level(self, single: bool = False) -> Generator[None, None, None]:
|
||||
self._levels.append(_Level.LAST if single else _Level.ACTIVE)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._levels.pop()
|
||||
|
||||
def _mark_last(self):
|
||||
if self._levels:
|
||||
self._levels[-1] = _Level.LAST
|
||||
|
||||
def _write_line(self, text: str, *, last: bool = False):
|
||||
if last:
|
||||
self._mark_last()
|
||||
indent: str = self._build_indent()
|
||||
if self._idx is not None:
|
||||
text = f"[{self._idx}] {text}"
|
||||
self._idx = None
|
||||
self._buf.write(indent + text + "\n")
|
||||
|
||||
def _build_indent(self) -> str:
|
||||
parts: list[str] = []
|
||||
for level in self._levels[:-1]:
|
||||
parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL)
|
||||
if self._levels:
|
||||
if self._levels[-1] == _Level.LAST:
|
||||
parts.append(self.LAST_CHILD)
|
||||
self._levels[-1] = _Level.EMPTY
|
||||
else:
|
||||
parts.append(self.CHILD)
|
||||
return "".join(parts)
|
||||
|
||||
def _write_optional_child(
|
||||
self, label: str, child: Optional[T], *, last: bool = False
|
||||
):
|
||||
if last:
|
||||
self._mark_last()
|
||||
if child is None:
|
||||
self._write_line(f"{label}: None")
|
||||
else:
|
||||
self._write_line(label)
|
||||
with self._child_level(single=True):
|
||||
child.accept(self)
|
||||
|
||||
|
||||
class MidasAstPrinter(
|
||||
AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None], m.Type.Visitor[None]
|
||||
):
|
||||
# Statements
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
self._write_line("TypeStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
for i, param in enumerate(stmt.params):
|
||||
self._idx = i
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
self._print_type_stmt_param(param)
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None:
|
||||
self._write_line("Param")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{param.name.lexeme}"')
|
||||
self._write_optional_child("bound", param.bound, last=True)
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||
self._write_line("PropertyStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self._write_line("ExtendStmt")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
self._write_line("operations", last=True)
|
||||
with self._child_level():
|
||||
for i, op in enumerate(stmt.operations):
|
||||
self._idx = i
|
||||
if i == len(stmt.operations) - 1:
|
||||
self._mark_last()
|
||||
op.accept(self)
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
|
||||
self._write_line("OpStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
|
||||
self._write_line("operand")
|
||||
with self._child_level(single=True):
|
||||
stmt.operand.accept(self)
|
||||
|
||||
self._write_line("result", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.result.accept(self)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
self._write_line("PredicateStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line(f'subject: "{stmt.subject.lexeme}"')
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
self._write_line("condition", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.condition.accept(self)
|
||||
|
||||
# Expressions
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
self._write_line("LogicalExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr):
|
||||
self._write_line("BinaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr):
|
||||
self._write_line("UnaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
self._write_line("GetExpr")
|
||||
with self._child_level():
|
||||
self._write_line("expr")
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr):
|
||||
self._write_line("VariableExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr):
|
||||
self._write_line("GroupingExpr")
|
||||
with self._child_level():
|
||||
self._write_line("expr", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"value: {expr.value}", last=True)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||
self._write_line("WildcardExpr")
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> None:
|
||||
self._write_line("NamedType")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{type.name.lexeme}"', last=True)
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> None:
|
||||
self._write_line("GenericType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level():
|
||||
type.type.accept(self)
|
||||
self._write_line("params", last=True)
|
||||
with self._child_level():
|
||||
for i, param in enumerate(type.params):
|
||||
self._idx = i
|
||||
if i == len(type.params) - 1:
|
||||
self._mark_last()
|
||||
param.accept(self)
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
||||
self._write_line("ConstraintType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
type.type.accept(self)
|
||||
self._write_line("constraint", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.constraint.accept(self)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> None:
|
||||
self._write_line("ComplexType")
|
||||
with self._child_level():
|
||||
self._write_line("properties", last=True)
|
||||
with self._child_level():
|
||||
for i, prop in enumerate(type.properties):
|
||||
self._idx = i
|
||||
if i == len(type.properties) - 1:
|
||||
self._mark_last()
|
||||
prop.accept(self)
|
||||
|
||||
|
||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
||||
def __init__(self, indent: int = 4):
|
||||
self.indent: int = indent
|
||||
self.level: int = 0
|
||||
|
||||
def indented(self, text: str) -> str:
|
||||
return " " * (self.level * self.indent) + text
|
||||
|
||||
def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
|
||||
self.level = 0
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
|
||||
template: str = ""
|
||||
if len(stmt.params) != 0:
|
||||
params: list[str] = [
|
||||
self._print_type_template_param(param) for param in stmt.params
|
||||
]
|
||||
template = f"[{', '.join(params)}]"
|
||||
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
||||
return self.indented(res)
|
||||
|
||||
def _print_type_template_param(self, param: m.TypeStmt.Param) -> str:
|
||||
res: str = param.name.lexeme
|
||||
if param.bound is not None:
|
||||
res += "<:" + param.bound.accept(self)
|
||||
return res
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||
res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
|
||||
return self.indented(res)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt):
|
||||
res: str = self.indented(f"extend {stmt.type.accept(self)}")
|
||||
res += " {\n"
|
||||
self.level += 1
|
||||
for op in stmt.operations:
|
||||
res += op.accept(self)
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt):
|
||||
operand: str = stmt.operand.accept(self)
|
||||
result: str = stmt.result.accept(self)
|
||||
return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}\n")
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
name: str = stmt.name.lexeme
|
||||
subject: str = stmt.subject.lexeme
|
||||
type: str = stmt.type.accept(self)
|
||||
condition: str = stmt.condition.accept(self)
|
||||
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{left} {operator} {right}"
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{left} {operator} {right}"
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr):
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{operator}{right}"
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
expr_: str = expr.expr.accept(self)
|
||||
name: str = expr.name.lexeme
|
||||
return f"{expr_}.{name}"
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr):
|
||||
return expr.name.lexeme
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr):
|
||||
expr_: str = expr.expr.accept(self)
|
||||
return f"({expr_})"
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr):
|
||||
return str(expr.value)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr):
|
||||
return "_"
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> str:
|
||||
return type.name.lexeme
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
if len(type.params) != 0:
|
||||
params: list[str] = [param.accept(self) for param in type.params]
|
||||
res += f"[{', '.join(params)}]"
|
||||
return res
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
res += " where " + type.constraint.accept(self)
|
||||
return res
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> str:
|
||||
res: str = "{\n"
|
||||
self.level += 1
|
||||
for prop in type.properties:
|
||||
res += prop.accept(self)
|
||||
res += "\n"
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
|
||||
class PythonAstPrinter(
|
||||
AstPrinter,
|
||||
p.MidasType.Visitor[None],
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[None],
|
||||
):
|
||||
def visit_base_type(self, node: p.BaseType) -> None:
|
||||
self._write_line("BaseType")
|
||||
with self._child_level():
|
||||
self._write_line(f"base: {node.base}")
|
||||
self._write_optional_child("param", node.param, last=True)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||
self._write_line("ConstraintType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
node.type.accept(self)
|
||||
self._write_line(f"constraint: {ast.unparse(node.constraint)}", last=True)
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> None:
|
||||
self._write_line("FrameColumn")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {node.name}")
|
||||
self._write_optional_child("type", node.type, last=True)
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> None:
|
||||
self._write_line("FrameType")
|
||||
with self._child_level():
|
||||
self._write_line("columns", last=True)
|
||||
with self._child_level():
|
||||
for i, col in enumerate(node.columns):
|
||||
self._idx = i
|
||||
if i == len(node.columns) - 1:
|
||||
self._mark_last()
|
||||
col.accept(self)
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
stmt.expr.accept(self)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
self._write_line("Function")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {stmt.name}")
|
||||
|
||||
self._write_line("posonlyargs")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(stmt.posonlyargs):
|
||||
self._idx = i
|
||||
if i == len(stmt.posonlyargs) - 1:
|
||||
self._mark_last()
|
||||
self._print_argument(arg)
|
||||
|
||||
self._write_line("args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(stmt.args):
|
||||
self._idx = i
|
||||
if i == len(stmt.args) - 1:
|
||||
self._mark_last()
|
||||
self._print_argument(arg)
|
||||
|
||||
self._write_line("kwonlyargs")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(stmt.kwonlyargs):
|
||||
self._idx = i
|
||||
if i == len(stmt.kwonlyargs) - 1:
|
||||
self._mark_last()
|
||||
self._print_argument(arg)
|
||||
|
||||
self._write_optional_child("returns", stmt.returns)
|
||||
self._write_line("body", last=True)
|
||||
with self._child_level():
|
||||
for i, body_stmt in enumerate(stmt.body):
|
||||
self._idx = i
|
||||
if i == len(stmt.body) - 1:
|
||||
self._mark_last()
|
||||
body_stmt.accept(self)
|
||||
|
||||
def _print_argument(self, arg: p.Function.Argument) -> None:
|
||||
self._write_line("FunctionArgument")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {arg.name}")
|
||||
self._write_optional_child("type", arg.type, last=True)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
self._write_line("TypeAssign")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {stmt.name}")
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
self._write_line("AssignStmt")
|
||||
with self._child_level():
|
||||
self._write_line("targets")
|
||||
with self._child_level():
|
||||
for i, target in enumerate(stmt.targets):
|
||||
self._idx = i
|
||||
if i == len(stmt.targets) - 1:
|
||||
self._mark_last()
|
||||
target.accept(self)
|
||||
self._write_line("value", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.value.accept(self)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
self._write_line("ReturnStmt")
|
||||
with self._child_level():
|
||||
self._write_optional_child("value", stmt.value, last=True)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
self._write_line("IfStmt")
|
||||
with self._child_level():
|
||||
self._write_line("test")
|
||||
with self._child_level(single=True):
|
||||
stmt.test.accept(self)
|
||||
self._write_line("body")
|
||||
with self._child_level():
|
||||
for i, body_stmt in enumerate(stmt.body):
|
||||
self._idx = i
|
||||
if i == len(stmt.body) - 1:
|
||||
self._mark_last()
|
||||
body_stmt.accept(self)
|
||||
self._write_line("orelse", last=True)
|
||||
with self._child_level():
|
||||
for i, else_stmt in enumerate(stmt.orelse):
|
||||
self._idx = i
|
||||
if i == len(stmt.orelse) - 1:
|
||||
self._mark_last()
|
||||
else_stmt.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||
self._write_line("BinaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
|
||||
self._write_line("CompareExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
|
||||
self._write_line("UnaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||
self._write_line("CallExpr")
|
||||
with self._child_level():
|
||||
self._write_line("callee")
|
||||
with self._child_level(single=True):
|
||||
expr.callee.accept(self)
|
||||
|
||||
self._write_line("arguments")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(expr.arguments):
|
||||
self._idx = i
|
||||
if i == len(expr.arguments) - 1:
|
||||
self._mark_last()
|
||||
arg.accept(self)
|
||||
|
||||
self._write_line("keywords", last=True)
|
||||
with self._child_level():
|
||||
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||
self._idx = i
|
||||
if i == len(expr.keywords) - 1:
|
||||
self._mark_last()
|
||||
self._write_line(name)
|
||||
with self._child_level(single=True):
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> None:
|
||||
self._write_line("GetExpr")
|
||||
with self._child_level():
|
||||
self._write_line("object")
|
||||
with self._child_level(single=True):
|
||||
expr.object.accept(self)
|
||||
self._write_line(f"name: {expr.name}", last=True)
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"value: {expr.value}")
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
|
||||
self._write_line("VariableExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"name: {expr.name}")
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
|
||||
self._write_line("LogicalExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||
self._write_line("CastExpr")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
expr.type.accept(self)
|
||||
self._write_line("expr", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||
self._write_line("TernaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("test")
|
||||
with self._child_level(single=True):
|
||||
expr.test.accept(self)
|
||||
|
||||
self._write_line("if_true")
|
||||
with self._child_level(single=True):
|
||||
expr.if_true.accept(self)
|
||||
|
||||
self._write_line("if_false", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.if_false.accept(self)
|
||||
314
midas/ast/python.py
Normal file
314
midas/ast/python.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
This file was generated by a script. Any manual changes might be overwritten.
|
||||
Please modify gen/python.py instead and run gen/gen.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from midas.ast.location import Location
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
####################
|
||||
# Type annotations #
|
||||
####################
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MidasType(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_base_type(self, node: BaseType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint_type(self, node: ConstraintType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_frame_column(self, node: FrameColumn) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_frame_type(self, node: FrameType) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BaseType(MidasType):
|
||||
base: str
|
||||
param: Optional[MidasType]
|
||||
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_base_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstraintType(MidasType):
|
||||
type: MidasType
|
||||
constraint: ast.expr
|
||||
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FrameColumn(MidasType):
|
||||
name: Optional[str]
|
||||
type: Optional[MidasType]
|
||||
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_frame_column(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FrameType(MidasType):
|
||||
columns: list[FrameColumn]
|
||||
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_frame_type(self)
|
||||
|
||||
|
||||
##############
|
||||
# Statements #
|
||||
##############
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Stmt(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_expression_stmt(self, stmt: ExpressionStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_function(self, stmt: Function) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_type_assign(self, stmt: TypeAssign) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_assign_stmt(self, stmt: AssignStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_return_stmt(self, stmt: ReturnStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_if_stmt(self, stmt: IfStmt) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExpressionStmt(Stmt):
|
||||
expr: Expr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_expression_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Function(Stmt):
|
||||
name: str
|
||||
posonlyargs: list[Argument]
|
||||
args: list[Argument]
|
||||
sink: Optional[Argument]
|
||||
kwonlyargs: list[Argument]
|
||||
kw_sink: Optional[Argument]
|
||||
returns: Optional[MidasType]
|
||||
body: list[Stmt]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
location: Optional[Location] = None
|
||||
name: str
|
||||
type: Optional[MidasType]
|
||||
default: Optional[Expr]
|
||||
|
||||
@property
|
||||
def all_args(self) -> list[Argument]:
|
||||
return self.posonlyargs + self.args + self.kwonlyargs
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_function(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeAssign(Stmt):
|
||||
name: str
|
||||
type: MidasType
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_type_assign(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssignStmt(Stmt):
|
||||
targets: list[Expr]
|
||||
value: Expr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_assign_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReturnStmt(Stmt):
|
||||
value: Optional[Expr]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_return_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IfStmt(Stmt):
|
||||
test: Expr
|
||||
body: list[Stmt]
|
||||
orelse: list[Stmt]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_if_stmt(self)
|
||||
|
||||
|
||||
###############
|
||||
# Expressions #
|
||||
###############
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Expr(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_binary_expr(self, expr: BinaryExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_compare_expr(self, expr: CompareExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_call_expr(self, expr: CallExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_get_expr(self, expr: GetExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_variable_expr(self, expr: VariableExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_cast_expr(self, expr: CastExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BinaryExpr(Expr):
|
||||
left: Expr
|
||||
operator: ast.operator
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_binary_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CompareExpr(Expr):
|
||||
left: Expr
|
||||
operator: ast.cmpop
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_compare_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UnaryExpr(Expr):
|
||||
operator: ast.unaryop
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_unary_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CallExpr(Expr):
|
||||
callee: Expr
|
||||
arguments: list[Expr]
|
||||
keywords: dict[str, Expr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_call_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GetExpr(Expr):
|
||||
object: Expr
|
||||
name: str
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_get_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LiteralExpr(Expr):
|
||||
value: Any
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_literal_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VariableExpr(Expr):
|
||||
name: str
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_variable_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LogicalExpr(Expr):
|
||||
left: Expr
|
||||
operator: ast.boolop
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_logical_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CastExpr(Expr):
|
||||
type: MidasType
|
||||
expr: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_cast_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TernaryExpr(Expr):
|
||||
test: Expr
|
||||
if_true: Expr
|
||||
if_false: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_ternary_expr(self)
|
||||
4
midas/checker/builtins.py
Normal file
4
midas/checker/builtins.py
Normal file
@@ -0,0 +1,4 @@
|
||||
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||
"float": {"int"},
|
||||
"int": {"bool"},
|
||||
}
|
||||
812
midas/checker/checker.py
Normal file
812
midas/checker/checker.py
Normal file
@@ -0,0 +1,812 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
BaseType,
|
||||
ComplexType,
|
||||
Function,
|
||||
Operation,
|
||||
Type,
|
||||
UnitType,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
from midas.resolver.midas import MidasResolver
|
||||
|
||||
|
||||
class ReturnException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MappedArgument:
|
||||
expr: p.Expr
|
||||
type: Type
|
||||
argument: Function.Argument
|
||||
|
||||
|
||||
class Checker(
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[Type],
|
||||
p.MidasType.Visitor[Type],
|
||||
):
|
||||
"""A type checker which can use custom type definitions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
locals: dict[p.Expr, int],
|
||||
source_path: Path,
|
||||
types_paths: list[Path],
|
||||
):
|
||||
self.logger: logging.Logger = logging.getLogger("Checker")
|
||||
self.source_path: Path = source_path
|
||||
self.types_paths: list[Path] = types_paths
|
||||
self.ctx: MidasResolver = MidasResolver()
|
||||
self.global_env: Environment = Environment()
|
||||
self.env: Environment = self.global_env
|
||||
self.locals: dict[p.Expr, int] = locals
|
||||
self.diagnostics: list[Diagnostic] = []
|
||||
self.judgements: list[tuple[p.Expr, Type]] = []
|
||||
|
||||
def diagnostic(self, type: DiagnosticType, location: Location, message: str):
|
||||
self.diagnostics.append(
|
||||
Diagnostic(
|
||||
file_path=self.source_path,
|
||||
location=location,
|
||||
type=type,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
|
||||
def error(self, location: Location, message: str):
|
||||
self.diagnostic(
|
||||
type=DiagnosticType.ERROR,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def warning(self, location: Location, message: str):
|
||||
self.diagnostic(
|
||||
type=DiagnosticType.WARNING,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def info(self, location: Location, message: str):
|
||||
self.diagnostic(
|
||||
type=DiagnosticType.INFO,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def type_of(self, expr: p.Expr) -> Type:
|
||||
"""Evaluate the type of an expression
|
||||
|
||||
Args:
|
||||
expr (p.Expr): the expression to evaluate
|
||||
|
||||
Returns:
|
||||
Type: the type of the given expression
|
||||
"""
|
||||
type: Type = expr.accept(self)
|
||||
self.judgements.append((expr, type))
|
||||
return type
|
||||
|
||||
def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
|
||||
"""Evaluate a sequence of statements
|
||||
|
||||
Args:
|
||||
block (list[p.Stmt]): the statements to evaluate
|
||||
env (Environment): the environment in which to evaluate
|
||||
|
||||
Returns:
|
||||
bool: whether a return statement is present in the block
|
||||
"""
|
||||
previous_env: Environment = self.env
|
||||
self.env = env
|
||||
returned: bool = False
|
||||
for i, stmt in enumerate(block):
|
||||
try:
|
||||
stmt.accept(self)
|
||||
except ReturnException:
|
||||
returned = True
|
||||
if i < len(block) - 1:
|
||||
self.warning(block[i + 1].location, "Unreachable statement")
|
||||
break
|
||||
self.env = previous_env
|
||||
return returned
|
||||
|
||||
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
|
||||
"""Type check a sequence of statements and returns diagnostics
|
||||
|
||||
Args:
|
||||
statements (list[p.Stmt]): the statements to evaluate and check
|
||||
|
||||
Returns:
|
||||
list[Diagnostic]: the list of diagnostics (errors, warning, etc.)
|
||||
"""
|
||||
self.diagnostics = []
|
||||
|
||||
for path in self.types_paths:
|
||||
self.import_midas(path)
|
||||
self.logger.debug(f"Midas types: {self.ctx._types}")
|
||||
self.logger.debug(f"Midas operations: {self.ctx._operations}")
|
||||
|
||||
for stmt in statements:
|
||||
stmt.accept(self)
|
||||
|
||||
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
|
||||
return self.diagnostics
|
||||
|
||||
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
|
||||
"""Look up a variable in the environment it was declared
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
expr (p.Expr): the variable expression, used to lookup the scope distance
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the type of the variable, or None if it was not found
|
||||
"""
|
||||
distance: Optional[int] = self.locals.get(expr)
|
||||
if distance is not None:
|
||||
return self.env.get_at(distance, name)
|
||||
return self.global_env.get(name)
|
||||
|
||||
def import_midas(self, path: Path) -> None:
|
||||
"""Import Midas definitions from a path
|
||||
|
||||
Args:
|
||||
path (Path): the import path
|
||||
"""
|
||||
self.logger.debug(f"Importing type definitions from {path}")
|
||||
lexer: MidasLexer = MidasLexer(path.read_text())
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser: MidasParser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
self.ctx.resolve(stmts)
|
||||
|
||||
def unfold_type(self, type: Type) -> Type:
|
||||
match type:
|
||||
case AliasType(type=ref_type):
|
||||
return self.unfold_type(ref_type)
|
||||
case _:
|
||||
return type
|
||||
|
||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||
"""Check whether `type1` is a subtype of `type2`
|
||||
|
||||
For more details on the rules checked here, see TAPL Chap. 15-16-17
|
||||
|
||||
Args:
|
||||
type1 (Type): the potential subtype
|
||||
type2 (Type): the potential supertype
|
||||
|
||||
Returns:
|
||||
bool: whether `type1` is a subtype of `type2`
|
||||
"""
|
||||
|
||||
if type1 == type2:
|
||||
return True
|
||||
|
||||
match (type1, type2):
|
||||
case (AliasType(type=base1), _):
|
||||
return self.is_subtype(base1, type2)
|
||||
|
||||
case (BaseType(name=name1), BaseType(name=name2)):
|
||||
return name1 in BUILTIN_SUBTYPES.get(name2, set())
|
||||
|
||||
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
||||
for k, t in props2.items():
|
||||
if k not in props1:
|
||||
return False
|
||||
if not self.is_subtype(props1[k], t):
|
||||
return False
|
||||
return True
|
||||
|
||||
case (Function(returns=return1), Function(returns=return2)):
|
||||
if not self.is_func_subtype(type1, type2):
|
||||
return False
|
||||
if not self.is_subtype(return1, return2):
|
||||
return False
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# TODO: verify the logic in here
|
||||
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
|
||||
"""Check whether a function is a subtype of another
|
||||
|
||||
Args:
|
||||
func1 (Function): the potential function subtype
|
||||
func2 (Function): the potential function supertype
|
||||
|
||||
Returns:
|
||||
bool: whether `func1` is a subtype of `func2`
|
||||
"""
|
||||
if not self.is_subtype(func1.returns, func2.returns):
|
||||
return False
|
||||
|
||||
pos1: list[Function.Argument] = func1.pos_args
|
||||
mixed1: list[Function.Argument] = func1.args
|
||||
kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args}
|
||||
pos2: list[Function.Argument] = func2.pos_args
|
||||
mixed2: list[Function.Argument] = func2.args
|
||||
kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args}
|
||||
|
||||
mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2}
|
||||
mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2}
|
||||
|
||||
def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool:
|
||||
if not self.is_subtype(sub.type, sup.type):
|
||||
return False
|
||||
if not sup.required and sub.required:
|
||||
return False
|
||||
return True
|
||||
|
||||
for arg1 in pos1:
|
||||
arg2: Function.Argument
|
||||
if arg1.pos < len(pos2):
|
||||
arg2 = pos2[arg1.pos]
|
||||
elif arg1.pos in mixed_by_pos:
|
||||
arg2 = mixed_by_pos[arg1.pos]
|
||||
elif not arg1.required:
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
if not is_arg_subtype(arg2, arg1):
|
||||
return False
|
||||
|
||||
for name, arg1 in kw1.items():
|
||||
arg2: Function.Argument
|
||||
if name in kw2:
|
||||
arg2 = kw2[name]
|
||||
elif name in mixed_by_name:
|
||||
arg2 = mixed_by_name[name]
|
||||
elif not arg1.required:
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
if not is_arg_subtype(arg2, arg1):
|
||||
return False
|
||||
|
||||
for arg1 in mixed1:
|
||||
pos_arg2: Optional[Function.Argument] = None
|
||||
kw_arg2: Optional[Function.Argument] = None
|
||||
if arg1.name in kw2:
|
||||
kw_arg2 = kw2[arg1.name]
|
||||
elif arg1.name in mixed_by_name:
|
||||
kw_arg2 = mixed_by_name[arg1.name]
|
||||
if arg1.pos < len(pos2):
|
||||
pos_arg2 = pos2[arg1.pos]
|
||||
elif arg1.pos in mixed_by_pos:
|
||||
pos_arg2 = mixed_by_pos[arg1.pos]
|
||||
|
||||
# No match in func2 and arg is required
|
||||
if pos_arg2 is None and kw_arg2 is None and arg1.required:
|
||||
return False
|
||||
|
||||
# Matching keyword argument
|
||||
if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1):
|
||||
return False
|
||||
|
||||
# Matching positional argument
|
||||
if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1):
|
||||
return False
|
||||
|
||||
mixed_positions: set[int] = {a.pos for a in mixed1}
|
||||
mixed_names: set[str] = {a.name for a in mixed1}
|
||||
for arg2 in pos2:
|
||||
if not arg2.required:
|
||||
continue
|
||||
if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions:
|
||||
return False
|
||||
|
||||
for name, arg2 in kw2.items():
|
||||
if not arg2.required:
|
||||
continue
|
||||
if name not in kw1 and name not in mixed_names:
|
||||
return False
|
||||
|
||||
for arg2 in mixed2:
|
||||
if arg2.required:
|
||||
continue
|
||||
pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions
|
||||
kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names
|
||||
if not pos_match or not kw_match:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
self.type_of(stmt.expr)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
env: Environment = Environment(self.env)
|
||||
pos_args: list[Function.Argument] = []
|
||||
args: list[Function.Argument] = []
|
||||
kw_args: list[Function.Argument] = []
|
||||
|
||||
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
||||
if arg.type is not None:
|
||||
return arg.type.accept(self)
|
||||
if arg.default is not None:
|
||||
return arg.default.accept(self)
|
||||
return UnknownType()
|
||||
|
||||
pos: int = 0
|
||||
for arg in stmt.posonlyargs:
|
||||
pos_args.append(
|
||||
Function.Argument(
|
||||
pos=pos,
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
pos += 1
|
||||
for arg in stmt.args:
|
||||
args.append(
|
||||
Function.Argument(
|
||||
pos=pos,
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
pos += 1
|
||||
for arg in stmt.kwonlyargs:
|
||||
kw_args.append(
|
||||
Function.Argument(
|
||||
pos=pos, # not relevant
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
pos += 1
|
||||
|
||||
for arg in pos_args + args + kw_args:
|
||||
env.define(arg.name, arg.type)
|
||||
|
||||
returns_hint: Optional[Type] = None
|
||||
if stmt.returns is not None:
|
||||
returns_hint = stmt.returns.accept(self)
|
||||
# Early define to handle simple fully-typed recursion
|
||||
inside_function: Function = Function(
|
||||
name=stmt.name,
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns_hint,
|
||||
)
|
||||
self.env.define(stmt.name, inside_function)
|
||||
|
||||
returned: bool = self.process_block(stmt.body, env)
|
||||
inferred_return: Type = UnknownType()
|
||||
if not returned:
|
||||
env.return_types.append(UnitType())
|
||||
return_types: set[Type] = set(env.return_types)
|
||||
if len(return_types) == 1:
|
||||
inferred_return = list(return_types)[0]
|
||||
elif len(return_types) > 1:
|
||||
self.error(
|
||||
stmt.location,
|
||||
f"Mixed return types: {env.return_types}",
|
||||
)
|
||||
|
||||
returns: Type = UnknownType()
|
||||
if returns_hint is not None:
|
||||
assert stmt.returns is not None
|
||||
returns = returns_hint
|
||||
if returns != inferred_return:
|
||||
self.error(
|
||||
stmt.returns.location,
|
||||
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
||||
)
|
||||
else:
|
||||
returns = inferred_return
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
function: Function = Function(
|
||||
name=stmt.name,
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns,
|
||||
)
|
||||
self.env.define(stmt.name, function)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
# TODO check not yet defined locally
|
||||
type: Type = stmt.type.accept(self)
|
||||
self.env.define(stmt.name, type)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
value_type: Type = self.type_of(stmt.value)
|
||||
for target in stmt.targets:
|
||||
self._assign(stmt.location, target, value_type)
|
||||
|
||||
def _assign(self, location: Location, target: p.Expr, value_type: Type):
|
||||
match target:
|
||||
case p.VariableExpr():
|
||||
self._assign_var(location, target, value_type)
|
||||
|
||||
case p.GetExpr():
|
||||
self._assign_attr(location, target, value_type)
|
||||
|
||||
case _:
|
||||
if not isinstance(target, p.VariableExpr):
|
||||
self.logger.warning(f"Unsupported assignment to {target}")
|
||||
self.warning(target.location, f"Unsupported assignment to {target}")
|
||||
|
||||
def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type):
|
||||
name: str = target.name
|
||||
var_type: Optional[Type] = self.look_up_variable(name, target)
|
||||
|
||||
if var_type is None:
|
||||
self.env.define(name, value_type)
|
||||
else:
|
||||
# S <: T
|
||||
# Γ, x: T v: S
|
||||
# x = v
|
||||
if not self.is_subtype(value_type, var_type):
|
||||
self.error(
|
||||
location,
|
||||
f"Cannot assign {value_type} to {name} of type {var_type}",
|
||||
)
|
||||
|
||||
def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type):
|
||||
object: Type = self.type_of(target.object)
|
||||
base_object: Type = self.unfold_type(object)
|
||||
match base_object:
|
||||
case ComplexType(properties=properties):
|
||||
if target.name not in properties:
|
||||
self.error(
|
||||
target.location, f"Unknown property '{target.name} on {object}"
|
||||
)
|
||||
return
|
||||
|
||||
prop_type: Type = properties[target.name]
|
||||
if not self.is_subtype(value_type, prop_type):
|
||||
self.error(
|
||||
location,
|
||||
f"Cannot assign {value_type} to property '{target.name}' of type {prop_type} on {object}",
|
||||
)
|
||||
return
|
||||
|
||||
case UnknownType():
|
||||
pass
|
||||
|
||||
case _:
|
||||
self.error(
|
||||
target.location,
|
||||
f"Cannot assign {value_type} to unknown property '{target.name}' on {object}",
|
||||
)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
|
||||
self.env.return_types.append(type)
|
||||
raise ReturnException()
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
# Not evaluated in sub-environment because assignments in the test leak out of the if
|
||||
# For example:
|
||||
# if (m := 1 + 1) < 2:
|
||||
# ...
|
||||
# print(m) # <- m is still defined
|
||||
test_type: Type = stmt.test.accept(self)
|
||||
|
||||
# TODO Allow subtypes or any type
|
||||
if test_type != self.ctx.get_type("bool"):
|
||||
self.error(
|
||||
stmt.test.location, f"If test must be a boolean, got {test_type}"
|
||||
)
|
||||
|
||||
env: Environment = Environment(self.env)
|
||||
body_returned: bool = self.process_block(stmt.body, env)
|
||||
else_returned: bool = self.process_block(stmt.orelse, env)
|
||||
self.env.return_types.extend(env.return_types)
|
||||
if body_returned and else_returned:
|
||||
raise ReturnException()
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
||||
return UnknownType()
|
||||
left: Type = self.type_of(expr.left)
|
||||
right: Type = self.type_of(expr.right)
|
||||
|
||||
operations: list[Operation] = self.ctx.get_operations_by_name(method)
|
||||
valid_operations: list[Operation] = []
|
||||
for op in operations:
|
||||
sig: Operation.CallSignature = op.signature
|
||||
if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right):
|
||||
valid_operations.append(op)
|
||||
|
||||
if len(valid_operations) == 0:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Undefined operation {method} between {left} and {right}",
|
||||
)
|
||||
return UnknownType()
|
||||
elif len(valid_operations) == 1:
|
||||
self.logger.debug(f"Unique operation {method} between {left} and {right}")
|
||||
return valid_operations[0].result
|
||||
|
||||
for i, op1 in enumerate(valid_operations):
|
||||
sig1: Operation.CallSignature = op1.signature
|
||||
best_match: bool = True
|
||||
for j, op2 in enumerate(valid_operations):
|
||||
if i == j:
|
||||
continue
|
||||
sig2: Operation.CallSignature = op2.signature
|
||||
if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype(
|
||||
sig1.right, sig2.right
|
||||
):
|
||||
best_match = False
|
||||
break
|
||||
self.logger.debug(f"{op1} is a full overload of {op2}")
|
||||
if best_match:
|
||||
return op1.result
|
||||
|
||||
overloads: list[str] = [
|
||||
f"({op.signature.left} {op.signature.method} {op.signature.right}) -> {op.result}"
|
||||
for op in valid_operations
|
||||
]
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(overloads)}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
||||
return UnknownType()
|
||||
left: Type = self.type_of(expr.left)
|
||||
right: Type = self.type_of(expr.right)
|
||||
|
||||
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
||||
if result is None:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Undefined operation {method} between {left} and {right}",
|
||||
)
|
||||
return UnknownType()
|
||||
return result
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||
callee: Type = self.type_of(expr.callee)
|
||||
if not isinstance(callee, Function):
|
||||
self.error(expr.callee.location, "Callee is not a function")
|
||||
return UnknownType()
|
||||
function: Function = callee
|
||||
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
||||
for arg in mapped:
|
||||
if not self.is_subtype(arg.type, arg.argument.type):
|
||||
self.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
)
|
||||
return function.returns
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
||||
object: Type = self.type_of(expr.object)
|
||||
base_object: Type = self.unfold_type(object)
|
||||
match base_object:
|
||||
case ComplexType(properties=properties):
|
||||
if expr.name not in properties:
|
||||
self.error(
|
||||
expr.location, f"Unknown property '{expr.name} on {object}"
|
||||
)
|
||||
return UnknownType()
|
||||
return properties[expr.name]
|
||||
|
||||
case UnknownType():
|
||||
return UnknownType()
|
||||
|
||||
case _:
|
||||
self.error(
|
||||
expr.location, f"Cannot get property '{expr.name}' on {object}"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
|
||||
match expr.value:
|
||||
case bool(): # Must be before int
|
||||
return self.ctx.get_type("bool")
|
||||
case int():
|
||||
return self.ctx.get_type("int")
|
||||
case float():
|
||||
return self.ctx.get_type("float")
|
||||
case str():
|
||||
return self.ctx.get_type("str")
|
||||
case _:
|
||||
self.warning(expr.location, f"Unknown literal {expr}")
|
||||
return UnknownType()
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
|
||||
return self.look_up_variable(expr.name, expr) or UnknownType()
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
|
||||
left: Type = expr.left.accept(self)
|
||||
right: Type = expr.right.accept(self)
|
||||
|
||||
if self.is_subtype(left, right):
|
||||
return right
|
||||
if self.is_subtype(right, left):
|
||||
return left
|
||||
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Incompatible operand types, {left=} and {right=}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
||||
return expr.type.accept(self)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
||||
test_type: Type = expr.test.accept(self)
|
||||
|
||||
# TODO Allow subtypes or any type
|
||||
if test_type != self.ctx.get_type("bool"):
|
||||
self.error(
|
||||
expr.test.location, f"If test must be a boolean, got {test_type}"
|
||||
)
|
||||
|
||||
true_type: Type = expr.if_true.accept(self)
|
||||
false_type: Type = expr.if_false.accept(self)
|
||||
if self.is_subtype(true_type, false_type):
|
||||
return false_type
|
||||
if self.is_subtype(false_type, true_type):
|
||||
return true_type
|
||||
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Incompatible types in ternary if branches: true={true_type} and false={false_type}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def visit_base_type(self, node: p.BaseType) -> Type:
|
||||
return self.ctx.get_type(node.base)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> Type: ...
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> Type: ...
|
||||
|
||||
def map_call_arguments(
|
||||
self, function: Function, call: p.CallExpr
|
||||
) -> list[MappedArgument]:
|
||||
"""Map call arguments to function parameters as defined in its signature
|
||||
|
||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||
with the arguments passed at the call site
|
||||
|
||||
Any mismatched, missing or unexpected argument is reported as a diagnostic
|
||||
|
||||
Args:
|
||||
function (Function): the function definition
|
||||
call (p.CallExpr): the call expression
|
||||
|
||||
Returns:
|
||||
list[MappedArgument]: the list of mapped arguments
|
||||
"""
|
||||
positional: list[tuple[p.Expr, Type]] = [
|
||||
(arg, self.type_of(arg)) for arg in call.arguments
|
||||
]
|
||||
keywords: dict[str, tuple[p.Expr, Type]] = {
|
||||
name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
|
||||
}
|
||||
set_args: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
arg.name for arg in function.pos_args + function.args if arg.required
|
||||
]
|
||||
required_keyword: list[str] = [
|
||||
arg.name for arg in function.kw_args if arg.required
|
||||
]
|
||||
|
||||
mapped: list[MappedArgument] = []
|
||||
|
||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||
mixed_params: list[Function.Argument] = list(function.args)
|
||||
kw_params: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.kw_args
|
||||
}
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Argument
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
self.error(arg[0].location, "Too many positional arguments")
|
||||
break
|
||||
name: str = param.name
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Argument
|
||||
if name not in kw_params:
|
||||
if name in set_args:
|
||||
self.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
)
|
||||
else:
|
||||
self.error(arg[0].location, f"Unknown keyword argument '{name}'")
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
def join_args(args: list[str]) -> str:
|
||||
args = list(map(lambda a: f"'{a}'", args))
|
||||
if len(args) == 0:
|
||||
return ""
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
args: str = join_args(required_positional)
|
||||
self.error(
|
||||
call.location,
|
||||
f"Missing required positional argument{plural}: {args}",
|
||||
)
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
args: str = join_args(required_keyword)
|
||||
self.error(
|
||||
call.location,
|
||||
f"Missing required keyword argument{plural}: {args}",
|
||||
)
|
||||
|
||||
return mapped
|
||||
37
midas/checker/diagnostic.py
Normal file
37
midas/checker/diagnostic.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
|
||||
|
||||
class DiagnosticType(StrEnum):
|
||||
ERROR = "Error"
|
||||
WARNING = "Warning"
|
||||
INFO = "Info"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Diagnostic:
|
||||
file_path: Path
|
||||
location: Location
|
||||
type: DiagnosticType
|
||||
message: str
|
||||
|
||||
@property
|
||||
def location_str(self) -> str:
|
||||
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
|
||||
end_loc: Optional[str] = ""
|
||||
if (
|
||||
self.location.end_lineno is not None
|
||||
and self.location.end_col_offset is not None
|
||||
):
|
||||
end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}"
|
||||
loc: str = (
|
||||
f"at {start_loc}" if end_loc is None else f"from {start_loc} to {end_loc}"
|
||||
)
|
||||
return f"{self.type} in {self.file_path} {loc}"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.location_str}: {self.message}"
|
||||
142
midas/checker/environment.py
Normal file
142
midas/checker/environment.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from midas.checker.types import Type
|
||||
|
||||
|
||||
class Environment:
|
||||
"""
|
||||
A scoped environment in which variables are defined
|
||||
|
||||
Each environment can inherit from a parent/enclosing environment.
|
||||
"""
|
||||
|
||||
def __init__(self, enclosing: Optional[Environment] = None) -> None:
|
||||
self.enclosing: Optional[Environment] = enclosing
|
||||
self.values: dict[str, Type] = {}
|
||||
self.return_types: list[Type] = []
|
||||
|
||||
self._children: list[Environment] = []
|
||||
if enclosing is not None:
|
||||
enclosing._children.append(self)
|
||||
|
||||
def define(self, name: str, value: Type) -> None:
|
||||
"""Define a variable in this environment
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
value (Type): the value
|
||||
"""
|
||||
self.values[name] = value
|
||||
|
||||
def get(self, name: str) -> Optional[Type]:
|
||||
"""Get a variable in the closest environment which has a definition for it
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the value of the variable, or None if it was not found
|
||||
"""
|
||||
if name in self.values:
|
||||
return self.values[name]
|
||||
if self.enclosing is not None:
|
||||
return self.enclosing.get(name)
|
||||
# raise NameError(f"Undefined variable '{name}'")
|
||||
return None
|
||||
|
||||
def assign(self, name: str, value: Type) -> bool:
|
||||
"""Assign a new value to a variable in the environment it was defined in
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
value (Type): the new value
|
||||
|
||||
Returns:
|
||||
bool: True if the variable was assigned in this environment or an ancestor, False otherwise
|
||||
"""
|
||||
if name not in self.values:
|
||||
if self.enclosing is None:
|
||||
return False
|
||||
if self.enclosing.assign(name, value):
|
||||
return True
|
||||
self.values[name] = value
|
||||
return True
|
||||
|
||||
def clear(self):
|
||||
"""Clear all definitions in this environment"""
|
||||
self.values = {}
|
||||
|
||||
def get_at(self, distance: int, name: str) -> Optional[Type]:
|
||||
"""Get the value of a variable at a given distance
|
||||
|
||||
A distance of 0 looks up in this environment, 1 in the parent environment, etc.
|
||||
This methods expects `distance` to be valid. An error will be raised if
|
||||
the stack does not extend far enough to reach `distance`
|
||||
|
||||
Args:
|
||||
distance (int): the scope distance
|
||||
name (str): the name of the variable
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the value at the given distance, or None if it is not defined in that environment
|
||||
|
||||
Raises:
|
||||
AssertionError: if the stack does not extend far enough to reach `distance`
|
||||
"""
|
||||
return self.ancestor(distance).values.get(name)
|
||||
|
||||
def assign_at(self, distance: int, name: str, value: Type) -> None:
|
||||
"""Assign a new value to a variable at a given distance
|
||||
|
||||
A distance of 0 assigns in this environment, 1 in the parent environment, etc.
|
||||
|
||||
Args:
|
||||
distance (int): the scope distance
|
||||
name (str): the name of the variable
|
||||
value (Type): the new value
|
||||
|
||||
Raises:
|
||||
AssertionError: if the stack does not extend far enough to reach `distance`
|
||||
"""
|
||||
self.ancestor(distance).values[name] = value
|
||||
|
||||
def ancestor(self, distance: int) -> Environment:
|
||||
"""Get the ancestor at a given distance
|
||||
|
||||
A distance of 0 references this environment, 1 the parent environment, etc.
|
||||
|
||||
Args:
|
||||
distance (int): the scope distance
|
||||
|
||||
Returns:
|
||||
Environment: the environment
|
||||
|
||||
Raises:
|
||||
AssertionError: if the stack does not extend far enough to reach `distance`
|
||||
"""
|
||||
env: Environment = self
|
||||
for _ in range(distance):
|
||||
assert env.enclosing is not None
|
||||
env = env.enclosing
|
||||
return env
|
||||
|
||||
def flat_dict(self) -> dict[str, Type]:
|
||||
"""Get the current environment including definitions in its ancestor as a flat dictionary
|
||||
|
||||
This method recursively combines this environment definitions with its ancestor's
|
||||
|
||||
Returns:
|
||||
dict: the combined environment
|
||||
"""
|
||||
if self.enclosing is None:
|
||||
return self.values
|
||||
return self.enclosing.flat_dict() | self.values
|
||||
|
||||
def dump(self) -> dict:
|
||||
return {
|
||||
"values": self.values,
|
||||
"return_types": self.return_types,
|
||||
"children": [child.dump() for child in self._children],
|
||||
}
|
||||
31
midas/checker/operators.py
Normal file
31
midas/checker/operators.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import ast
|
||||
from typing import Type
|
||||
|
||||
OPERATOR_METHODS: dict[Type[ast.operator], str] = {
|
||||
ast.Add: "__add__",
|
||||
ast.Sub: "__sub__",
|
||||
ast.Mult: "__mul__",
|
||||
ast.MatMult: "__matmul__",
|
||||
ast.Div: "__truediv__",
|
||||
ast.Mod: "__mod__",
|
||||
ast.Pow: "__pow__",
|
||||
ast.LShift: "__lshift__",
|
||||
ast.RShift: "__rshift__",
|
||||
ast.BitOr: "__or__",
|
||||
ast.BitXor: "__xor__",
|
||||
ast.BitAnd: "__and__",
|
||||
ast.FloorDiv: "__floordiv__",
|
||||
}
|
||||
|
||||
COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
||||
ast.Eq: "__eq__",
|
||||
# ast.NotEq: "__noteq__",
|
||||
ast.Lt: "__lt__",
|
||||
ast.LtE: "__le__",
|
||||
ast.Gt: "__gt__",
|
||||
ast.GtE: "__ge__",
|
||||
# ast.Is: "__is__",
|
||||
# ast.IsNot: "__isnot__",
|
||||
# ast.In: "__in__",
|
||||
# ast.NotIn: "__notin__",
|
||||
}
|
||||
60
midas/checker/types.py
Normal file
60
midas/checker/types.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class BaseType:
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class AliasType:
|
||||
name: str
|
||||
type: Type
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class UnknownType:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class UnitType:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Function:
|
||||
name: str
|
||||
pos_args: list[Argument]
|
||||
args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
pos: int
|
||||
name: str
|
||||
type: Type
|
||||
required: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ComplexType:
|
||||
properties: dict[str, Type]
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Operation:
|
||||
signature: CallSignature
|
||||
result: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class CallSignature:
|
||||
left: Type
|
||||
method: str
|
||||
right: Type
|
||||
|
||||
|
||||
Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType
|
||||
0
midas/cli/__init__.py
Normal file
0
midas/cli/__init__.py
Normal file
41
midas/cli/ansi.py
Normal file
41
midas/cli/ansi.py
Normal file
@@ -0,0 +1,41 @@
|
||||
class Ansi:
|
||||
CTRL = "\x1b["
|
||||
RESET = CTRL + "0m"
|
||||
BOLD = CTRL + "1m"
|
||||
DIM = CTRL + "2m"
|
||||
ITALIC = CTRL + "3m"
|
||||
UNDERLINE = CTRL + "4m"
|
||||
|
||||
BLACK = 0
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
YELLOW = 3
|
||||
BLUE = 4
|
||||
MAGENTA = 5
|
||||
CYAN = 6
|
||||
WHITE = 7
|
||||
|
||||
BRIGHT_BLACK = 60
|
||||
BRIGHT_RED = 61
|
||||
BRIGHT_GREEN = 62
|
||||
BRIGHT_YELLOW = 63
|
||||
BRIGHT_BLUE = 64
|
||||
BRIGHT_MAGENTA = 65
|
||||
BRIGHT_CYAN = 66
|
||||
BRIGHT_WHITE = 67
|
||||
|
||||
@classmethod
|
||||
def FG(cls, col: int) -> str:
|
||||
return f"{cls.CTRL}{30 + col}m"
|
||||
|
||||
@classmethod
|
||||
def BG(cls, col: int) -> str:
|
||||
return f"{cls.CTRL}{40 + col}m"
|
||||
|
||||
@classmethod
|
||||
def FG_RGB(cls, r: int, g: int, b: int) -> str:
|
||||
return f"{cls.CTRL}38;2;{r};{g};{b}m"
|
||||
|
||||
@classmethod
|
||||
def BG_RGB(cls, r: int, g: int, b: int) -> str:
|
||||
return f"{cls.CTRL}48;2;{r};{g};{b}m"
|
||||
58
midas/cli/highlight.css
Normal file
58
midas/cli/highlight.css
Normal file
@@ -0,0 +1,58 @@
|
||||
html,
|
||||
body {
|
||||
margin: 0;
|
||||
font-size: 14pt;
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
#code {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
font-family: monospace;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.line {
|
||||
display: flex;
|
||||
|
||||
&:nth-child(odd) {
|
||||
background-color: rgb(247, 247, 247);
|
||||
}
|
||||
|
||||
.no {
|
||||
width: 4em;
|
||||
text-align: right;
|
||||
padding: 0.2em 0.4em;
|
||||
border-right: solid black 1px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.txt {
|
||||
flex-grow: 1;
|
||||
padding: 0.2em 0.8em;
|
||||
}
|
||||
}
|
||||
|
||||
span {
|
||||
--col: transparent;
|
||||
--opacity: 0.1;
|
||||
--border: 0px;
|
||||
background-color: rgba(var(--col), var(--opacity));
|
||||
outline: solid rgb(var(--col)) var(--border);
|
||||
outline-offset: 2px;
|
||||
border-radius: 2px;
|
||||
|
||||
&:hover:not(:has(*:hover)) {
|
||||
--opacity: 0.8;
|
||||
--border: 2px;
|
||||
z-index: 10;
|
||||
}
|
||||
|
||||
&.keyword {
|
||||
color: rgb(211, 72, 9);
|
||||
pointer-events: none;
|
||||
}
|
||||
}
|
||||
306
midas/cli/highlighter.py
Normal file
306
midas/cli/highlighter.py
Normal file
@@ -0,0 +1,306 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generic, Optional, Protocol, TextIO, TypeVar
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.lexer.token import Token
|
||||
|
||||
H = TypeVar("H", bound="Highlighter", contravariant=True)
|
||||
|
||||
|
||||
class Highlightable(Protocol, Generic[H]):
|
||||
def accept(self, visitor: H): ...
|
||||
|
||||
|
||||
class Locatable(Protocol):
|
||||
@property
|
||||
@abstractmethod
|
||||
def location(self) -> Optional[Location]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LocatableToken:
|
||||
token: Token
|
||||
|
||||
@property
|
||||
def location(self) -> Location:
|
||||
return self.token.get_location()
|
||||
|
||||
|
||||
class Highlighter(ABC):
|
||||
BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css"
|
||||
EXTRA_CSS_PATH: Optional[Path] = None
|
||||
|
||||
def __init__(self, source: str) -> None:
|
||||
self.source: str = source
|
||||
self.lines: list[str] = self.source.splitlines()
|
||||
self.openings: dict[tuple[int, int], list[str]] = {}
|
||||
self.closings: dict[tuple[int, int], list[str]] = {}
|
||||
|
||||
def format_css(self, path: Path) -> list[str]:
|
||||
css: str = path.read_text()
|
||||
css = "\n".join((" " + line).rstrip() for line in css.splitlines())
|
||||
return [
|
||||
" <style>",
|
||||
css,
|
||||
" </style>",
|
||||
]
|
||||
|
||||
def dump(self, buf: TextIO):
|
||||
base_css: list[str] = self.format_css(self.BASE_CSS_PATH)
|
||||
extra_css: list[str] = (
|
||||
self.format_css(self.EXTRA_CSS_PATH)
|
||||
if self.EXTRA_CSS_PATH is not None
|
||||
else []
|
||||
)
|
||||
lines: list[str] = [
|
||||
"<!DOCTYPE html>",
|
||||
'<html lang="en">',
|
||||
"<head>",
|
||||
' <meta charset="UTF-8">',
|
||||
' <meta name="viewport" content="width=device-width, initial-scale=1.0">',
|
||||
" <title>Highlighted file</title>",
|
||||
*base_css,
|
||||
*extra_css,
|
||||
"</head>",
|
||||
"<body>",
|
||||
' <div id="code">',
|
||||
]
|
||||
for l, line in enumerate(self.lines):
|
||||
lineno: int = l + 1
|
||||
line_buf: str = (
|
||||
f'<div class="line" id="l{lineno}"><div class="no">{lineno}</div><div class="txt">'
|
||||
)
|
||||
for c, char in enumerate(line):
|
||||
pos: tuple[int, int] = (lineno, c)
|
||||
closings: list[str] = self.closings.get(pos, [])
|
||||
openings: list[str] = self.openings.get(pos, [])
|
||||
line_buf += "".join(closings + openings)
|
||||
line_buf += char
|
||||
line_buf += "".join(self.closings.get((lineno, len(line)), []))
|
||||
line_buf += "</div></div>"
|
||||
lines.append(" " + line_buf)
|
||||
lines.extend(
|
||||
[
|
||||
" </div>",
|
||||
"</body>",
|
||||
"</html>",
|
||||
]
|
||||
)
|
||||
|
||||
buf.write("\n".join(lines))
|
||||
|
||||
def wrap(self, node: Locatable, cls: str, message: Optional[str] = None):
|
||||
if node.location is None:
|
||||
return
|
||||
if node.location.end_lineno is None or node.location.end_col_offset is None:
|
||||
return
|
||||
start_pos: tuple[int, int] = (node.location.lineno, node.location.col_offset)
|
||||
end_pos: tuple[int, int] = (
|
||||
node.location.end_lineno,
|
||||
node.location.end_col_offset,
|
||||
)
|
||||
opening: str = f'<span class="{cls}" title="{cls}">'
|
||||
closing: str = "</span>"
|
||||
if message is not None:
|
||||
opening = f'<span class="with-msg">{opening}'
|
||||
closing = f'{closing}<span class="message">{message}</span></span>'
|
||||
|
||||
self.openings.setdefault(start_pos, []).append(opening)
|
||||
self.closings.setdefault(end_pos, []).insert(0, closing)
|
||||
if start_pos[0] != end_pos[0]:
|
||||
for l in range(start_pos[0], end_pos[0]):
|
||||
c: int = len(self.lines[l - 1])
|
||||
self.closings.setdefault((l, c), []).insert(0, closing)
|
||||
self.openings.setdefault((l + 1, 0), []).append(opening)
|
||||
|
||||
|
||||
class PythonHighlighter(
|
||||
Highlighter,
|
||||
p.MidasType.Visitor[None],
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[None],
|
||||
):
|
||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_python.css"
|
||||
|
||||
def highlight(self, node: Highlightable[PythonHighlighter]):
|
||||
node.accept(self)
|
||||
|
||||
def visit_base_type(self, node: p.BaseType) -> None:
|
||||
self.wrap(node, "base-type")
|
||||
if node.param is not None:
|
||||
self.wrap(node.param, "param")
|
||||
node.param.accept(self)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||
self.wrap(node, "constraint-type")
|
||||
node.type.accept(self)
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> None:
|
||||
self.wrap(node, "frame-column")
|
||||
if node.type is not None:
|
||||
node.type.accept(self)
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> None:
|
||||
self.wrap(node, "frame-type")
|
||||
for column in node.columns:
|
||||
column.accept(self)
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
stmt.expr.accept(self)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
self.wrap(stmt, "function")
|
||||
for arg in stmt.posonlyargs + stmt.args + stmt.kwonlyargs:
|
||||
self._highlight_function_argument(arg)
|
||||
for body_stmt in stmt.body:
|
||||
body_stmt.accept(self)
|
||||
|
||||
def _highlight_function_argument(self, arg: p.Function.Argument) -> None:
|
||||
self.wrap(arg, "argument")
|
||||
if arg.type is not None:
|
||||
arg.type.accept(self)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
for target in stmt.targets:
|
||||
target.accept(self)
|
||||
stmt.value.accept(self)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
self.wrap(stmt, "return")
|
||||
if stmt.value is not None:
|
||||
stmt.value.accept(self)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
self.wrap(stmt, "if")
|
||||
stmt.test.accept(self)
|
||||
for body_stmt in stmt.body:
|
||||
body_stmt.accept(self)
|
||||
for else_stmt in stmt.orelse:
|
||||
else_stmt.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> None: ...
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ...
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||
self.wrap(expr, "call")
|
||||
expr.callee.accept(self)
|
||||
for arg in expr.arguments:
|
||||
arg.accept(self)
|
||||
for arg in expr.keywords.values():
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> None: ...
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> None: ...
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> None: ...
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None: ...
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
|
||||
|
||||
|
||||
class MidasHighlighter(
|
||||
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
|
||||
):
|
||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"
|
||||
|
||||
def highlight(self, node: Highlightable[MidasHighlighter]):
|
||||
node.accept(self)
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
self.wrap(stmt, "type-stmt")
|
||||
self.wrap(LocatableToken(stmt.name), "type-name")
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None:
|
||||
self.wrap(stmt, "property")
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self.wrap(stmt, "extend")
|
||||
stmt.type.accept(self)
|
||||
for op in stmt.operations:
|
||||
op.accept(self)
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
|
||||
self.wrap(stmt, "op")
|
||||
self.wrap(LocatableToken(stmt.name), "op-name")
|
||||
stmt.operand.accept(self)
|
||||
stmt.result.accept(self)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||
self.wrap(stmt, "predicate")
|
||||
self.wrap(LocatableToken(stmt.name), "predicate-name")
|
||||
stmt.type.accept(self)
|
||||
stmt.condition.accept(self)
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
||||
self.wrap(expr, "logical-expr")
|
||||
expr.left.accept(self)
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
|
||||
self.wrap(expr, "binary-expr")
|
||||
expr.left.accept(self)
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
|
||||
self.wrap(expr, "unary-expr")
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> None:
|
||||
self.wrap(expr, "get-expr")
|
||||
expr.expr.accept(self)
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
|
||||
self.wrap(expr, "variable")
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
||||
expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> None:
|
||||
self.wrap(type, "named-type")
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> None:
|
||||
self.wrap(type, "generic-type")
|
||||
type.type.accept(self)
|
||||
for param in type.params:
|
||||
param.accept(self)
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
||||
self.wrap(type, "constraint-type")
|
||||
type.type.accept(self)
|
||||
type.constraint.accept(self)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> None:
|
||||
self.wrap(type, "complex-type")
|
||||
for prop in type.properties:
|
||||
prop.accept(self)
|
||||
|
||||
|
||||
class DiagnosticsHighlighter(Highlighter):
|
||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
|
||||
|
||||
def highlight(self, diagnostics: list[Diagnostic]):
|
||||
for diagnostic in diagnostics:
|
||||
self.wrap(diagnostic, str(diagnostic.type).lower(), diagnostic.message)
|
||||
39
midas/cli/hl_diagnostic.css
Normal file
39
midas/cli/hl_diagnostic.css
Normal file
@@ -0,0 +1,39 @@
|
||||
span {
|
||||
--opacity: 0.4;
|
||||
|
||||
&.error {
|
||||
--col: 255, 0, 0;
|
||||
}
|
||||
&.warning {
|
||||
--col: 250, 160, 0;
|
||||
}
|
||||
&.info {
|
||||
--col: 150, 190, 250;
|
||||
}
|
||||
|
||||
&.with-msg {
|
||||
position: relative;
|
||||
|
||||
.message {
|
||||
display: none;
|
||||
}
|
||||
|
||||
&:hover:not(:has(.with-msg:hover)) {
|
||||
.message {
|
||||
display: inline-block;
|
||||
}
|
||||
}
|
||||
|
||||
.message {
|
||||
position: absolute;
|
||||
top: calc(100% + 0.2em);
|
||||
left: -.2em;
|
||||
background-color: black;
|
||||
color: white;
|
||||
padding: 0.2em 0.4em;
|
||||
border-radius: .2em;
|
||||
z-index: 10;
|
||||
width: 300%;
|
||||
}
|
||||
}
|
||||
}
|
||||
52
midas/cli/hl_midas.css
Normal file
52
midas/cli/hl_midas.css
Normal file
@@ -0,0 +1,52 @@
|
||||
span {
|
||||
&.comment {
|
||||
--col: 200, 200, 200;
|
||||
color: rgb(110, 110, 110);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
&.named-type,
|
||||
&.generic-type,
|
||||
&.constraint-type,
|
||||
&.complex-type {
|
||||
--col: 150, 150, 150;
|
||||
}
|
||||
|
||||
&.constraint {
|
||||
--col: 233, 108, 108;
|
||||
}
|
||||
|
||||
&.property {
|
||||
--col: 233, 108, 176;
|
||||
}
|
||||
|
||||
&.extend {
|
||||
--col: 108, 197, 233;
|
||||
}
|
||||
|
||||
&.op {
|
||||
--col: 108, 148, 233;
|
||||
}
|
||||
|
||||
&.predicate {
|
||||
--col: 193, 108, 233;
|
||||
}
|
||||
|
||||
&.logical-expr,
|
||||
&.binary-expr,
|
||||
&.unary-expr,
|
||||
&.get-expr {
|
||||
--col: 123, 215, 193;
|
||||
}
|
||||
|
||||
&.template {
|
||||
--col: 163, 117, 71;
|
||||
}
|
||||
|
||||
&.type-name,
|
||||
&.op-name,
|
||||
&.predicate-name {
|
||||
--col: 200, 200, 200;
|
||||
font-weight: bold;
|
||||
}
|
||||
}
|
||||
29
midas/cli/hl_python.css
Normal file
29
midas/cli/hl_python.css
Normal file
@@ -0,0 +1,29 @@
|
||||
span {
|
||||
&.base-type {
|
||||
--col: 108, 233, 108;
|
||||
}
|
||||
|
||||
&.param {
|
||||
--col: 103, 192, 224;
|
||||
}
|
||||
|
||||
&.constraint-type {
|
||||
--col: 174, 200, 195;
|
||||
}
|
||||
|
||||
&.frame-column {
|
||||
--col: 216, 231, 81;
|
||||
}
|
||||
|
||||
&.frame-type {
|
||||
--col: 231, 46, 40;
|
||||
}
|
||||
|
||||
&.function {
|
||||
--col: 215, 103, 224;
|
||||
}
|
||||
|
||||
&.argument {
|
||||
--col: 103, 192, 224;
|
||||
}
|
||||
}
|
||||
254
midas/cli/main.py
Normal file
254
midas/cli/main.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, TextIO, get_args
|
||||
|
||||
import click
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
|
||||
from midas.checker.checker import Checker
|
||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||
from midas.checker.types import Type
|
||||
from midas.cli.ansi import Ansi
|
||||
from midas.cli.highlighter import (
|
||||
DiagnosticsHighlighter,
|
||||
Highlighter,
|
||||
LocatableToken,
|
||||
MidasHighlighter,
|
||||
PythonHighlighter,
|
||||
)
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token, TokenType
|
||||
from midas.parser.midas import MidasParser
|
||||
from midas.parser.python import PythonParser
|
||||
from midas.resolver.resolver import Resolver
|
||||
from midas.utils import UniversalJSONDumper
|
||||
|
||||
|
||||
@click.group()
|
||||
def midas():
|
||||
pass
|
||||
|
||||
|
||||
def print_diagnostic(lines: list[str], diagnostic: Diagnostic, indent: int = 4):
|
||||
"""Pretty-print a diagnostic, showing some context if possible
|
||||
|
||||
If the diagnostic concerns a specific part of one line, the line is shown
|
||||
with the affected part highlighted. The message is clearly printed under the
|
||||
line with an underline further indicating the target expression.
|
||||
|
||||
If multiple lines are concerned, no context is shown, only the
|
||||
diagnostic type, location and message
|
||||
|
||||
Args:
|
||||
lines (list[str]): source code lines
|
||||
diagnostic (Diagnostic): the diagnostic to print
|
||||
indent (int, optional): the number of spaces added before the target line to indent if from the location header. Defaults to 4.
|
||||
"""
|
||||
|
||||
loc: Location = diagnostic.location
|
||||
if loc.lineno != loc.end_lineno:
|
||||
print(diagnostic)
|
||||
return
|
||||
|
||||
start_offset: int = loc.col_offset
|
||||
end_offset: int = loc.end_col_offset or (start_offset + 1)
|
||||
|
||||
line: str = lines[loc.lineno - 1]
|
||||
before: str = line[:start_offset]
|
||||
after: str = line[end_offset:]
|
||||
|
||||
color: int = {
|
||||
DiagnosticType.ERROR: Ansi.RED,
|
||||
DiagnosticType.WARNING: Ansi.YELLOW,
|
||||
DiagnosticType.INFO: Ansi.CYAN,
|
||||
}.get(diagnostic.type, Ansi.WHITE)
|
||||
|
||||
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
|
||||
cursor: str = (
|
||||
" " * start_offset
|
||||
+ Ansi.FG(color)
|
||||
+ "~" * (end_offset - start_offset)
|
||||
+ "> "
|
||||
+ diagnostic.message
|
||||
+ Ansi.RESET
|
||||
)
|
||||
|
||||
indent_str: str = " " * indent
|
||||
print(diagnostic.location_str + ":")
|
||||
print(indent_str + before + subject + after)
|
||||
print(indent_str + cursor)
|
||||
print()
|
||||
|
||||
|
||||
@midas.command()
|
||||
@click.option("-l", "--highlight", type=click.File("w"))
|
||||
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||
@click.option("-v", "--verbose", is_flag=True)
|
||||
@click.argument("file", type=click.File("r"))
|
||||
def compile(
|
||||
highlight: Optional[TextIO],
|
||||
types: tuple[TextIO],
|
||||
verbose: bool,
|
||||
file: TextIO,
|
||||
):
|
||||
logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN)
|
||||
source: str = file.read()
|
||||
tree: ast.Module = ast.parse(source, filename=file.name)
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
resolver = Resolver()
|
||||
resolver.resolve(*stmts)
|
||||
types_paths: list[Path] = [Path(t.name).resolve() for t in types]
|
||||
checker = Checker(
|
||||
resolver.locals,
|
||||
source_path=Path(file.name).resolve(),
|
||||
types_paths=types_paths,
|
||||
)
|
||||
diagnostics: list[Diagnostic] = checker.check(stmts)
|
||||
lines: list[str] = source.split("\n")
|
||||
for diagnostic in diagnostics:
|
||||
print_diagnostic(lines, diagnostic)
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
json.dumps(
|
||||
UniversalJSONDumper.dump(
|
||||
checker.global_env,
|
||||
[("Environment", "_children")],
|
||||
lambda obj: isinstance(obj, get_args(Type)),
|
||||
),
|
||||
indent=4,
|
||||
)
|
||||
)
|
||||
if highlight is not None:
|
||||
highlighter = DiagnosticsHighlighter(source)
|
||||
highlighter.highlight(diagnostics)
|
||||
highlighter.dump(highlight)
|
||||
|
||||
|
||||
@midas.group()
|
||||
def utils():
|
||||
pass
|
||||
|
||||
|
||||
def dump_python_ast(tree: ast.Module) -> str:
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
printer = PythonAstPrinter()
|
||||
dump: str = ""
|
||||
for stmt in stmts:
|
||||
dump += printer.print(stmt)
|
||||
dump += "\n"
|
||||
return dump
|
||||
|
||||
|
||||
def dump_midas_ast(source: str, filename: str) -> str:
|
||||
lexer = MidasLexer(source, file=filename)
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
if len(parser.errors) != 0:
|
||||
for err in parser.errors:
|
||||
print(err.get_report())
|
||||
raise RuntimeError("A parsing error occurred")
|
||||
printer = MidasAstPrinter()
|
||||
dump: str = ""
|
||||
for stmt in stmts:
|
||||
dump += printer.print(stmt)
|
||||
dump += "\n"
|
||||
return dump
|
||||
|
||||
|
||||
@utils.command()
|
||||
@click.option("-o", "--output", type=click.File("w"))
|
||||
@click.option("-p", "--parse", is_flag=True)
|
||||
@click.argument("file", type=click.File("r"))
|
||||
def dump_ast(output: Optional[TextIO], parse: bool, file: TextIO):
|
||||
source: str = file.read()
|
||||
|
||||
dump: str
|
||||
if file.name.endswith(".py"):
|
||||
tree: ast.Module = ast.parse(source, filename=file.name)
|
||||
if parse:
|
||||
dump = dump_python_ast(tree)
|
||||
else:
|
||||
dump = ast.dump(tree, indent=4)
|
||||
elif file.name.endswith(".midas"):
|
||||
dump = dump_midas_ast(source, file.name)
|
||||
else:
|
||||
raise ValueError("Unsupported file type")
|
||||
|
||||
if output is None:
|
||||
click.echo(dump)
|
||||
else:
|
||||
output.write(dump)
|
||||
|
||||
|
||||
def highlight_python(source: str, path: str) -> Highlighter:
|
||||
tree: ast.Module = ast.parse(source, filename=path)
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
highlighter = PythonHighlighter(source)
|
||||
for stmt in stmts:
|
||||
highlighter.highlight(stmt)
|
||||
return highlighter
|
||||
|
||||
|
||||
def highlight_midas(source: str, path: str) -> Highlighter:
|
||||
lexer = MidasLexer(source, file=path)
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
highlighter = MidasHighlighter(source)
|
||||
for err in parser.errors:
|
||||
print(err.get_report())
|
||||
|
||||
for stmt in stmts:
|
||||
highlighter.highlight(stmt)
|
||||
for token in tokens:
|
||||
if token.type == TokenType.COMMENT:
|
||||
highlighter.wrap(LocatableToken(token), "comment")
|
||||
elif token.is_keyword:
|
||||
highlighter.wrap(LocatableToken(token), "keyword")
|
||||
return highlighter
|
||||
|
||||
|
||||
@utils.command()
|
||||
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
def highlight(output: TextIO, file: TextIO):
|
||||
source: str = file.read()
|
||||
highlighter: Highlighter
|
||||
|
||||
if file.name.endswith(".py"):
|
||||
highlighter = highlight_python(source, file.name)
|
||||
elif file.name.endswith(".midas"):
|
||||
highlighter = highlight_midas(source, file.name)
|
||||
else:
|
||||
raise ValueError("Unsupported file type")
|
||||
highlighter.dump(output)
|
||||
|
||||
|
||||
@midas.command()
|
||||
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
def format(output: TextIO, file: TextIO):
|
||||
source: str = file.read()
|
||||
printer = MidasPrinter()
|
||||
lexer = MidasLexer(source, file=file.name)
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
for err in parser.errors:
|
||||
print(err.get_report())
|
||||
for stmt in stmts:
|
||||
output.write(printer.print(stmt) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
midas()
|
||||
0
midas/lexer/__init__.py
Normal file
0
midas/lexer/__init__.py
Normal file
@@ -1,8 +1,15 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from lexer.position import Position
|
||||
from lexer.token import Token, TokenType
|
||||
from midas.lexer.position import Position
|
||||
from midas.lexer.token import Token, TokenType
|
||||
|
||||
|
||||
class MidasSyntaxError(Exception):
|
||||
def __init__(self, pos: Position, message: str):
|
||||
super().__init__(f"[ERROR] Error at {pos}: {message}")
|
||||
self.pos: Position = pos
|
||||
self.message: str = message
|
||||
|
||||
|
||||
class Lexer(ABC):
|
||||
@@ -38,9 +45,9 @@ class Lexer(ABC):
|
||||
msg (str): the error message
|
||||
|
||||
Raises:
|
||||
SyntaxError
|
||||
MidasSyntaxError
|
||||
"""
|
||||
raise SyntaxError(f"[ERROR] Error at {self.start_pos}: {msg}")
|
||||
raise MidasSyntaxError(self.start_pos, msg)
|
||||
|
||||
def process(self) -> list[Token]:
|
||||
"""Scan tokens out of the source text
|
||||
@@ -49,7 +56,7 @@ class Lexer(ABC):
|
||||
list[Token]: all the tokens that could be scanned
|
||||
|
||||
Raises:
|
||||
SyntaxError: if a syntax error is found
|
||||
MidasSyntaxError: if a syntax error is found
|
||||
"""
|
||||
self.scan_tokens()
|
||||
self.tokens.append(Token(TokenType.EOF, "", None, self.get_position()))
|
||||
@@ -1,6 +1,5 @@
|
||||
from lexer.base import Lexer
|
||||
from lexer.keyword import MIDAS_KEYWORDS
|
||||
from lexer.token import TokenType
|
||||
from midas.lexer.base import Lexer
|
||||
from midas.lexer.token import KEYWORDS, TokenType
|
||||
|
||||
|
||||
class MidasLexer(Lexer):
|
||||
@@ -31,30 +30,32 @@ class MidasLexer(Lexer):
|
||||
self.add_token(
|
||||
TokenType.EQUAL_EQUAL if self.match("=") else TokenType.EQUAL
|
||||
)
|
||||
case "!":
|
||||
if self.match("="):
|
||||
self.add_token(TokenType.BANG_EQUAL)
|
||||
else:
|
||||
self.error("Unexpected single bang. Did you mean '!=' ?")
|
||||
case "!" if self.match("="):
|
||||
self.add_token(TokenType.BANG_EQUAL)
|
||||
case ":":
|
||||
self.add_token(TokenType.COLON)
|
||||
case ".":
|
||||
self.add_token(TokenType.DOT)
|
||||
case "&":
|
||||
self.add_token(TokenType.AND)
|
||||
case "?":
|
||||
self.add_token(TokenType.QMARK)
|
||||
case ",":
|
||||
self.add_token(TokenType.COMMA)
|
||||
case "_":
|
||||
case "_" if not self.is_identifier_char(self.peek_next(), start=False):
|
||||
self.add_token(TokenType.UNDERSCORE)
|
||||
case "+":
|
||||
self.add_token(TokenType.PLUS)
|
||||
case "-" if self.match(">"):
|
||||
self.add_token(TokenType.ARROW)
|
||||
# case "+":
|
||||
# self.add_token(TokenType.PLUS)
|
||||
case "-":
|
||||
self.add_token(TokenType.MINUS)
|
||||
case "*":
|
||||
self.add_token(TokenType.STAR)
|
||||
case "/":
|
||||
if self.match("/"):
|
||||
self.scan_comment()
|
||||
elif self.match("*"):
|
||||
self.scan_comment_multiline()
|
||||
else:
|
||||
self.add_token(TokenType.SLASH)
|
||||
# case "*":
|
||||
# self.add_token(TokenType.STAR)
|
||||
case "/" if self.match("/"):
|
||||
self.scan_comment()
|
||||
case "/" if self.match("*"):
|
||||
self.scan_comment_multiline()
|
||||
case "\n":
|
||||
self.add_token(TokenType.NEWLINE)
|
||||
case " " | "\r" | "\t":
|
||||
@@ -69,7 +70,7 @@ class MidasLexer(Lexer):
|
||||
case _:
|
||||
if char.isdigit():
|
||||
self.scan_number()
|
||||
elif char.isalpha():
|
||||
elif self.is_identifier_char(char, start=True):
|
||||
self.scan_identifier()
|
||||
else:
|
||||
self.error("Unexpected character")
|
||||
@@ -98,11 +99,11 @@ class MidasLexer(Lexer):
|
||||
An identifier starts with a letter, followed by any number of
|
||||
alphanumerical characters or underscores
|
||||
"""
|
||||
while self.peek().isalnum() or self.peek() == "_":
|
||||
while self.is_identifier_char(self.peek(), start=False):
|
||||
self.advance()
|
||||
|
||||
lexeme: str = self.source[self.start : self.idx]
|
||||
token_type: TokenType = MIDAS_KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
|
||||
token_type: TokenType = KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
|
||||
self.add_token(token_type)
|
||||
|
||||
def scan_comment(self):
|
||||
@@ -129,3 +130,12 @@ class MidasLexer(Lexer):
|
||||
if not self.is_at_end():
|
||||
self.advance()
|
||||
self.add_token(TokenType.COMMENT)
|
||||
|
||||
def is_identifier_char(self, char: str, *, start: bool) -> bool:
|
||||
if char == "_":
|
||||
return True
|
||||
if char.isalpha():
|
||||
return True
|
||||
if not start and char.isdigit():
|
||||
return True
|
||||
return False
|
||||
@@ -5,6 +5,7 @@ from typing import Optional
|
||||
@dataclass(frozen=True)
|
||||
class Position:
|
||||
"""A simple structure to store the position of a token"""
|
||||
|
||||
file: Optional[str]
|
||||
line: int
|
||||
column: int
|
||||
104
midas/lexer/token.py
Normal file
104
midas/lexer/token.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Any
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.lexer.position import Position
|
||||
|
||||
|
||||
class TokenType(Enum):
|
||||
# Punctuation
|
||||
LEFT_PAREN = auto()
|
||||
RIGHT_PAREN = auto()
|
||||
LEFT_BRACKET = auto()
|
||||
RIGHT_BRACKET = auto()
|
||||
LEFT_BRACE = auto()
|
||||
RIGHT_BRACE = auto()
|
||||
COLON = auto()
|
||||
COMMA = auto()
|
||||
UNDERSCORE = auto()
|
||||
ARROW = auto()
|
||||
AND = auto()
|
||||
QMARK = auto()
|
||||
DOT = auto()
|
||||
|
||||
# Operators
|
||||
# PLUS = auto()
|
||||
MINUS = auto()
|
||||
# STAR = auto()
|
||||
# SLASH = auto()
|
||||
GREATER = auto()
|
||||
GREATER_EQUAL = auto()
|
||||
LESS = auto()
|
||||
LESS_EQUAL = auto()
|
||||
EQUAL = auto()
|
||||
EQUAL_EQUAL = auto()
|
||||
BANG_EQUAL = auto()
|
||||
|
||||
# Literals
|
||||
IDENTIFIER = auto()
|
||||
NUMBER = auto()
|
||||
TRUE = auto()
|
||||
FALSE = auto()
|
||||
NONE = auto()
|
||||
|
||||
# Keywords
|
||||
TYPE = auto()
|
||||
OP = auto()
|
||||
PREDICATE = auto()
|
||||
EXTEND = auto()
|
||||
WHERE = auto()
|
||||
|
||||
# Misc
|
||||
COMMENT = auto()
|
||||
WHITESPACE = auto()
|
||||
EOF = auto()
|
||||
NEWLINE = auto()
|
||||
|
||||
|
||||
KEYWORDS: dict[str, TokenType] = {
|
||||
"type": TokenType.TYPE,
|
||||
"op": TokenType.OP,
|
||||
"predicate": TokenType.PREDICATE,
|
||||
"extend": TokenType.EXTEND,
|
||||
"where": TokenType.WHERE,
|
||||
"true": TokenType.TRUE,
|
||||
"false": TokenType.FALSE,
|
||||
"none": TokenType.NONE,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Token:
|
||||
"""A scanned token"""
|
||||
|
||||
type: TokenType
|
||||
lexeme: str
|
||||
value: Any
|
||||
position: Position
|
||||
|
||||
def get_location(self) -> Location:
|
||||
lineno: int = self.position.line
|
||||
col_offset: int = self.position.column - 1
|
||||
end_lineno = lineno
|
||||
end_col_offset = col_offset
|
||||
for c in self.lexeme:
|
||||
end_col_offset += 1
|
||||
if c == "\n":
|
||||
end_lineno += 1
|
||||
end_col_offset = 0
|
||||
return Location(
|
||||
lineno=lineno,
|
||||
col_offset=col_offset,
|
||||
end_lineno=end_lineno,
|
||||
end_col_offset=end_col_offset,
|
||||
)
|
||||
|
||||
def location_to(self, to: Token) -> Location:
|
||||
return Location.span(self.get_location(), to.get_location())
|
||||
|
||||
@property
|
||||
def is_keyword(self) -> bool:
|
||||
return self.lexeme in KEYWORDS
|
||||
@@ -2,8 +2,8 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from lexer.token import Token, TokenType
|
||||
from parser.errors import ParsingError
|
||||
from midas.lexer.token import Token, TokenType
|
||||
from midas.parser.errors import ParsingError
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
447
midas/parser/midas.py
Normal file
447
midas/parser/midas.py
Normal file
@@ -0,0 +1,447 @@
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.midas import (
|
||||
BinaryExpr,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
GenericType,
|
||||
GetExpr,
|
||||
GroupingExpr,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
NamedType,
|
||||
OpStmt,
|
||||
PredicateStmt,
|
||||
PropertyStmt,
|
||||
Stmt,
|
||||
Type,
|
||||
TypeStmt,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
WildcardExpr,
|
||||
)
|
||||
from midas.lexer.token import Token, TokenType
|
||||
from midas.parser.base import Parser
|
||||
from midas.parser.errors import ParsingError
|
||||
|
||||
|
||||
class MidasParser(Parser):
|
||||
"""A simple parser for midas type definitions"""
|
||||
|
||||
SYNC_BOUNDARY: set[TokenType] = {
|
||||
TokenType.TYPE,
|
||||
TokenType.OP,
|
||||
TokenType.EXTEND,
|
||||
TokenType.PREDICATE,
|
||||
}
|
||||
|
||||
def parse(self) -> list[Stmt]:
|
||||
statements: list[Stmt] = []
|
||||
while not self.is_at_end():
|
||||
stmt: Optional[Stmt] = self.declaration()
|
||||
if stmt is None:
|
||||
print("Early stop")
|
||||
break
|
||||
statements.append(stmt)
|
||||
return statements
|
||||
|
||||
def synchronize(self):
|
||||
"""Skip tokens until a synchronization boundary is found
|
||||
|
||||
This method allows gracefully recovering from a parse error
|
||||
to a safe place and continue parsing
|
||||
"""
|
||||
self.advance()
|
||||
while not self.is_at_end():
|
||||
if self.previous().type == TokenType.NEWLINE:
|
||||
return
|
||||
if self.peek().type in self.SYNC_BOUNDARY:
|
||||
return
|
||||
self.advance()
|
||||
|
||||
def declaration(self) -> Optional[Stmt]:
|
||||
"""Try and parse a declaration
|
||||
|
||||
Any parsing error is caught and None is returned
|
||||
|
||||
Returns:
|
||||
Optional[Stmt]: the parsed Midas statement, or None if a ParsingError was raised
|
||||
"""
|
||||
try:
|
||||
if self.match(TokenType.TYPE):
|
||||
return self.type_declaration()
|
||||
if self.match(TokenType.EXTEND):
|
||||
return self.extend_declaration()
|
||||
if self.match(TokenType.PREDICATE):
|
||||
return self.predicate_declaration()
|
||||
raise self.error(self.peek(), "Unexpected token")
|
||||
except ParsingError:
|
||||
self.synchronize()
|
||||
return None
|
||||
|
||||
def type_declaration(self) -> TypeStmt:
|
||||
"""Parse a type declaration
|
||||
|
||||
A type declaration can either be a simple type alias or a new complex type.
|
||||
In either case, it can have an optional template expression after its name, wrapped in brackets.
|
||||
A simple type alias is derived from a base type expression, and can have a optional constraint expression preceded by the `where` keyword.
|
||||
A full simple type alias is thus written:
|
||||
```
|
||||
type Name[Template](TypeExpr) where Condition
|
||||
```
|
||||
|
||||
A new complex type has a set of properties which are named, have a type and an optional constraint expression (also preceded by the `where` keyword).
|
||||
A full complex type definition is thus written:
|
||||
```
|
||||
type Name[Template] {
|
||||
prop1: TypeExpr1 where Condition1
|
||||
prop2: TypeExpr2 where Condition2
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
Returns:
|
||||
TypeStmt: the parsed type declaration statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
params: list[TypeStmt.Param] = []
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
params = self.type_stmt_params()
|
||||
|
||||
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
|
||||
|
||||
type: Type = self.type_expr()
|
||||
|
||||
return TypeStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
name=name,
|
||||
params=params,
|
||||
type=type,
|
||||
)
|
||||
|
||||
def type_stmt_params(self) -> list[TypeStmt.Param]:
|
||||
"""Parse a generic template expression
|
||||
|
||||
A template is written `[TypeExpr]`
|
||||
|
||||
Returns:
|
||||
TemplateExpr: the parsed template expression
|
||||
"""
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
|
||||
params: list[TypeStmt.Param] = []
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
|
||||
bound: Optional[Type] = None
|
||||
if self.match(TokenType.LESS):
|
||||
self.consume(TokenType.COLON, "Expected ':' after '<'")
|
||||
bound = self.type_expr()
|
||||
params.append(
|
||||
TypeStmt.Param(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
bound=bound,
|
||||
)
|
||||
)
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
|
||||
return params
|
||||
|
||||
def type_expr(self) -> Type:
|
||||
"""Parse a type expression
|
||||
|
||||
A type is an identifier, optionally followed by a template expression.
|
||||
It can also optionally be followed by a '?' to indicate a nullable type
|
||||
|
||||
Returns:
|
||||
TypeExpr: the parsed type expression
|
||||
"""
|
||||
return self.constraint_type()
|
||||
|
||||
def constraint_type(self) -> Type:
|
||||
type: Type = self.base_type()
|
||||
if self.match(TokenType.WHERE):
|
||||
constraint: Expr = self.constraint()
|
||||
return ConstraintType(
|
||||
location=Location.span(type.location, constraint.location),
|
||||
type=type,
|
||||
constraint=constraint,
|
||||
)
|
||||
return type
|
||||
|
||||
def base_type(self) -> Type:
|
||||
if self.match(TokenType.LEFT_PAREN):
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
|
||||
return type
|
||||
|
||||
if self.check(TokenType.LEFT_BRACE):
|
||||
return self.complex_type()
|
||||
|
||||
return self.generic_type()
|
||||
|
||||
def generic_type(self) -> Type:
|
||||
type: Type = self.named_type()
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
params: list[Type] = self.type_params()
|
||||
return GenericType(
|
||||
location=Location.span(type.location, self.previous().get_location()),
|
||||
type=type,
|
||||
params=params,
|
||||
)
|
||||
return type
|
||||
|
||||
def type_params(self) -> list[Type]:
|
||||
params: list[Type] = []
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters")
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||
params.append(self.type_expr())
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters")
|
||||
return params
|
||||
|
||||
def named_type(self) -> Type:
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
return NamedType(
|
||||
location=name.get_location(),
|
||||
name=name,
|
||||
)
|
||||
|
||||
def complex_type(self) -> Type:
|
||||
"""Parse a type definition body
|
||||
|
||||
A type definition body is a set of whitespace-separated
|
||||
property statements enclosed in curly braces
|
||||
|
||||
Returns:
|
||||
list[PropertyStmt]: the parsed type properties
|
||||
"""
|
||||
left: Token = self.consume(
|
||||
TokenType.LEFT_BRACE, "Expected '{' to start type body"
|
||||
)
|
||||
properties: list[PropertyStmt] = []
|
||||
names: set[str] = set()
|
||||
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
||||
prop: PropertyStmt = self.property_stmt()
|
||||
if prop.name.lexeme in names:
|
||||
raise self.error(prop.name, "Duplicate property")
|
||||
names.add(prop.name.lexeme)
|
||||
properties.append(prop)
|
||||
right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
|
||||
return ComplexType(
|
||||
location=left.location_to(right),
|
||||
properties=properties,
|
||||
)
|
||||
|
||||
def constraint(self) -> Expr:
|
||||
"""Parse a constraint
|
||||
|
||||
A constraint is basically a logical predicate
|
||||
|
||||
Returns:
|
||||
Expr: the parsed constraint expression
|
||||
"""
|
||||
return self.and_()
|
||||
|
||||
def and_(self) -> Expr:
|
||||
"""Parse a logical AND expression or a simpler expression
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.equality()
|
||||
while self.match(TokenType.AND):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.equality()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = LogicalExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
return expr
|
||||
|
||||
def equality(self) -> Expr:
|
||||
"""Parse a logical equality expression or a simpler expression
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.comparison()
|
||||
while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.comparison()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = BinaryExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
return expr
|
||||
|
||||
def comparison(self) -> Expr:
|
||||
"""Parse a logical comparison expression or a simpler expression
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.unary()
|
||||
while self.match(
|
||||
TokenType.LESS,
|
||||
TokenType.LESS_EQUAL,
|
||||
TokenType.GREATER,
|
||||
TokenType.GREATER_EQUAL,
|
||||
):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.unary()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = BinaryExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
return expr
|
||||
|
||||
def unary(self) -> Expr:
|
||||
"""Parse a unary expression or a simpler expression
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
if self.match(TokenType.MINUS):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.unary()
|
||||
location: Location = Location.span(operator.get_location(), right.location)
|
||||
return UnaryExpr(location=location, operator=operator, right=right)
|
||||
return self.reference()
|
||||
|
||||
def reference(self) -> Expr:
|
||||
"""Parse an attribute access expression or a simpler expression
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.primary()
|
||||
while self.match(TokenType.DOT):
|
||||
name: Token = self.consume(
|
||||
TokenType.IDENTIFIER, "Expected property name after '.'"
|
||||
)
|
||||
location: Location = Location.span(expr.location, name.get_location())
|
||||
expr = GetExpr(location=location, expr=expr, name=name)
|
||||
return expr
|
||||
|
||||
def primary(self) -> Expr:
|
||||
"""Parse a primary expression
|
||||
|
||||
This includes literals (booleans, numbers, etc.), wildcards, identifiers and grouped expressions
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
token: Token = self.peek()
|
||||
if self.match(TokenType.FALSE):
|
||||
return LiteralExpr(location=token.get_location(), value=False)
|
||||
if self.match(TokenType.TRUE):
|
||||
return LiteralExpr(location=token.get_location(), value=True)
|
||||
if self.match(TokenType.NONE):
|
||||
return LiteralExpr(location=token.get_location(), value=None)
|
||||
|
||||
if self.match(TokenType.NUMBER):
|
||||
return LiteralExpr(location=token.get_location(), value=token.value)
|
||||
|
||||
if self.match(TokenType.IDENTIFIER):
|
||||
return VariableExpr(location=token.get_location(), name=token)
|
||||
|
||||
if self.match(TokenType.UNDERSCORE):
|
||||
return WildcardExpr(location=token.get_location(), token=token)
|
||||
|
||||
if self.match(TokenType.LEFT_PAREN):
|
||||
expr: Expr = self.constraint()
|
||||
right: Token = self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
|
||||
return GroupingExpr(location=token.location_to(right), expr=expr)
|
||||
|
||||
raise self.error(self.peek(), "Expected expression")
|
||||
|
||||
def property_stmt(self) -> PropertyStmt:
|
||||
"""Parse a property statement
|
||||
|
||||
A type property statement is written `name: Type` or `name: Type where Condition`
|
||||
|
||||
Returns:
|
||||
PropertyStmt: the parsed property statement
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after property name")
|
||||
type: Type = self.type_expr()
|
||||
return PropertyStmt(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
type=type,
|
||||
)
|
||||
|
||||
def extend_declaration(self) -> ExtendStmt:
|
||||
"""Parse an extension definition
|
||||
|
||||
An extension is written `extend Type { operations }`
|
||||
|
||||
Returns:
|
||||
ExtendStmt: the parsed extension statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
|
||||
operations: list[OpStmt] = []
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
|
||||
operations.append(self.op_declaration())
|
||||
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
|
||||
location: Location = keyword.location_to(self.previous())
|
||||
return ExtendStmt(location=location, type=type, operations=operations)
|
||||
|
||||
def op_declaration(self) -> OpStmt:
|
||||
"""Parse an operation definition
|
||||
|
||||
An operation is written `op name(Type) -> Type`
|
||||
|
||||
Returns:
|
||||
OpStmt: the parsed operation statement
|
||||
"""
|
||||
keyword: Token = self.consume(TokenType.OP, "Expected 'op' keyword")
|
||||
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name")
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type")
|
||||
operand: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type")
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
result: Type = self.type_expr()
|
||||
|
||||
return OpStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
name=name,
|
||||
operand=operand,
|
||||
result=result,
|
||||
)
|
||||
|
||||
def predicate_declaration(self) -> PredicateStmt:
|
||||
"""Parse a predicate declaration
|
||||
|
||||
A predicate is written `predicate Name(subject: Type) = constraint_expression`
|
||||
|
||||
Returns:
|
||||
PredicateStmt: the parsed predicate declaration statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name")
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
||||
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
||||
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
|
||||
condition: Expr = self.constraint()
|
||||
return PredicateStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
name=name,
|
||||
subject=subject,
|
||||
type=type,
|
||||
condition=condition,
|
||||
)
|
||||
502
midas/parser/python.py
Normal file
502
midas/parser/python.py
Normal file
@@ -0,0 +1,502 @@
|
||||
import ast
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.python import (
|
||||
AssignStmt,
|
||||
BaseType,
|
||||
BinaryExpr,
|
||||
CallExpr,
|
||||
CastExpr,
|
||||
CompareExpr,
|
||||
ConstraintType,
|
||||
Expr,
|
||||
ExpressionStmt,
|
||||
FrameColumn,
|
||||
FrameType,
|
||||
Function,
|
||||
GetExpr,
|
||||
IfStmt,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
MidasType,
|
||||
ReturnStmt,
|
||||
Stmt,
|
||||
TernaryExpr,
|
||||
TypeAssign,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
)
|
||||
|
||||
|
||||
class InvalidSyntaxError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedSyntaxError(Exception):
|
||||
def __init__(self, expr: ast.expr) -> None:
|
||||
super().__init__(
|
||||
f"Unsupported syntax at L{expr.lineno}:{expr.col_offset}: {ast.unparse(expr)}"
|
||||
)
|
||||
|
||||
|
||||
class PythonParser:
|
||||
CAST_FUNCTION = "cast"
|
||||
|
||||
def parse_module(self, node: ast.Module) -> list[Stmt]:
|
||||
statements: list[Stmt] = []
|
||||
for stmt in node.body:
|
||||
try:
|
||||
parsed: None | Stmt | list[Stmt] = self.parse_stmt(stmt)
|
||||
if isinstance(parsed, Stmt):
|
||||
statements.append(parsed)
|
||||
elif parsed is not None:
|
||||
statements.extend(parsed)
|
||||
except UnsupportedSyntaxError as e:
|
||||
print(f"{e}, skipping")
|
||||
continue
|
||||
return statements
|
||||
|
||||
def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]:
|
||||
location: Location = Location.from_ast(node)
|
||||
match node:
|
||||
case ast.AnnAssign():
|
||||
return self.parse_annotation_assign(node)
|
||||
|
||||
case ast.Assign():
|
||||
return self.parse_assign(node)
|
||||
|
||||
case ast.AugAssign():
|
||||
return self.parse_aug_assign(node)
|
||||
|
||||
case ast.FunctionDef():
|
||||
return self.parse_function(node)
|
||||
|
||||
case ast.Expr(value=expr):
|
||||
return ExpressionStmt(
|
||||
location=location,
|
||||
expr=self.parse_expr(expr),
|
||||
)
|
||||
|
||||
case ast.Return(value=value):
|
||||
return ReturnStmt(
|
||||
location=location,
|
||||
value=self.parse_expr(value) if value is not None else None,
|
||||
)
|
||||
|
||||
case ast.If():
|
||||
return self.parse_if(node)
|
||||
|
||||
case ast.Pass():
|
||||
return None
|
||||
|
||||
case _:
|
||||
print(f"Unsupported statement: {ast.unparse(node)}")
|
||||
return None
|
||||
|
||||
def parse_annotation_assign(self, node: ast.AnnAssign) -> list[Stmt]:
|
||||
statements: list[Stmt] = []
|
||||
loc: Location = Location.from_ast(node)
|
||||
match node:
|
||||
case ast.AnnAssign(
|
||||
target=ast.Name(id=target),
|
||||
annotation=annotation,
|
||||
value=value,
|
||||
simple=1,
|
||||
):
|
||||
type = self._parse_type(annotation)
|
||||
statements.append(
|
||||
TypeAssign(
|
||||
location=loc,
|
||||
name=target,
|
||||
type=type,
|
||||
)
|
||||
)
|
||||
|
||||
if value is not None:
|
||||
statements.append(
|
||||
AssignStmt(
|
||||
location=loc,
|
||||
targets=[
|
||||
VariableExpr(
|
||||
location=Location.from_ast(node.target), name=target
|
||||
),
|
||||
],
|
||||
value=self.parse_expr(value),
|
||||
),
|
||||
)
|
||||
case _:
|
||||
print(f"Unsupported annotation: {ast.unparse(node)}")
|
||||
return statements
|
||||
|
||||
def parse_assign(self, node: ast.Assign) -> AssignStmt:
|
||||
targets: list[Expr] = []
|
||||
for target in node.targets:
|
||||
targets.append(self.parse_expr(target))
|
||||
value: Expr = self.parse_expr(node.value)
|
||||
return AssignStmt(
|
||||
location=Location.from_ast(node),
|
||||
targets=targets,
|
||||
value=value,
|
||||
)
|
||||
|
||||
def parse_aug_assign(self, node: ast.AugAssign) -> AssignStmt:
|
||||
location: Location = Location.from_ast(node)
|
||||
target: Expr = self.parse_expr(node.target)
|
||||
value: Expr = self.parse_expr(node.value)
|
||||
return AssignStmt(
|
||||
location=location,
|
||||
targets=[target],
|
||||
value=BinaryExpr(
|
||||
location=location,
|
||||
left=target,
|
||||
operator=node.op,
|
||||
right=value,
|
||||
),
|
||||
)
|
||||
|
||||
def parse_if(self, node: ast.If) -> IfStmt:
|
||||
body: list[Stmt] = []
|
||||
for stmt in node.body:
|
||||
stmts = self.parse_stmt(stmt)
|
||||
if isinstance(stmts, Stmt):
|
||||
body.append(stmts)
|
||||
elif stmts is not None:
|
||||
body.extend(stmts)
|
||||
|
||||
orelse: list[Stmt] = []
|
||||
for stmt in node.orelse:
|
||||
stmts = self.parse_stmt(stmt)
|
||||
if isinstance(stmts, Stmt):
|
||||
orelse.append(stmts)
|
||||
elif stmts is not None:
|
||||
orelse.extend(stmts)
|
||||
|
||||
return IfStmt(
|
||||
location=Location.from_ast(node),
|
||||
test=self.parse_expr(node.test),
|
||||
body=body,
|
||||
orelse=orelse,
|
||||
)
|
||||
|
||||
def parse_function(self, node: ast.FunctionDef) -> Function:
|
||||
loc: Location = Location.from_ast(node)
|
||||
match node:
|
||||
case ast.FunctionDef(
|
||||
name=name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=posonlyargs,
|
||||
args=args,
|
||||
vararg=sink,
|
||||
kwonlyargs=kwonlyargs,
|
||||
kwarg=kw_sink,
|
||||
defaults=defaults,
|
||||
kw_defaults=kw_defaults,
|
||||
),
|
||||
returns=returns,
|
||||
body=raw_body,
|
||||
):
|
||||
|
||||
def parse_args(
|
||||
args_list: list[ast.arg], defaults: list[Optional[Expr]]
|
||||
) -> list[Function.Argument]:
|
||||
return [
|
||||
self._parse_function_argument(arg, default)
|
||||
for arg, default in zip(args_list, defaults)
|
||||
]
|
||||
|
||||
body: list[Stmt] = []
|
||||
for stmt in raw_body:
|
||||
stmts = self.parse_stmt(stmt)
|
||||
if isinstance(stmts, Stmt):
|
||||
body.append(stmts)
|
||||
elif stmts is not None:
|
||||
body.extend(stmts)
|
||||
|
||||
parsed_defaults: list[Optional[Expr]] = [
|
||||
self.parse_expr(default) for default in defaults
|
||||
]
|
||||
n_posargs: int = len(posonlyargs)
|
||||
n_args: int = len(args)
|
||||
n_all_posargs = n_posargs + n_args
|
||||
parsed_defaults = [
|
||||
None,
|
||||
] * (n_all_posargs - len(defaults)) + parsed_defaults
|
||||
|
||||
posargs_defaults: list[Optional[Expr]] = parsed_defaults[:n_posargs]
|
||||
args_defaults: list[Optional[Expr]] = parsed_defaults[n_posargs:]
|
||||
kwargs_defaults: list[Optional[Expr]] = [
|
||||
self.parse_expr(default) if default is not None else None
|
||||
for default in kw_defaults
|
||||
]
|
||||
|
||||
return Function(
|
||||
location=loc,
|
||||
name=name,
|
||||
posonlyargs=parse_args(posonlyargs, posargs_defaults),
|
||||
args=parse_args(args, args_defaults),
|
||||
sink=(
|
||||
self._parse_function_argument(sink, None)
|
||||
if sink is not None
|
||||
else None
|
||||
),
|
||||
kwonlyargs=parse_args(kwonlyargs, kwargs_defaults),
|
||||
kw_sink=(
|
||||
self._parse_function_argument(kw_sink, None)
|
||||
if kw_sink is not None
|
||||
else None
|
||||
),
|
||||
returns=self._parse_type(returns) if returns is not None else None,
|
||||
body=body,
|
||||
)
|
||||
case _:
|
||||
print(f"Unsupported function definition: {ast.unparse(node)}")
|
||||
|
||||
def _parse_function_argument(
|
||||
self, arg: ast.arg, default: Optional[Expr]
|
||||
) -> Function.Argument:
|
||||
loc: Location = Location.from_ast(arg)
|
||||
name: str = arg.arg
|
||||
type: Optional[MidasType] = None
|
||||
if arg.annotation is not None:
|
||||
type = self._parse_type(arg.annotation)
|
||||
return Function.Argument(
|
||||
location=loc,
|
||||
name=name,
|
||||
type=type,
|
||||
default=default,
|
||||
)
|
||||
|
||||
def _parse_type(self, type_expr: ast.expr) -> MidasType:
|
||||
loc: Location = Location.from_ast(type_expr)
|
||||
match type_expr:
|
||||
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
|
||||
return self._parse_frame_type(schema)
|
||||
|
||||
case ast.Subscript(value=ast.Name(id=name), slice=param):
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base=name,
|
||||
param=self._parse_type(param),
|
||||
)
|
||||
|
||||
case ast.Name(id=name):
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base=name,
|
||||
param=None,
|
||||
)
|
||||
|
||||
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
|
||||
left = self._parse_type(left_expr)
|
||||
match left:
|
||||
case None:
|
||||
raise InvalidSyntaxError()
|
||||
|
||||
# If chained constraints, separate base type and rebuild constraint
|
||||
case ConstraintType(type=left_type, constraint=left_constraint):
|
||||
constraint = ast.BinOp(
|
||||
left=left_constraint,
|
||||
op=ast.Add(),
|
||||
right=right_expr,
|
||||
)
|
||||
ast.copy_location(constraint, type_expr)
|
||||
return ConstraintType(
|
||||
location=loc,
|
||||
type=left_type,
|
||||
constraint=constraint,
|
||||
)
|
||||
|
||||
case _:
|
||||
return ConstraintType(
|
||||
location=loc,
|
||||
type=left,
|
||||
constraint=right_expr,
|
||||
)
|
||||
|
||||
case ast.Constant(value=None):
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base="None",
|
||||
param=None,
|
||||
)
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(type_expr)
|
||||
|
||||
def _parse_frame_type(self, schema: ast.expr) -> FrameType:
|
||||
loc: Location = Location.from_ast(schema)
|
||||
columns: list[FrameColumn] = []
|
||||
|
||||
match schema:
|
||||
case ast.Tuple(elts=cols):
|
||||
for col in cols:
|
||||
columns.append(self._parse_frame_column(col))
|
||||
|
||||
case ast.Slice() | ast.Name():
|
||||
columns.append(self._parse_frame_column(schema))
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(schema)
|
||||
|
||||
return FrameType(location=loc, columns=columns)
|
||||
|
||||
def _parse_frame_column(self, column: ast.expr) -> FrameColumn:
|
||||
loc: Location = Location.from_ast(column)
|
||||
match column:
|
||||
case ast.Name():
|
||||
return FrameColumn(
|
||||
location=loc,
|
||||
name=None,
|
||||
type=self._parse_type(column),
|
||||
)
|
||||
|
||||
case ast.Slice(lower=ast.Name(id=name), upper=type_expr):
|
||||
if name == "_":
|
||||
name = None
|
||||
|
||||
type: Optional[MidasType] = None
|
||||
match type_expr:
|
||||
case None:
|
||||
raise InvalidSyntaxError("Missing column type")
|
||||
case ast.Name(id="_"):
|
||||
type = None
|
||||
case ast.expr():
|
||||
type = self._parse_type(type_expr)
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(type_expr)
|
||||
return FrameColumn(location=loc, name=name, type=type)
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(column)
|
||||
|
||||
def parse_expr(self, node: ast.expr) -> Expr:
|
||||
location: Location = Location.from_ast(node)
|
||||
match node:
|
||||
case ast.BoolOp():
|
||||
return self.parse_bool_op(node)
|
||||
|
||||
case ast.BinOp(left=left, op=op, right=right):
|
||||
return BinaryExpr(
|
||||
location=location,
|
||||
left=self.parse_expr(left),
|
||||
operator=op,
|
||||
right=self.parse_expr(right),
|
||||
)
|
||||
|
||||
case ast.UnaryOp(op=op, operand=right):
|
||||
return UnaryExpr(
|
||||
location=location,
|
||||
operator=op,
|
||||
right=self.parse_expr(right),
|
||||
)
|
||||
|
||||
case ast.Compare():
|
||||
return self.parse_compare(node)
|
||||
|
||||
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
|
||||
return self.parse_cast(node)
|
||||
|
||||
case ast.Call():
|
||||
return self.parse_call(node)
|
||||
|
||||
case ast.IfExp():
|
||||
return self.parse_ternary(node)
|
||||
|
||||
case ast.Constant(value=value):
|
||||
return LiteralExpr(location=location, value=value)
|
||||
|
||||
case ast.Attribute(value=object, attr=name):
|
||||
return GetExpr(
|
||||
location=location,
|
||||
object=self.parse_expr(object),
|
||||
name=name,
|
||||
)
|
||||
|
||||
case ast.Name(id=name):
|
||||
return VariableExpr(location=location, name=name)
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(node)
|
||||
|
||||
def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr:
|
||||
op: ast.boolop = node.op
|
||||
rights: list[Expr] = [self.parse_expr(expr) for expr in node.values]
|
||||
expr: LogicalExpr = LogicalExpr(
|
||||
location=Location.span(
|
||||
rights[0].location,
|
||||
rights[1].location,
|
||||
),
|
||||
left=rights[0],
|
||||
operator=op,
|
||||
right=rights[1],
|
||||
)
|
||||
for right in rights[2:]:
|
||||
expr = LogicalExpr(
|
||||
location=Location.span(expr.location, right.location),
|
||||
left=expr,
|
||||
operator=op,
|
||||
right=right,
|
||||
)
|
||||
return expr
|
||||
|
||||
def parse_compare(self, node: ast.Compare) -> Expr:
|
||||
ops: list[ast.cmpop] = node.ops
|
||||
left: Expr = self.parse_expr(node.left)
|
||||
rights: list[Expr] = [self.parse_expr(expr) for expr in node.comparators]
|
||||
expr: Expr = CompareExpr(
|
||||
location=Location.span(
|
||||
left.location,
|
||||
rights[0].location,
|
||||
),
|
||||
left=left,
|
||||
operator=ops[0],
|
||||
right=rights[0],
|
||||
)
|
||||
for i, right in enumerate(rights[1:]):
|
||||
comparison = CompareExpr(
|
||||
location=Location.span(rights[i].location, right.location),
|
||||
left=rights[i],
|
||||
operator=ops[i],
|
||||
right=right,
|
||||
)
|
||||
expr = LogicalExpr(
|
||||
location=Location.span(expr.location, comparison.location),
|
||||
left=expr,
|
||||
operator=ast.And(),
|
||||
right=comparison,
|
||||
)
|
||||
return expr
|
||||
|
||||
def parse_cast(self, node: ast.Call) -> CastExpr:
|
||||
match node:
|
||||
case ast.Call(args=[type, expr], keywords=[]):
|
||||
return CastExpr(
|
||||
location=Location.from_ast(node),
|
||||
type=self._parse_type(type),
|
||||
expr=self.parse_expr(expr),
|
||||
)
|
||||
case _:
|
||||
raise InvalidSyntaxError(
|
||||
f"Invalid call to {self.CAST_FUNCTION}, expected type and expression"
|
||||
)
|
||||
|
||||
def parse_call(self, node: ast.Call) -> CallExpr:
|
||||
return CallExpr(
|
||||
location=Location.from_ast(node),
|
||||
callee=self.parse_expr(node.func),
|
||||
arguments=[self.parse_expr(arg) for arg in node.args],
|
||||
keywords={
|
||||
arg.arg: self.parse_expr(arg.value)
|
||||
for arg in node.keywords
|
||||
if arg.arg is not None # Should always be True, type checker happy
|
||||
},
|
||||
)
|
||||
|
||||
def parse_ternary(self, node: ast.IfExp) -> TernaryExpr:
|
||||
return TernaryExpr(
|
||||
location=Location.from_ast(node),
|
||||
test=self.parse_expr(node.test),
|
||||
if_true=self.parse_expr(node.body),
|
||||
if_false=self.parse_expr(node.orelse),
|
||||
)
|
||||
72
midas/resolver/builtin.py
Normal file
72
midas/resolver/builtin.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from midas.checker.types import BaseType, Type, UnitType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.resolver.midas import MidasResolver
|
||||
|
||||
|
||||
def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
|
||||
ctx.define_operation(
|
||||
left=t1,
|
||||
operator=operator,
|
||||
right=t2,
|
||||
result=t3,
|
||||
)
|
||||
|
||||
|
||||
def basic_op(ctx: MidasResolver, type: Type, op: str):
|
||||
ctx.define_operation(
|
||||
left=type,
|
||||
operator=op,
|
||||
right=type,
|
||||
result=type,
|
||||
)
|
||||
|
||||
|
||||
def define_builtins(ctx: MidasResolver):
|
||||
"""Define builtin types and operations"""
|
||||
unit = ctx.define_type("None", UnitType())
|
||||
bool = ctx.define_type("bool", BaseType(name="bool"))
|
||||
int = ctx.define_type("int", BaseType(name="int"))
|
||||
float = ctx.define_type("float", BaseType(name="float"))
|
||||
str = ctx.define_type("str", BaseType(name="str"))
|
||||
|
||||
basic_op(ctx, int, "__add__") # int + int = int
|
||||
basic_op(ctx, int, "__sub__") # int - int = int
|
||||
basic_op(ctx, int, "__mul__") # int * int = int
|
||||
basic_op(ctx, int, "__pow__") # int ** int = int
|
||||
basic_op(ctx, int, "__mod__") # int % int = int
|
||||
basic_op(ctx, int, "__and__") # int & int = int
|
||||
basic_op(ctx, int, "__or__") # int | int = int
|
||||
basic_op(ctx, int, "__xor__") # int ^ int = int
|
||||
op(ctx, int, "__lt__", int, bool) # int < int = bool
|
||||
op(ctx, int, "__gt__", int, bool) # int > int = bool
|
||||
op(ctx, int, "__le__", int, bool) # int <= int = bool
|
||||
op(ctx, int, "__ge__", int, bool) # int >= int = bool
|
||||
op(ctx, int, "__eq__", int, bool) # int == int = bool
|
||||
basic_op(ctx, float, "__add__") # float + float = float
|
||||
basic_op(ctx, float, "__sub__") # float - float = float
|
||||
basic_op(ctx, float, "__mul__") # float * float = float
|
||||
basic_op(ctx, float, "__truediv__") # float / float = float
|
||||
op(ctx, float, "__lt__", float, bool) # float < float = bool
|
||||
op(ctx, float, "__gt__", float, bool) # float > float = bool
|
||||
op(ctx, float, "__le__", float, bool) # float <= float = bool
|
||||
op(ctx, float, "__ge__", float, bool) # float >= float = bool
|
||||
op(ctx, float, "__eq__", float, bool) # float == float = bool
|
||||
basic_op(ctx, str, "__add__") # str + str = str
|
||||
op(ctx, str, "__eq__", str, bool) # str == str = bool
|
||||
|
||||
op(ctx, int, "__lt__", float, bool) # int < float = bool
|
||||
op(ctx, int, "__gt__", float, bool) # int > float = bool
|
||||
op(ctx, int, "__le__", float, bool) # int <= float = bool
|
||||
op(ctx, int, "__ge__", float, bool) # int >= float = bool
|
||||
op(ctx, int, "__eq__", float, bool) # int == float = bool
|
||||
|
||||
op(ctx, float, "__lt__", int, bool) # float < int = bool
|
||||
op(ctx, float, "__gt__", int, bool) # float > int = bool
|
||||
op(ctx, float, "__le__", int, bool) # float <= int = bool
|
||||
op(ctx, float, "__ge__", int, bool) # float >= int = bool
|
||||
op(ctx, float, "__eq__", int, bool) # float == int = bool
|
||||
186
midas/resolver/midas.py
Normal file
186
midas/resolver/midas.py
Normal file
@@ -0,0 +1,186 @@
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
ComplexType,
|
||||
Operation,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.resolver.builtin import define_builtins
|
||||
|
||||
|
||||
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._types: dict[str, Type] = {}
|
||||
self._operations: dict[Operation.CallSignature, Type] = {}
|
||||
|
||||
define_builtins(self)
|
||||
|
||||
def get_type(self, name: str) -> Type:
|
||||
"""Get a type from its name
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
|
||||
Raises:
|
||||
NameError: if the type is not defined
|
||||
|
||||
Returns:
|
||||
Type: the type
|
||||
"""
|
||||
type: Optional[Type] = self._types.get(name)
|
||||
if type is None:
|
||||
raise NameError(f"Undefined type {name}")
|
||||
return type
|
||||
|
||||
def get_operation_result(
|
||||
self, left: Type, operator: str, right: Type
|
||||
) -> Optional[Type]:
|
||||
"""Get the resulting type of an operation
|
||||
|
||||
Args:
|
||||
left (Type): the type of the left operand
|
||||
operator (str): the operation name
|
||||
right (Type): the type of the right operand
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the result type, or None if no matching operation was found
|
||||
"""
|
||||
signature: Operation.CallSignature = Operation.CallSignature(
|
||||
left=left,
|
||||
method=operator,
|
||||
right=right,
|
||||
)
|
||||
result: Optional[Type] = self._operations.get(signature)
|
||||
return result
|
||||
|
||||
def get_operations_by_name(self, name: str) -> list[Operation]:
|
||||
operations: list[Operation] = []
|
||||
for signature, result in self._operations.items():
|
||||
if signature.method == name:
|
||||
operations.append(
|
||||
Operation(
|
||||
signature=signature,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
return operations
|
||||
|
||||
def define_type(self, name: str, type: Type) -> Type:
|
||||
"""Define a type in the registry
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
type (Type): the type to define
|
||||
|
||||
Raises:
|
||||
ValueError: if a type is already defined with that name
|
||||
|
||||
Returns:
|
||||
Type: the defined type
|
||||
"""
|
||||
if name in self._types:
|
||||
raise ValueError(f"Type {name} already defined")
|
||||
self._types[name] = type
|
||||
return type
|
||||
|
||||
def define_operation(self, left: Type, operator: str, right: Type, result: Type):
|
||||
"""Define an operation in the registry
|
||||
|
||||
Args:
|
||||
left (Type): the type of the left operand
|
||||
operator (str): the operation name
|
||||
right (Type): the type of the right operand
|
||||
result (Type): the result type
|
||||
|
||||
Raises:
|
||||
ValueError: if an operation is already defined with these operands and name
|
||||
"""
|
||||
signature: Operation.CallSignature = Operation.CallSignature(
|
||||
left=left,
|
||||
method=operator,
|
||||
right=right,
|
||||
)
|
||||
if signature in self._operations:
|
||||
raise ValueError(
|
||||
f"Operation {operator} already defined between {left} and {right}"
|
||||
)
|
||||
self._operations[signature] = result
|
||||
|
||||
def resolve(self, stmts: list[m.Stmt]):
|
||||
"""Process a sequence of statements
|
||||
|
||||
Args:
|
||||
stmts (list[m.Stmt]): the statements
|
||||
"""
|
||||
for stmt in stmts:
|
||||
stmt.accept(self)
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
type: Type = stmt.type.accept(self)
|
||||
for param in stmt.params:
|
||||
if param.bound is not None:
|
||||
param.bound.accept(self)
|
||||
name: str = stmt.name.lexeme
|
||||
self.define_type(name, AliasType(name=name, type=type))
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
base: Type = stmt.type.accept(self)
|
||||
for op in stmt.operations:
|
||||
right: Type = op.operand.accept(self)
|
||||
result: Type = op.result.accept(self)
|
||||
self.define_operation(
|
||||
left=base,
|
||||
operator=op.name.lexeme,
|
||||
right=right,
|
||||
result=result,
|
||||
)
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None: ...
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> None: ...
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
||||
return expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||
return self.get_type(type.name.lexeme)
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
params: list[Type] = [param.accept(self) for param in type.params]
|
||||
# TODO
|
||||
return UnknownType()
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
type.constraint.accept(self)
|
||||
# TODO
|
||||
return UnknownType()
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> Type:
|
||||
return ComplexType(
|
||||
properties={
|
||||
prop.name.lexeme: prop.type.accept(self) for prop in type.properties
|
||||
}
|
||||
)
|
||||
182
midas/resolver/resolver.py
Normal file
182
midas/resolver/resolver.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import midas.ast.python as p
|
||||
|
||||
|
||||
class ResolverError(Exception): ...
|
||||
|
||||
|
||||
class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
"""A variable assignment and reference resolver
|
||||
|
||||
This class keeps track of which scope a variable is defined in and which
|
||||
scope is referred to when a variable is referenced
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.locals: dict[p.Expr, int] = {}
|
||||
self.scopes: list[dict[str, bool]] = []
|
||||
|
||||
def resolve(self, *objects: p.Stmt | p.Expr) -> None:
|
||||
"""Resolve the given statements or expressions"""
|
||||
|
||||
for obj in objects:
|
||||
obj.accept(self)
|
||||
|
||||
def begin_scope(self):
|
||||
"""Begin a new scope inside the current one"""
|
||||
self.scopes.append({})
|
||||
|
||||
def end_scope(self):
|
||||
"""Close the current scope"""
|
||||
self.scopes.pop()
|
||||
|
||||
def declare(self, name: str) -> None:
|
||||
"""Declare a variable in the current scope
|
||||
|
||||
This method must be called *before* evaluating the variable initializer
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
|
||||
Raises:
|
||||
ResolverError: if the variable has already been declared in the current scope
|
||||
"""
|
||||
if len(self.scopes) == 0:
|
||||
return
|
||||
scope: dict[str, bool] = self.scopes[-1]
|
||||
if name in scope:
|
||||
raise ResolverError(
|
||||
f"A variable with the name {name} is already declared in this scope"
|
||||
)
|
||||
scope[name] = False
|
||||
|
||||
def define(self, name: str) -> None:
|
||||
"""Define a variable in the current scope
|
||||
|
||||
This method must be called *after* evaluating the variable initializer
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
"""
|
||||
if len(self.scopes) == 0:
|
||||
return
|
||||
self.scopes[-1][name] = True
|
||||
|
||||
def resolve_local(self, expr: p.Expr, name: str) -> None:
|
||||
"""Resolve a variable reference and store the scope distance
|
||||
|
||||
This method associates to the variable expression a number representing
|
||||
the "distance" of the variable declaration, i.e. the number of scope
|
||||
levels to go "up" to find the closest declaration for that variable.
|
||||
|
||||
Args:
|
||||
expr (p.Expr): the variable expression
|
||||
name (str): the name of the variable
|
||||
"""
|
||||
for i, scope in enumerate(reversed(self.scopes)):
|
||||
if name in scope:
|
||||
self.locals[expr] = i
|
||||
return
|
||||
|
||||
def resolve_function(self, function: p.Function) -> None:
|
||||
"""Resolve a function definition
|
||||
|
||||
This method creates a new scope for the function, resolves all the
|
||||
parameter declarations and then the body.
|
||||
|
||||
Args:
|
||||
function (p.Function): the function to resolve
|
||||
"""
|
||||
self.begin_scope()
|
||||
for param in function.all_args:
|
||||
self.declare(param.name)
|
||||
self.define(param.name)
|
||||
self.resolve(*function.body)
|
||||
self.end_scope()
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
stmt.expr.accept(self)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
# Declare before resolving body to allow recursion
|
||||
self.declare(stmt.name)
|
||||
self.define(stmt.name)
|
||||
self.resolve_function(stmt)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
self.declare(stmt.name)
|
||||
# NOTE: resolve type here?
|
||||
self.define(stmt.name)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
self.resolve(stmt.value)
|
||||
for target in stmt.targets:
|
||||
match target:
|
||||
case p.VariableExpr() | p.GetExpr():
|
||||
target.accept(self)
|
||||
case _:
|
||||
raise Exception(f"Unsupported assignment to {target}")
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
if stmt.value is not None:
|
||||
self.resolve(stmt.value)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
# Not resolved in sub-environment because assignments in the test leak out of the if
|
||||
# For example:
|
||||
# if (m := 1 + 1) < 2:
|
||||
# ...
|
||||
# print(m) # <- m is still defined
|
||||
self.resolve(stmt.test)
|
||||
|
||||
# Body
|
||||
self.begin_scope()
|
||||
self.resolve(*stmt.body)
|
||||
self.end_scope()
|
||||
|
||||
# Else
|
||||
self.begin_scope()
|
||||
self.resolve(*stmt.orelse)
|
||||
self.end_scope()
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||
self.resolve(expr.left)
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
|
||||
self.resolve(expr.left)
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||
self.resolve(expr.callee)
|
||||
for arg in expr.arguments:
|
||||
self.resolve(arg)
|
||||
for arg in expr.keywords.values():
|
||||
self.resolve(arg)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> None:
|
||||
self.resolve(expr.object)
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
|
||||
pass
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
|
||||
if len(self.scopes) != 0 and self.scopes[-1].get(expr.name) is False:
|
||||
raise ResolverError(
|
||||
f"Cannot use local variable '{expr.name}' in its own initializer"
|
||||
) # aka. UnboundLocalError
|
||||
self.resolve_local(expr, expr.name)
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
|
||||
self.resolve(expr.left)
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||
self.resolve(expr.expr)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||
self.resolve(expr.test)
|
||||
self.resolve(expr.if_true)
|
||||
self.resolve(expr.if_false)
|
||||
54
midas/utils.py
Normal file
54
midas/utils.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
AllowRepeat = Callable[[object], bool]
|
||||
|
||||
|
||||
class UniversalJSONDumper:
|
||||
@classmethod
|
||||
def dump(
|
||||
cls,
|
||||
obj: Any,
|
||||
include_keys: Optional[list[str | tuple[str, str]]] = None,
|
||||
allow_repeat: Optional[AllowRepeat] = None,
|
||||
) -> Any:
|
||||
if include_keys is None:
|
||||
include_keys = []
|
||||
return cls._dump(obj, include_keys, allow_repeat, [])
|
||||
|
||||
@classmethod
|
||||
def _dump(
|
||||
cls,
|
||||
obj: Any,
|
||||
include_keys: list[str | tuple[str, str]],
|
||||
allow_repeat: Optional[AllowRepeat],
|
||||
visited: list[Any],
|
||||
) -> Any:
|
||||
if obj in visited:
|
||||
return None
|
||||
match obj:
|
||||
case str() | int() | float() | None:
|
||||
return obj
|
||||
case list() | set() | tuple():
|
||||
return [
|
||||
cls._dump(child, include_keys, allow_repeat, visited)
|
||||
for child in obj
|
||||
]
|
||||
case dict():
|
||||
return {
|
||||
str(k): cls._dump(v, include_keys, allow_repeat, visited)
|
||||
for k, v in obj.items()
|
||||
}
|
||||
case object():
|
||||
if allow_repeat is None or not allow_repeat(obj):
|
||||
visited.append(obj)
|
||||
return {
|
||||
"_type": obj.__class__.__name__,
|
||||
} | {
|
||||
k: cls._dump(v, include_keys, allow_repeat, visited)
|
||||
for k, v in obj.__dict__.items()
|
||||
if not k.startswith("_")
|
||||
or k in include_keys
|
||||
or (obj.__class__.__name__, k) in include_keys
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported value: {obj}")
|
||||
@@ -1,152 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.ast.annotations import (
|
||||
AnnotationStmt,
|
||||
ConstraintExpr,
|
||||
Expr,
|
||||
LiteralExpr,
|
||||
SchemaElementExpr,
|
||||
SchemaExpr,
|
||||
Stmt,
|
||||
TypeExpr,
|
||||
WildcardExpr,
|
||||
)
|
||||
from lexer.token import Token, TokenType
|
||||
from parser.base import Parser
|
||||
from parser.errors import ParsingError
|
||||
|
||||
|
||||
class AnnotationParser(Parser):
|
||||
"""A simple parser for custom type annotations"""
|
||||
|
||||
SYNC_BOUNDARY: set[TokenType] = set()
|
||||
|
||||
def parse(self) -> Optional[Stmt]:
|
||||
stmt: Optional[Stmt] = None
|
||||
try:
|
||||
stmt = self.annotation()
|
||||
except ParsingError:
|
||||
self.synchronize()
|
||||
if not self.is_at_end():
|
||||
self.error(self.peek(), "Extra tokens")
|
||||
return stmt
|
||||
|
||||
def synchronize(self):
|
||||
"""Skip tokens until a synchronization boundary is found
|
||||
|
||||
This method allows gracefully recovering from a parse error
|
||||
to a safe place and continue parsing
|
||||
"""
|
||||
self.advance()
|
||||
while not self.is_at_end():
|
||||
if self.peek().type in self.SYNC_BOUNDARY:
|
||||
return
|
||||
self.advance()
|
||||
|
||||
def annotation(self) -> AnnotationStmt:
|
||||
"""Parse an annotation
|
||||
|
||||
An annotation is written as `Type` or `Type[Schema]`
|
||||
|
||||
Returns:
|
||||
AnnotationStmt: the parsed annotation statement
|
||||
"""
|
||||
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type identifier")
|
||||
schema: Optional[SchemaExpr] = None
|
||||
if self.match(TokenType.LEFT_BRACKET):
|
||||
schema = self.schema()
|
||||
return AnnotationStmt(name=name, schema=schema)
|
||||
|
||||
def type_expr(self) -> TypeExpr:
|
||||
"""Parse a type expression
|
||||
|
||||
Returns:
|
||||
TypeExpr: the parsed type expression
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
constraints: list[ConstraintExpr] = []
|
||||
|
||||
while not self.is_at_end() and self.match(TokenType.PLUS):
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before type constraint")
|
||||
constraints.append(self.constraint_expr())
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after type constraint")
|
||||
|
||||
return TypeExpr(name=name, constraints=constraints)
|
||||
|
||||
def constraint_expr(self) -> ConstraintExpr:
|
||||
"""Parse a type constraint
|
||||
|
||||
Returns:
|
||||
ConstraintExpr: the parsed type constraint expression
|
||||
"""
|
||||
|
||||
left: Expr = self.constraint_value()
|
||||
op: Token = self.constraint_operator()
|
||||
right: Expr = self.constraint_value()
|
||||
return ConstraintExpr(left=left, op=op, right=right)
|
||||
|
||||
def constraint_value(self) -> Expr:
|
||||
if self.match(TokenType.UNDERSCORE):
|
||||
return WildcardExpr(self.previous())
|
||||
return self.literal()
|
||||
|
||||
def literal(self) -> LiteralExpr:
|
||||
if self.match(TokenType.FALSE):
|
||||
return LiteralExpr(False)
|
||||
if self.match(TokenType.TRUE):
|
||||
return LiteralExpr(True)
|
||||
if self.match(TokenType.NONE):
|
||||
return LiteralExpr(None)
|
||||
|
||||
if self.match(TokenType.NUMBER):
|
||||
return LiteralExpr(self.previous().value)
|
||||
|
||||
raise self.error(self.peek(), "Expected literal")
|
||||
|
||||
def constraint_operator(self) -> Token:
|
||||
if self.match(TokenType.LESS, TokenType.LESS_EQUAL, TokenType.GREATER, TokenType.GREATER_EQUAL, TokenType.EQUAL_EQUAL, TokenType.BANG_EQUAL):
|
||||
return self.previous()
|
||||
raise self.error(self.peek(), "Expected constraint operator")
|
||||
|
||||
def schema(self) -> SchemaExpr:
|
||||
"""Parse a schema definition
|
||||
|
||||
A comma separated list of schema elements
|
||||
|
||||
Returns:
|
||||
SchemaExpr: the parsed schema expression
|
||||
"""
|
||||
left: Token = self.previous()
|
||||
elements: list[Expr] = []
|
||||
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
|
||||
elements.append(self.schema_element())
|
||||
if not self.check(TokenType.RIGHT_BRACKET):
|
||||
self.consume(TokenType.COMMA, "Expected ',' between schema elements")
|
||||
|
||||
right: Token = self.consume(TokenType.RIGHT_BRACKET, "Unclosed schema")
|
||||
return SchemaExpr(left=left, elements=elements, right=right)
|
||||
|
||||
def schema_element(self) -> SchemaElementExpr:
|
||||
"""Parse a schema element
|
||||
|
||||
An anonymous element (`_`), a type, an untyped named column (`name: _`),
|
||||
or a named column (`name: Type`)
|
||||
|
||||
Returns:
|
||||
SchemaElementExpr: the parsed schema element expression
|
||||
"""
|
||||
if self.match(TokenType.UNDERSCORE):
|
||||
return SchemaElementExpr(name=None, type=None)
|
||||
|
||||
if not self.check(TokenType.IDENTIFIER):
|
||||
raise self.error(self.peek(), "Expected schema element")
|
||||
|
||||
name: Optional[Token] = None
|
||||
type: Optional[TypeExpr] = None
|
||||
if self.check_next(TokenType.COLON):
|
||||
name = self.advance()
|
||||
self.advance()
|
||||
if not self.match(TokenType.UNDERSCORE):
|
||||
type = self.type_expr()
|
||||
return SchemaElementExpr(name=name, type=type)
|
||||
217
parser/midas.py
217
parser/midas.py
@@ -1,217 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.ast.midas import (
|
||||
ConstraintExpr,
|
||||
ConstraintStmt,
|
||||
Expr,
|
||||
LiteralExpr,
|
||||
OpStmt,
|
||||
PropertyStmt,
|
||||
Stmt,
|
||||
TypeBodyExpr,
|
||||
TypeExpr,
|
||||
TypeStmt,
|
||||
WildcardExpr,
|
||||
)
|
||||
from lexer.token import Token, TokenType
|
||||
from parser.base import Parser
|
||||
from parser.errors import ParsingError
|
||||
|
||||
|
||||
class MidasParser(Parser):
|
||||
"""A simple parser for midas type definitions"""
|
||||
|
||||
SYNC_BOUNDARY: set[TokenType] = {TokenType.TYPE, TokenType.OP, TokenType.CONSTRAINT}
|
||||
|
||||
def parse(self) -> list[Stmt]:
|
||||
statements: list[Stmt] = []
|
||||
while not self.is_at_end():
|
||||
stmt: Optional[Stmt] = self.declaration()
|
||||
if stmt is None:
|
||||
print("Early stop")
|
||||
break
|
||||
statements.append(stmt)
|
||||
return statements
|
||||
|
||||
def synchronize(self):
|
||||
"""Skip tokens until a synchronization boundary is found
|
||||
|
||||
This method allows gracefully recovering from a parse error
|
||||
to a safe place and continue parsing
|
||||
"""
|
||||
self.advance()
|
||||
while not self.is_at_end():
|
||||
if self.previous().type == TokenType.NEWLINE:
|
||||
return
|
||||
if self.peek().type in self.SYNC_BOUNDARY:
|
||||
return
|
||||
self.advance()
|
||||
|
||||
def declaration(self) -> Optional[Stmt]:
|
||||
"""Try and parse a declaration
|
||||
|
||||
Any parsing error is caught and None is returned
|
||||
|
||||
Returns:
|
||||
Optional[Stmt]: the parsed Midas statement, or None if a ParsingError was raised
|
||||
"""
|
||||
try:
|
||||
if self.match(TokenType.TYPE):
|
||||
return self.type_declaration()
|
||||
if self.match(TokenType.OP):
|
||||
return self.op_declaration()
|
||||
if self.match(TokenType.CONSTRAINT):
|
||||
return self.constraint_declaration()
|
||||
raise self.error(self.peek(), "Unexpected token")
|
||||
except ParsingError:
|
||||
self.synchronize()
|
||||
return None
|
||||
|
||||
def type_declaration(self) -> TypeStmt:
|
||||
"""Parse a type declaration
|
||||
|
||||
A type declaration is written `type Name<TypeExpr, ...>` optionally followed by a brace-wrapped body
|
||||
|
||||
Returns:
|
||||
TypeStmt: the parsed type declaration statement
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
self.consume(TokenType.LESS, "Expected '<' after type name")
|
||||
bases: list[TypeExpr] = []
|
||||
while not self.check(TokenType.GREATER) and not self.is_at_end():
|
||||
bases.append(self.type_expr())
|
||||
if not self.check(TokenType.GREATER):
|
||||
self.consume(TokenType.COMMA, "Expected ',' between type bases")
|
||||
self.consume(TokenType.GREATER, "Expected '>' after base type")
|
||||
|
||||
body: Optional[TypeBodyExpr] = None
|
||||
|
||||
if self.check(TokenType.LEFT_BRACE):
|
||||
body = self.type_body_expr()
|
||||
return TypeStmt(name=name, bases=bases, body=body)
|
||||
|
||||
def type_expr(self) -> TypeExpr:
|
||||
"""Parse a type expression
|
||||
|
||||
Returns:
|
||||
TypeExpr: the parsed type expression
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
constraints: list[ConstraintExpr] = []
|
||||
|
||||
while not self.is_at_end() and self.match(TokenType.PLUS):
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before type constraint")
|
||||
constraints.append(self.constraint_expr())
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after type constraint")
|
||||
|
||||
return TypeExpr(name=name, constraints=constraints)
|
||||
|
||||
def constraint_expr(self) -> ConstraintExpr:
|
||||
"""Parse a type constraint
|
||||
|
||||
Returns:
|
||||
ConstraintExpr: the parsed type constraint expression
|
||||
"""
|
||||
|
||||
left: Expr = self.constraint_value()
|
||||
op: Token = self.constraint_operator()
|
||||
right: Expr = self.constraint_value()
|
||||
return ConstraintExpr(left=left, op=op, right=right)
|
||||
|
||||
def constraint_value(self) -> Expr:
|
||||
if self.match(TokenType.UNDERSCORE):
|
||||
return WildcardExpr(self.previous())
|
||||
return self.literal()
|
||||
|
||||
def literal(self) -> LiteralExpr:
|
||||
if self.match(TokenType.FALSE):
|
||||
return LiteralExpr(False)
|
||||
if self.match(TokenType.TRUE):
|
||||
return LiteralExpr(True)
|
||||
if self.match(TokenType.NONE):
|
||||
return LiteralExpr(None)
|
||||
|
||||
if self.match(TokenType.NUMBER):
|
||||
return LiteralExpr(self.previous().value)
|
||||
|
||||
raise self.error(self.peek(), "Expected literal")
|
||||
|
||||
def constraint_operator(self) -> Token:
|
||||
if self.match(
|
||||
TokenType.LESS,
|
||||
TokenType.LESS_EQUAL,
|
||||
TokenType.GREATER,
|
||||
TokenType.GREATER_EQUAL,
|
||||
TokenType.EQUAL_EQUAL,
|
||||
TokenType.BANG_EQUAL,
|
||||
):
|
||||
return self.previous()
|
||||
raise self.error(self.peek(), "Expected constraint operator")
|
||||
|
||||
def type_body_expr(self) -> TypeBodyExpr:
|
||||
"""Parse a type definition body
|
||||
|
||||
A type definition body is a set of whitespace-separated
|
||||
property statements enclosed in curly braces
|
||||
|
||||
Returns:
|
||||
TypeBodyExpr: the parsed type body expression
|
||||
"""
|
||||
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start type body")
|
||||
properties: list[PropertyStmt] = []
|
||||
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
||||
properties.append(self.property_stmt())
|
||||
self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
|
||||
return TypeBodyExpr(properties=properties)
|
||||
|
||||
def property_stmt(self) -> PropertyStmt:
|
||||
"""Parse a property statement
|
||||
|
||||
A type property statement is written `name: Type`
|
||||
|
||||
Returns:
|
||||
PropertyStmt: the parsed property statement
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after property name")
|
||||
type: TypeExpr = self.type_expr()
|
||||
return PropertyStmt(name=name, type=type)
|
||||
|
||||
def op_declaration(self) -> OpStmt:
|
||||
"""Parse an operation definition
|
||||
|
||||
An operation is written `op <Type1> operator <Type2> = <Type3>` where `operator` can be any single token
|
||||
|
||||
Returns:
|
||||
OpStmt: the parsed operation statement
|
||||
"""
|
||||
self.consume(TokenType.LESS, "Expected '<' before first type")
|
||||
left: TypeExpr = self.type_expr()
|
||||
self.consume(TokenType.GREATER, "Expected '>' after first type")
|
||||
|
||||
op: Token = self.advance()
|
||||
|
||||
self.consume(TokenType.LESS, "Expected '<' before second type")
|
||||
right: TypeExpr = self.type_expr()
|
||||
self.consume(TokenType.GREATER, "Expected '>' after second type")
|
||||
|
||||
self.consume(TokenType.EQUAL, "Expected '=' after second type")
|
||||
|
||||
self.consume(TokenType.LESS, "Expected '<' before result type")
|
||||
result: TypeExpr = self.type_expr()
|
||||
self.consume(TokenType.GREATER, "Expected '>' after result type")
|
||||
|
||||
return OpStmt(left=left, op=op, right=right, result=result)
|
||||
|
||||
def constraint_declaration(self) -> ConstraintStmt:
|
||||
"""Parse a type constraint declaration
|
||||
|
||||
A constraint is written `constraint Name = constraint_expression`
|
||||
|
||||
Returns:
|
||||
ConstraintStmt: the parsed constraint declaration statement
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected constraint name")
|
||||
self.consume(TokenType.EQUAL, "Expected '=' after constraint name")
|
||||
constraint: ConstraintExpr = self.constraint_expr()
|
||||
return ConstraintStmt(name=name, constraint=constraint)
|
||||
22
pyproject.toml
Normal file
22
pyproject.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[project]
|
||||
name = "midas"
|
||||
version = "0.1.0"
|
||||
description = "A static-first type checking framework for Python data-frames"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
authors = [
|
||||
{ name = "Louis Heredero", email = "louis.heredero@students.hevs.ch" },
|
||||
]
|
||||
classifiers = ["Programming Language :: Python :: 3"]
|
||||
dependencies = ["click>=8.4.1"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://git.kbk28.ch/HEL/midas"
|
||||
Repository = "https://git.kbk28.ch/HEL/midas"
|
||||
|
||||
[project.scripts]
|
||||
midas = "midas.cli.main:midas"
|
||||
|
||||
[build-system]
|
||||
requires = ['hatchling']
|
||||
build-backend = 'hatchling.build'
|
||||
@@ -1,26 +1,43 @@
|
||||
identifier ::= '[a-zA-Z][a-zA-Z_]*'
|
||||
// W3C EBNF syntax definition for Midas
|
||||
Identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*
|
||||
|
||||
integer ::= '\d+'
|
||||
number ::= integer ["." integer]
|
||||
boolean ::= "False" | "True"
|
||||
none ::= "None"
|
||||
Integer ::= '\d+'
|
||||
Number ::= "-"? Integer ("." Integer)?
|
||||
Boolean ::= "False" | "True"
|
||||
None ::= "None"
|
||||
|
||||
value ::= number | boolean | none
|
||||
lambda-value ::= "_" | value
|
||||
lambda-operator ::= ">" | "<" | ">=" | "<=" | "==" | "!="
|
||||
lambda ::= lambda-value lambda-operator lambda-value
|
||||
Value ::= Number | Boolean | None
|
||||
|
||||
constraint ::= identifier | "(" lambda ")"
|
||||
base-type ::= identifier
|
||||
type ::= base-type { "+" constraint }
|
||||
ComparisonOp ::= ">" | "<" | ">=" | "<="
|
||||
EqualityOp ::= "==" | "!="
|
||||
|
||||
type-property ::= 'identifier' ":" 'type'
|
||||
type-body ::= "{" { 'type-property' } "}"
|
||||
Grouping ::= "(" Constraint ")"
|
||||
Primary ::= "_" | Value | Identifier | Grouping
|
||||
Reference ::= Primary ("." Identifier)*
|
||||
Unary ::= "-"? Unary | Reference
|
||||
Comparison ::= Unary (ComparisonOp Unary)*
|
||||
Equality ::= Comparison (EqualityOp Comparison)*
|
||||
Constraint ::= Equality ("&" Equality)*
|
||||
|
||||
operation-type ::= "<" 'type' ">"
|
||||
TemplateParam ::= Identifier ("<:" Type)?
|
||||
Template ::= "[" (TemplateParam ("," TemplateParam)*)? "]"
|
||||
|
||||
type-statement ::= "type" 'identifier' "<" 'type' {"," 'type'} ">" ['type-body']
|
||||
operation-statement ::= "op" 'operation-type' 'operator' 'operation-type' "=" 'operation-type'
|
||||
constraint-statement ::= "constraint" 'identifier' "=" 'lambda'
|
||||
|
||||
statement ::= type-statement | operation-statement | constraint-statement
|
||||
TypeProperty ::= Identifier ":" Type
|
||||
ComplexType ::= "{" TypeProperty* "}"
|
||||
NamedType ::= Identifier
|
||||
TypeParams ::= "[" (Type ("," Type)*)? "]"
|
||||
GenericType ::= NamedType TypeParams?
|
||||
GroupedType ::= "(" Type ")"
|
||||
BaseType ::= GroupedType | ComplexType | GenericType
|
||||
ConstraintType ::= BaseType ("where" Constraint)?
|
||||
Type ::= ConstraintType
|
||||
|
||||
OpDefinition ::= "op" Identifier "(" Type ")" "->" Type
|
||||
ExtendBody ::= "{" OpDefinition* "}"
|
||||
|
||||
TypeStatement ::= "type" Identifier Template? "=" Type
|
||||
ExtendStatement ::= "extend" Type ExtendBody
|
||||
PredicateStatement ::= "predicate" Identifier "(" Identifier ":" Type ")" "=" Constraint
|
||||
|
||||
Statement ::= TypeStatement | ExtendStatement | PredicateStatement
|
||||
|
||||
208
syntax/midas.typ
208
syntax/midas.typ
@@ -1,4 +1,11 @@
|
||||
#import "@preview/fervojo:0.1.1": render
|
||||
#import "@preview/fervojo:0.1.1": default-css, render
|
||||
|
||||
#let extra-css = ```css
|
||||
svg.railroad .terminal rect {
|
||||
fill: #F7DCD4;
|
||||
}
|
||||
```
|
||||
#let css = default-css() + bytes(extra-css.text)
|
||||
|
||||
#let value = ```
|
||||
{[`value` <
|
||||
@@ -8,90 +15,193 @@
|
||||
>]}
|
||||
```
|
||||
|
||||
#let constraint = ```
|
||||
{[`constraint` <"_", 'value'> <">", "<", ">=", "<=", "==", "!="> <"_", 'value'>]}
|
||||
#let grouping = ```
|
||||
{[`grouping` "(" 'constraint' ")"]}
|
||||
```
|
||||
|
||||
#let type-with-constraints = ```
|
||||
{[`type-with-constraints` 'identifier' <!, ["+" "(" 'constraint' ")"] * !>]}
|
||||
#let primary = ```
|
||||
{[`primary` <"_", 'value', 'identifier', 'grouping'>]}
|
||||
```
|
||||
|
||||
#let reference = ```
|
||||
{[`reference` 'primary' <!, ["." 'identifier']*!>]}
|
||||
```
|
||||
|
||||
#let unary = ```
|
||||
{[`unary` <[<!, "-"> 'unary'], 'reference'>]}
|
||||
```
|
||||
|
||||
#let comparison = ```
|
||||
{[`comparison` 'unary'*<">", "<", ">=", "<=">]}
|
||||
```
|
||||
|
||||
#let equality = ```
|
||||
{[`equality` 'comparison'*<"==", "!=">]}
|
||||
```
|
||||
|
||||
#let constraint = ```
|
||||
{[`constraint` 'equality'*"&"]}
|
||||
```
|
||||
|
||||
#let template-param = ```
|
||||
{[`template-param` 'identifier' <!, ["<:" 'type']>]}
|
||||
```
|
||||
|
||||
#let template = ```
|
||||
{[`template` "[" <!, 'template-param'*","> "]"]}
|
||||
```
|
||||
|
||||
#let type-property = ```
|
||||
{[`type-property` 'identifier' ":" 'type-with-constraints']}
|
||||
{[`type-property` 'identifier' ":" 'type']}
|
||||
```
|
||||
|
||||
#let type-body = ```
|
||||
{[`type-body` "{" <!, 'type-property'*!> "}"]}
|
||||
#let complex-type = ```
|
||||
{[`complex-type` "{" <!, 'type-property'*!> "}"]}
|
||||
```
|
||||
|
||||
#let operation-type = ```
|
||||
{[`operation-type` "<" 'type-with-constraints' ">"]}
|
||||
#let named-type = ```
|
||||
{[`named-type` 'identifier']}
|
||||
```
|
||||
|
||||
#let type-params = ```
|
||||
{[`type-params` "[" <!, 'type'*","> "]"]}
|
||||
```
|
||||
|
||||
#let generic-type = ```
|
||||
{[`generic-type` 'named-type' <!, 'type-params'>]}
|
||||
```
|
||||
|
||||
#let grouped-type = ```
|
||||
{[`grouped-type` "(" 'type' ")"]}
|
||||
```
|
||||
|
||||
#let base-type = ```
|
||||
{[`base-type` <'grouped-type', 'complex-type', 'generic-type'>]}
|
||||
```
|
||||
|
||||
#let constraint-type = ```
|
||||
{[`constraint-type` 'base-type' <!, ["where" 'constraint']>]}
|
||||
```
|
||||
|
||||
#let type = ```
|
||||
{[`type` 'constraint-type']}
|
||||
```
|
||||
|
||||
#let type-statement = ```
|
||||
{[`type-statement` "type" 'identifier' "<" 'type-with-constraints'*"," ">" <!, 'type-body'>]}
|
||||
{[`type-statement` "type" 'identifier' <!, 'template'> "=" 'type']}
|
||||
```
|
||||
|
||||
#let operation-statement = ```
|
||||
{[`operation-statement` "op" 'operation-type' "operator" 'operation-type' "=" 'operation-type']}
|
||||
#let op-definition = ```
|
||||
{[`op-definition` "op" 'identifier' "(" 'type' ")" "->" 'type']}
|
||||
```
|
||||
|
||||
#let constraint-statement = ```
|
||||
{[`constraint-statement` "constraint" 'identifier' "=" 'constraint']}
|
||||
#let extend-statement = ```
|
||||
{[`extend-statement` "extend" 'type' "{" <!, 'op-definition'*!> "}"]}
|
||||
```
|
||||
|
||||
#let predicate-statement = ```
|
||||
{[`predicate-statement` "predicate" 'identifier' "(" 'identifier' ":" 'type' ")" "=" 'constraint']}
|
||||
```
|
||||
|
||||
#let statement = ```
|
||||
{[`statement` <'type-statement', 'operation-statement', 'constraint-statement'>]}
|
||||
{[`statement` <'type-statement', 'extend-statement', 'predicate-statement'>]}
|
||||
```
|
||||
|
||||
#let rules = (
|
||||
value,
|
||||
constraint,
|
||||
type-with-constraints,
|
||||
type-property,
|
||||
type-body,
|
||||
operation-type,
|
||||
type-statement,
|
||||
operation-statement,
|
||||
constraint-statement,
|
||||
statement,
|
||||
value: value,
|
||||
grouping: grouping,
|
||||
primary: primary,
|
||||
reference: reference,
|
||||
unary: unary,
|
||||
comparison: comparison,
|
||||
equality: equality,
|
||||
constraint: constraint,
|
||||
template-param: template-param,
|
||||
template: template,
|
||||
type-property: type-property,
|
||||
complex-type: complex-type,
|
||||
named-type: named-type,
|
||||
type-params: type-params,
|
||||
generic-type: generic-type,
|
||||
grouped-type: grouped-type,
|
||||
base-type: base-type,
|
||||
constraint-type: constraint-type,
|
||||
type: type,
|
||||
type-statement: type-statement,
|
||||
op-definition: op-definition,
|
||||
extend-statement: extend-statement,
|
||||
predicate-statement: predicate-statement,
|
||||
statement: statement,
|
||||
)
|
||||
|
||||
#let inline = (
|
||||
"grouping",
|
||||
"value",
|
||||
"template-param",
|
||||
"template",
|
||||
"type-property",
|
||||
"complex-type",
|
||||
"type-params",
|
||||
"named-type",
|
||||
"grouped-type",
|
||||
"generic-type",
|
||||
"base-type",
|
||||
"constraint-type",
|
||||
"op-definition",
|
||||
"type-statement",
|
||||
"extend-statement",
|
||||
"predicate-statement",
|
||||
)
|
||||
|
||||
#set text(font: "Source Sans 3")
|
||||
|
||||
= Midas type definition syntax
|
||||
#title[Midas type definition syntax]
|
||||
|
||||
#for rule in rules {
|
||||
render(rule)
|
||||
}
|
||||
= Outline
|
||||
|
||||
/*
|
||||
#let by-name = (
|
||||
value: value,
|
||||
constraint: constraint,
|
||||
type-with-constraints: type-with-constraints,
|
||||
type-property: type-property,
|
||||
type-body: type-body,
|
||||
operation-type: operation-type,
|
||||
type-statement: type-statement,
|
||||
operation-statement: operation-statement,
|
||||
constraint-statement: constraint-statement,
|
||||
#box(
|
||||
columns(
|
||||
2,
|
||||
outline(title: none),
|
||||
),
|
||||
height: 9cm,
|
||||
stroke: 1pt,
|
||||
inset: 1em,
|
||||
)
|
||||
|
||||
= Statements and expressions
|
||||
|
||||
#for (name, rule) in rules.pairs().rev() {
|
||||
[== #name]
|
||||
render(rule, css: css)
|
||||
}
|
||||
|
||||
#let substitute(base-rule) = {
|
||||
let new-rule = base-rule
|
||||
for (key, rule) in by-name.pairs() {
|
||||
new-rule = new-rule.replace("'" + key + "'", rule.text.slice(1, -1))
|
||||
for name in inline {
|
||||
let rule = rules.at(name)
|
||||
let replacement = rule.text.slice(1, -1).replace(regex("\[`.*?`"), "[")
|
||||
replacement = "[" + replacement + "#`" + name + "`]"
|
||||
new-rule = new-rule.replace(
|
||||
"'" + name + "'",
|
||||
replacement,
|
||||
)
|
||||
}
|
||||
if new-rule != base-rule {
|
||||
new-rule = substitute(new-rule)
|
||||
}
|
||||
return new-rule.replace(regex("`.*?`"), "")
|
||||
return new-rule
|
||||
}
|
||||
|
||||
#let combined = raw(substitute(statement.text))
|
||||
|
||||
|
||||
#set page(flipped: true)
|
||||
#render(combined)
|
||||
*/
|
||||
|
||||
= Combined rules
|
||||
|
||||
#for (name, rule) in rules.pairs() {
|
||||
if not name in inline {
|
||||
[== #name]
|
||||
let combined = substitute(rule.text)
|
||||
render(raw(combined), css: css)
|
||||
//raw(block: true, combined)
|
||||
}
|
||||
}
|
||||
|
||||
35
test.py
35
test.py
@@ -1,40 +1,21 @@
|
||||
import importlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from core.ast.printer import AnnotationAstPrinter, MidasAstPrinter
|
||||
from lexer.annotations import AnnotationLexer
|
||||
from lexer.midas import MidasLexer
|
||||
from lexer.token import Token
|
||||
from parser.annotations import AnnotationParser
|
||||
from parser.midas import MidasParser
|
||||
|
||||
|
||||
def test_annotation():
|
||||
# Frame annotation
|
||||
mod = importlib.import_module("examples.00_syntax_prototype.01_simple_types")
|
||||
|
||||
annotation: str = mod.__annotations__["df"]
|
||||
lexer: AnnotationLexer = AnnotationLexer(annotation, "01_simple_types.py")
|
||||
tokens: list[Token] = lexer.process()
|
||||
# print([f"{t.type.name}('{t.lexeme}')" for t in tokens])
|
||||
|
||||
parser = AnnotationParser(tokens)
|
||||
parsed = parser.parse()
|
||||
print(parsed)
|
||||
for err in parser.errors:
|
||||
print(err.get_report())
|
||||
printer = AnnotationAstPrinter()
|
||||
if parsed is not None:
|
||||
print(printer.print(parsed))
|
||||
from midas.ast.printer import MidasAstPrinter
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
|
||||
|
||||
def test_midas():
|
||||
# Midas type definitions
|
||||
path: Path = Path("examples") / "00_syntax_prototype" / "02_custom_types.midas"
|
||||
path: Path = Path("examples") / "00_syntax_prototype" / "03_custom_types_v2.midas"
|
||||
definitions: str = path.read_text()
|
||||
midas_lexer: MidasLexer = MidasLexer(definitions, path.name)
|
||||
tokens: list[Token] = midas_lexer.process()
|
||||
# print([f"{t.type.name}('{t.lexeme}')" for t in tokens])
|
||||
with open("tokens.json", "w") as f:
|
||||
json.dump([f"{t.type.name}('{t.lexeme}')" for t in tokens], f, indent=4)
|
||||
|
||||
parser = MidasParser(tokens)
|
||||
parsed = parser.parse()
|
||||
|
||||
149
tests/base.py
Normal file
149
tests/base.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import difflib
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Protocol
|
||||
|
||||
|
||||
class CaseResult(Protocol):
|
||||
def dumps(self) -> str: ...
|
||||
|
||||
|
||||
class Tester(ABC):
|
||||
"""A test runner to check for regressions in the lexer and parser"""
|
||||
|
||||
CASES_DIR: Path = Path(__file__).parent / "cases"
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def namespace(self) -> str: ...
|
||||
|
||||
@property
|
||||
def base_dir(self) -> Path:
|
||||
return self.CASES_DIR / self.namespace
|
||||
|
||||
@abstractmethod
|
||||
def _list_tests(self) -> list[Path]: ...
|
||||
|
||||
def run_all_tests(self) -> bool:
|
||||
paths: list[Path] = sorted(self._list_tests())
|
||||
return self.run_tests(paths)
|
||||
|
||||
def run_tests(self, tests: list[Path]) -> bool:
|
||||
rule: str = "-" * 80
|
||||
n: int = len(tests)
|
||||
successes: int = 0
|
||||
failures: int = 0
|
||||
|
||||
print(rule)
|
||||
for i, test in enumerate(tests):
|
||||
print(f"Case {i+1}/{n}: {test.resolve().relative_to(self.CASES_DIR)}")
|
||||
success: bool = self._run_test(test)
|
||||
if success:
|
||||
successes += 1
|
||||
else:
|
||||
failures += 1
|
||||
|
||||
print(rule)
|
||||
print(f"Success: {successes}/{n}")
|
||||
print(f"Failed: {failures}/{n}")
|
||||
print(rule)
|
||||
return failures == 0
|
||||
|
||||
def _run_test(self, path: Path) -> bool:
|
||||
result_path: Path = self._result_path(path)
|
||||
if not result_path.exists():
|
||||
print("Missing snapshot. Please run the update command first")
|
||||
return False
|
||||
result: CaseResult = self._exec_case(path)
|
||||
expected: str = result_path.read_text()
|
||||
actual: str = result.dumps()
|
||||
|
||||
if expected == actual:
|
||||
return True
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
expected.splitlines(keepends=True),
|
||||
actual.splitlines(keepends=True),
|
||||
fromfile="Snapshot",
|
||||
tofile="Result",
|
||||
)
|
||||
self._print_diff(diff)
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def _exec_case(self, path: Path) -> CaseResult: ...
|
||||
|
||||
def update_all_tests(self):
|
||||
paths: list[Path] = sorted(self._list_tests())
|
||||
return self.update_tests(paths)
|
||||
|
||||
def update_tests(self, tests: list[Path]):
|
||||
updated: int = 0
|
||||
for test in tests:
|
||||
if self._update_test(test):
|
||||
updated += 1
|
||||
print(f"Updated {updated}/{len(tests)} tests")
|
||||
|
||||
def _update_test(self, path: Path) -> bool:
|
||||
result: CaseResult = self._exec_case(path)
|
||||
result_path: Path = self._result_path(path)
|
||||
current: str = result_path.read_text() if result_path.exists() else ""
|
||||
new: str = result.dumps()
|
||||
if current == new:
|
||||
return False
|
||||
result_path.write_text(new)
|
||||
return True
|
||||
|
||||
def _result_path(self, test_path: Path) -> Path:
|
||||
return test_path.parent / (test_path.name + ".ref.json")
|
||||
|
||||
def _print_diff(self, diff: Iterator[str]):
|
||||
for line in diff:
|
||||
if line.startswith("+") and not line.startswith("+++"):
|
||||
print(f"\033[92m{line}\033[0m", end="")
|
||||
elif line.startswith("-") and not line.startswith("---"):
|
||||
print(f"\033[91m{line}\033[0m", end="")
|
||||
else:
|
||||
print(line, end="")
|
||||
print()
|
||||
|
||||
@classmethod
|
||||
def main(cls):
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers(dest="subcommand")
|
||||
|
||||
update = subparsers.add_parser("update")
|
||||
update.add_argument("-a", "--all", action="store_true")
|
||||
update.add_argument("FILE", type=Path, nargs="*")
|
||||
|
||||
run = subparsers.add_parser("run")
|
||||
run.add_argument("-a", "--all", action="store_true")
|
||||
run.add_argument("FILE", type=Path, nargs="*")
|
||||
args = parser.parse_args()
|
||||
|
||||
tester: Tester = cls()
|
||||
|
||||
match args.subcommand:
|
||||
case "update":
|
||||
if args.all:
|
||||
tester.update_all_tests()
|
||||
else:
|
||||
tester.update_tests(args.FILE)
|
||||
case "run":
|
||||
success: bool
|
||||
if args.all:
|
||||
success = tester.run_all_tests()
|
||||
else:
|
||||
success = tester.run_tests(args.FILE)
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
case None:
|
||||
print("No subcommand provided. Available subcommands: run, update")
|
||||
sys.exit(1)
|
||||
case _:
|
||||
print(f"Unknown subcommand '{args.subcommand}'")
|
||||
sys.exit(1)
|
||||
14
tests/cases/checker/01_simple_types.py
Normal file
14
tests/cases/checker/01_simple_types.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
df: Frame[
|
||||
verified: bool,
|
||||
birth_year: int,
|
||||
height: float + ( _ > 0 ) + ( _ < 250 ),
|
||||
name: str,
|
||||
date: datetime,
|
||||
float,
|
||||
unknown: _,
|
||||
_
|
||||
]
|
||||
4
tests/cases/checker/01_simple_types.py.ref.json
Normal file
4
tests/cases/checker/01_simple_types.py.ref.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"diagnostics": [],
|
||||
"judgments": []
|
||||
}
|
||||
11
tests/cases/checker/02_simple_operations.py
Normal file
11
tests/cases/checker/02_simple_operations.py
Normal file
@@ -0,0 +1,11 @@
|
||||
a: int = 3
|
||||
b: int = 4
|
||||
|
||||
c = a + b
|
||||
|
||||
c = "invalid"
|
||||
|
||||
d = True
|
||||
e = d + d
|
||||
|
||||
f: float = a
|
||||
179
tests/cases/checker/02_simple_operations.py.ref.json
Normal file
179
tests/cases/checker/02_simple_operations.py.ref.json
Normal file
@@ -0,0 +1,179 @@
|
||||
{
|
||||
"diagnostics": [
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
6,
|
||||
0
|
||||
],
|
||||
"end": [
|
||||
6,
|
||||
13
|
||||
]
|
||||
},
|
||||
"message": "Cannot assign BaseType(name='str') to c of type BaseType(name='int')"
|
||||
}
|
||||
],
|
||||
"judgments": [
|
||||
{
|
||||
"location": {
|
||||
"from": "L1:9",
|
||||
"to": "L1:10"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 3
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L2:9",
|
||||
"to": "L2:10"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 4
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L4:4",
|
||||
"to": "L4:5"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L4:8",
|
||||
"to": "L4:9"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L4:4",
|
||||
"to": "L4:9"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"operator": "+",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:4",
|
||||
"to": "L6:13"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "invalid"
|
||||
},
|
||||
"type": {
|
||||
"name": "str"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L8:4",
|
||||
"to": "L8:8"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": true
|
||||
},
|
||||
"type": {
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L9:4",
|
||||
"to": "L9:5"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "d"
|
||||
},
|
||||
"type": {
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L9:8",
|
||||
"to": "L9:9"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "d"
|
||||
},
|
||||
"type": {
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L9:4",
|
||||
"to": "L9:9"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "d"
|
||||
},
|
||||
"operator": "+",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "d"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:11",
|
||||
"to": "L11:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
18
tests/cases/checker/03_functions.py
Normal file
18
tests/cases/checker/03_functions.py
Normal file
@@ -0,0 +1,18 @@
|
||||
def foo(a: int, /, b: float, *, c: str):
|
||||
return True
|
||||
|
||||
|
||||
r1 = foo()
|
||||
r2 = foo(1)
|
||||
r3 = foo(1, 2.0)
|
||||
r4 = foo(1, b=2.0)
|
||||
r5 = foo(1, 2.0, "test")
|
||||
r6 = foo(1, 2.0, b=3.0)
|
||||
r7 = foo(a=1)
|
||||
r8 = foo(g="test")
|
||||
|
||||
r9a = foo(1, 2.0, c="test")
|
||||
r9b = foo(1, b=2.0, c="test")
|
||||
r9c = foo(1, c="test", b=2.0)
|
||||
|
||||
r10 = foo("a", 3, c=False)
|
||||
1468
tests/cases/checker/03_functions.py.ref.json
Normal file
1468
tests/cases/checker/03_functions.py.ref.json
Normal file
File diff suppressed because it is too large
Load Diff
14
tests/cases/checker/04_custom_types.midas
Normal file
14
tests/cases/checker/04_custom_types.midas
Normal file
@@ -0,0 +1,14 @@
|
||||
type Meter = float
|
||||
type Second = float
|
||||
type MeterPerSecond = float
|
||||
|
||||
extend Meter {
|
||||
op __add__(Meter) -> Meter
|
||||
op __sub__(Meter) -> Meter
|
||||
op __truediv__(Second) -> MeterPerSecond
|
||||
}
|
||||
|
||||
extend Second {
|
||||
op __add__(Second) -> Second
|
||||
op __sub__(Second) -> Second
|
||||
}
|
||||
6
tests/cases/checker/04_custom_types.py
Normal file
6
tests/cases/checker/04_custom_types.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
distance: Meter = cast(Meter, 123.45)
|
||||
time: Second = cast(Second, 6.7)
|
||||
speed = distance / time
|
||||
109
tests/cases/checker/04_custom_types.py.ref.json
Normal file
109
tests/cases/checker/04_custom_types.py.ref.json
Normal file
@@ -0,0 +1,109 @@
|
||||
{
|
||||
"diagnostics": [],
|
||||
"judgments": [
|
||||
{
|
||||
"location": {
|
||||
"from": "L4:18",
|
||||
"to": "L4:37"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "CastExpr",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Meter",
|
||||
"param": null
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 123.45
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "Meter",
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L5:15",
|
||||
"to": "L5:32"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "CastExpr",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Second",
|
||||
"param": null
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 6.7
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "Second",
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:8",
|
||||
"to": "L6:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "distance"
|
||||
},
|
||||
"type": {
|
||||
"name": "Meter",
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:19",
|
||||
"to": "L6:23"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "time"
|
||||
},
|
||||
"type": {
|
||||
"name": "Second",
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:8",
|
||||
"to": "L6:23"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "distance"
|
||||
},
|
||||
"operator": "/",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "time"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "MeterPerSecond",
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
25
tests/cases/checker/05_control_flow.py
Normal file
25
tests/cases/checker/05_control_flow.py
Normal file
@@ -0,0 +1,25 @@
|
||||
def valid(a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
def with_if(a: int, b: int) -> int:
|
||||
if a < b:
|
||||
return b - a
|
||||
else:
|
||||
return a - b
|
||||
|
||||
def unreachable1():
|
||||
return
|
||||
a = 0
|
||||
|
||||
def unreachable2(a: int) -> int:
|
||||
if a > 10:
|
||||
return a - 10
|
||||
else:
|
||||
return a
|
||||
b = 0
|
||||
|
||||
def mixed(a: int, b: int):
|
||||
if a < b:
|
||||
return b - a
|
||||
else:
|
||||
return "oops"
|
||||
256
tests/cases/checker/05_control_flow.py.ref.json
Normal file
256
tests/cases/checker/05_control_flow.py.ref.json
Normal file
@@ -0,0 +1,256 @@
|
||||
{
|
||||
"diagnostics": [
|
||||
{
|
||||
"type": "Warning",
|
||||
"location": {
|
||||
"start": [
|
||||
12,
|
||||
4
|
||||
],
|
||||
"end": [
|
||||
12,
|
||||
9
|
||||
]
|
||||
},
|
||||
"message": "Unreachable statement"
|
||||
},
|
||||
{
|
||||
"type": "Warning",
|
||||
"location": {
|
||||
"start": [
|
||||
19,
|
||||
4
|
||||
],
|
||||
"end": [
|
||||
19,
|
||||
9
|
||||
]
|
||||
},
|
||||
"message": "Unreachable statement"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
21,
|
||||
0
|
||||
],
|
||||
"end": [
|
||||
25,
|
||||
21
|
||||
]
|
||||
},
|
||||
"message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]"
|
||||
}
|
||||
],
|
||||
"judgments": [
|
||||
{
|
||||
"location": {
|
||||
"from": "L2:11",
|
||||
"to": "L2:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L2:15",
|
||||
"to": "L2:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L5:7",
|
||||
"to": "L5:8"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L5:11",
|
||||
"to": "L5:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:15",
|
||||
"to": "L6:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:19",
|
||||
"to": "L6:20"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L8:15",
|
||||
"to": "L8:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L8:19",
|
||||
"to": "L8:20"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L15:7",
|
||||
"to": "L15:8"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L15:11",
|
||||
"to": "L15:13"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 10
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L16:15",
|
||||
"to": "L16:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L16:19",
|
||||
"to": "L16:21"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 10
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L22:7",
|
||||
"to": "L22:8"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L22:11",
|
||||
"to": "L22:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L23:15",
|
||||
"to": "L23:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L23:19",
|
||||
"to": "L23:20"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
12
tests/cases/checker/06_subtyping.py
Normal file
12
tests/cases/checker/06_subtyping.py
Normal file
@@ -0,0 +1,12 @@
|
||||
v1: int = 3
|
||||
v2: float = 4
|
||||
|
||||
|
||||
def maximum(a: float, b: float):
|
||||
if b > a:
|
||||
return b
|
||||
return a
|
||||
|
||||
|
||||
v3 = maximum(v1, v2)
|
||||
v3 = v1 + v2
|
||||
193
tests/cases/checker/06_subtyping.py.ref.json
Normal file
193
tests/cases/checker/06_subtyping.py.ref.json
Normal file
@@ -0,0 +1,193 @@
|
||||
{
|
||||
"diagnostics": [],
|
||||
"judgments": [
|
||||
{
|
||||
"location": {
|
||||
"from": "L1:10",
|
||||
"to": "L1:11"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 3
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L2:12",
|
||||
"to": "L2:13"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 4
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:7",
|
||||
"to": "L6:8"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:11",
|
||||
"to": "L6:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:5",
|
||||
"to": "L11:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "maximum"
|
||||
},
|
||||
"type": {
|
||||
"name": "maximum",
|
||||
"pos_args": [],
|
||||
"args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "b",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:13",
|
||||
"to": "L11:15"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v1"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:17",
|
||||
"to": "L11:19"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v2"
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:5",
|
||||
"to": "L11:20"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "CallExpr",
|
||||
"callee": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "maximum"
|
||||
},
|
||||
"arguments": [
|
||||
{
|
||||
"_type": "VariableExpr",
|
||||
"name": "v1"
|
||||
},
|
||||
{
|
||||
"_type": "VariableExpr",
|
||||
"name": "v2"
|
||||
}
|
||||
],
|
||||
"keywords": {}
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:5",
|
||||
"to": "L12:7"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v1"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:10",
|
||||
"to": "L12:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v2"
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:5",
|
||||
"to": "L12:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v1"
|
||||
},
|
||||
"operator": "+",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v2"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
57
tests/cases/midas-parser/01_simple_types.midas
Normal file
57
tests/cases/midas-parser/01_simple_types.midas
Normal file
@@ -0,0 +1,57 @@
|
||||
// Simple custom type derived from float
|
||||
type Custom = float
|
||||
|
||||
// Simple custom types with constraints
|
||||
type Latitude = float where (-90 <= _ <= 90)
|
||||
type Longitude = float where (-180 <= _ <= 180)
|
||||
|
||||
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
|
||||
type Difference[T] = T
|
||||
|
||||
// Complex custom type, containing two values accessible through properties
|
||||
type GeoLocation = {
|
||||
lat: Latitude
|
||||
lon: Longitude
|
||||
}
|
||||
|
||||
// Define operations on our custom type
|
||||
extend GeoLocation {
|
||||
// This type is compatible with the `-` operation with another GeoLocation
|
||||
// i.e. you can subtract a GeoLocation from another GeoLocation, resulting
|
||||
// in a Difference of GeoLocations
|
||||
op __sub__(GeoLocation) -> Difference[GeoLocation]
|
||||
}
|
||||
|
||||
// For complex generics, you need to specify how the genericity the properties
|
||||
// are handled
|
||||
type Difference[GeoLocation] = {
|
||||
lat: Difference[Latitude]
|
||||
lon: Difference[Longitude]
|
||||
}
|
||||
|
||||
// Simple operation defined on our custom types
|
||||
extend Latitude {
|
||||
op __sub__(Latitude) -> Difference[Latitude]
|
||||
}
|
||||
|
||||
extend Longitude {
|
||||
op __sub__(Longitude) -> Difference[Longitude]
|
||||
}
|
||||
|
||||
// Predefined custom predicates that can be referenced in other definitions
|
||||
predicate Positive(v: float) = v >= 0
|
||||
predicate StrictlyPositive(v: float) = v > 0
|
||||
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
|
||||
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
|
||||
|
||||
type Person = {
|
||||
name: str
|
||||
|
||||
// Property with an inline constraint
|
||||
age: Optional[int where (0 <= _ < 150)]
|
||||
|
||||
// Property referencing a predicate
|
||||
height: float where StrictlyPositive
|
||||
|
||||
home: GeoLocation
|
||||
}
|
||||
2702
tests/cases/midas-parser/01_simple_types.midas.ref.json
Normal file
2702
tests/cases/midas-parser/01_simple_types.midas.ref.json
Normal file
File diff suppressed because it is too large
Load Diff
14
tests/cases/python-parser/01_simple_types.py
Normal file
14
tests/cases/python-parser/01_simple_types.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
df: Frame[
|
||||
verified: bool,
|
||||
birth_year: int,
|
||||
height: float + ( _ > 0 ) + ( _ < 250 ),
|
||||
name: str,
|
||||
date: datetime,
|
||||
float,
|
||||
unknown: _,
|
||||
_
|
||||
]
|
||||
85
tests/cases/python-parser/01_simple_types.py.ref.json
Normal file
85
tests/cases/python-parser/01_simple_types.py.ref.json
Normal file
@@ -0,0 +1,85 @@
|
||||
{
|
||||
"stmts": [
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "df",
|
||||
"type": {
|
||||
"_type": "FrameType",
|
||||
"columns": [
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "verified",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "bool",
|
||||
"param": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "birth_year",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "height",
|
||||
"type": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "(_ > 0) + (_ < 250)"
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "name",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "str",
|
||||
"param": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "date",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "datetime",
|
||||
"param": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": null,
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "unknown",
|
||||
"type": null
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": null,
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "_",
|
||||
"param": null
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
25
tests/cases/python-parser/02_custom_types.py
Normal file
25
tests/cases/python-parser/02_custom_types.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
df: Frame[
|
||||
location: GeoLocation
|
||||
]
|
||||
|
||||
lat: Column[GeoLocation] = df["location"].lat
|
||||
lon: Column[GeoLocation] = df["location"].lon
|
||||
|
||||
lat + lon
|
||||
|
||||
lat1: Latitude = lat[0]
|
||||
lat2: Latitude = lat[1]
|
||||
lat_diff: Difference[Latitude] = lat2 - lat1
|
||||
|
||||
df2: Frame[
|
||||
age: int + (_ >= 0),
|
||||
height: float + (_ >= 0),
|
||||
]
|
||||
df2_bis: Frame[
|
||||
age: int + Positive,
|
||||
height: float + Positive,
|
||||
]
|
||||
141
tests/cases/python-parser/02_custom_types.py.ref.json
Normal file
141
tests/cases/python-parser/02_custom_types.py.ref.json
Normal file
@@ -0,0 +1,141 @@
|
||||
{
|
||||
"stmts": [
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "df",
|
||||
"type": {
|
||||
"_type": "FrameType",
|
||||
"columns": [
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "location",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "GeoLocation",
|
||||
"param": null
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "ExpressionStmt",
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "lat"
|
||||
},
|
||||
"operator": "+",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "lon"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "lat_diff",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Difference",
|
||||
"param": {
|
||||
"_type": "BaseType",
|
||||
"base": "Latitude",
|
||||
"param": null
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "AssignStmt",
|
||||
"targets": [
|
||||
{
|
||||
"_type": "VariableExpr",
|
||||
"name": "lat_diff"
|
||||
}
|
||||
],
|
||||
"value": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "lat2"
|
||||
},
|
||||
"operator": "-",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "lat1"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "df2",
|
||||
"type": {
|
||||
"_type": "FrameType",
|
||||
"columns": [
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "age",
|
||||
"type": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "_ >= 0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "height",
|
||||
"type": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "_ >= 0"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "df2_bis",
|
||||
"type": {
|
||||
"_type": "FrameType",
|
||||
"columns": [
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "age",
|
||||
"type": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "Positive"
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "height",
|
||||
"type": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "Positive"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
15
tests/cases/python-parser/03_functions.py
Normal file
15
tests/cases/python-parser/03_functions.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def func(
|
||||
col1: Column[float + (0 <= _ <= 1)],
|
||||
col2: Column[float + (0 <= _ <= 1)],
|
||||
) -> Column[float + (0 <= _ <= 2)]:
|
||||
result: Column[float + (0 <= _ <= 2)] = col1 + col2
|
||||
return result
|
||||
|
||||
|
||||
def func2(a: int, /, b: float, *, c: str):
|
||||
pass
|
||||
149
tests/cases/python-parser/03_functions.py.ref.json
Normal file
149
tests/cases/python-parser/03_functions.py.ref.json
Normal file
@@ -0,0 +1,149 @@
|
||||
{
|
||||
"stmts": [
|
||||
{
|
||||
"_type": "Function",
|
||||
"name": "func",
|
||||
"posonlyargs": [],
|
||||
"args": [
|
||||
{
|
||||
"name": "col1",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 1"
|
||||
}
|
||||
},
|
||||
"default": null
|
||||
},
|
||||
{
|
||||
"name": "col2",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 1"
|
||||
}
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"sink": null,
|
||||
"kwonlyargs": [],
|
||||
"kw_sink": null,
|
||||
"returns": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 2"
|
||||
}
|
||||
},
|
||||
"body": [
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "result",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 2"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "AssignStmt",
|
||||
"targets": [
|
||||
{
|
||||
"_type": "VariableExpr",
|
||||
"name": "result"
|
||||
}
|
||||
],
|
||||
"value": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "col1"
|
||||
},
|
||||
"operator": "+",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "col2"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "ReturnStmt",
|
||||
"value": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "result"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"_type": "Function",
|
||||
"name": "func2",
|
||||
"posonlyargs": [
|
||||
{
|
||||
"name": "a",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"args": [
|
||||
{
|
||||
"name": "b",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"sink": null,
|
||||
"kwonlyargs": [
|
||||
{
|
||||
"name": "c",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "str",
|
||||
"param": null
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"kw_sink": null,
|
||||
"returns": null,
|
||||
"body": []
|
||||
}
|
||||
]
|
||||
}
|
||||
94
tests/checker.py
Normal file
94
tests/checker.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import ast
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.checker.checker import Checker
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.checker.types import Type
|
||||
from midas.parser.python import PythonParser
|
||||
from midas.resolver.resolver import Resolver
|
||||
from tests.base import Tester
|
||||
from tests.serializer.python import PythonAstJsonSerializer
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaseResult:
|
||||
diagnostics: list[dict] = field(default_factory=list)
|
||||
judgments: list = field(default_factory=list)
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(asdict(self), indent=2)
|
||||
|
||||
|
||||
class CheckerTester(Tester):
|
||||
@property
|
||||
def namespace(self) -> str:
|
||||
return "checker"
|
||||
|
||||
def _list_tests(self) -> list[Path]:
|
||||
return list(self.base_dir.rglob("*.py"))
|
||||
|
||||
def _exec_case(self, path: Path) -> CaseResult:
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Could not find test '{path}'")
|
||||
if not path.is_file():
|
||||
raise TypeError(f"Test '{path}' is not a file")
|
||||
|
||||
types_paths: list[Path] = []
|
||||
types_path: Path = path.with_suffix(".midas")
|
||||
if types_path.exists():
|
||||
types_paths.append(types_path)
|
||||
source: str = path.read_text()
|
||||
tree: ast.Module = ast.parse(source, filename=path)
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
resolver = Resolver()
|
||||
resolver.resolve(*stmts)
|
||||
result: CaseResult = CaseResult()
|
||||
checker = Checker(
|
||||
resolver.locals,
|
||||
source_path=path,
|
||||
types_paths=types_paths,
|
||||
)
|
||||
|
||||
diagnostics: list[Diagnostic] = checker.check(stmts)
|
||||
for diagnostic in diagnostics:
|
||||
result.diagnostics.append(
|
||||
{
|
||||
"type": str(diagnostic.type),
|
||||
"location": {
|
||||
"start": (
|
||||
diagnostic.location.lineno,
|
||||
diagnostic.location.col_offset,
|
||||
),
|
||||
"end": (
|
||||
diagnostic.location.end_lineno,
|
||||
diagnostic.location.end_col_offset,
|
||||
),
|
||||
},
|
||||
"message": diagnostic.message,
|
||||
}
|
||||
)
|
||||
|
||||
judgements: list[tuple[p.Expr, Type]] = checker.judgements
|
||||
serializer = PythonAstJsonSerializer()
|
||||
for expr, type in judgements:
|
||||
loc = expr.location
|
||||
result.judgments.append(
|
||||
{
|
||||
"location": {
|
||||
"from": f"L{loc.lineno}:{loc.col_offset}",
|
||||
"to": f"L{loc.end_lineno}:{loc.end_col_offset}",
|
||||
},
|
||||
"expr": expr.accept(serializer),
|
||||
"type": asdict(type),
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CheckerTester.main()
|
||||
@@ -1,129 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from lexer.annotations import AnnotationLexer
|
||||
from lexer.token import Token, TokenType
|
||||
|
||||
|
||||
def scan(source: str) -> list[Token]:
|
||||
return AnnotationLexer(source).process()
|
||||
|
||||
|
||||
def assert_n_tokens(tokens: list[Token], n: int):
|
||||
assert len(tokens) == n + 1
|
||||
assert tokens[-1].type == TokenType.EOF
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("(", TokenType.LEFT_PAREN),
|
||||
(")", TokenType.RIGHT_PAREN),
|
||||
("[", TokenType.LEFT_BRACKET),
|
||||
("]", TokenType.RIGHT_BRACKET),
|
||||
(":", TokenType.COLON),
|
||||
(",", TokenType.COMMA),
|
||||
("_", TokenType.UNDERSCORE),
|
||||
],
|
||||
)
|
||||
def test_punctuation(src: str, expected: TokenType):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("+", TokenType.PLUS),
|
||||
(">", TokenType.GREATER),
|
||||
(">=", TokenType.GREATER_EQUAL),
|
||||
("<", TokenType.LESS),
|
||||
("<=", TokenType.LESS_EQUAL),
|
||||
("=", TokenType.EQUAL),
|
||||
("==", TokenType.EQUAL_EQUAL),
|
||||
("!=", TokenType.BANG_EQUAL),
|
||||
],
|
||||
)
|
||||
def test_operators(src: str, expected: TokenType):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("a", TokenType.IDENTIFIER),
|
||||
("foo", TokenType.IDENTIFIER),
|
||||
("foo1", TokenType.IDENTIFIER),
|
||||
("foo_", TokenType.IDENTIFIER),
|
||||
("foo_bar1_baz2", TokenType.IDENTIFIER),
|
||||
("FOO_BAR1_BAZ2", TokenType.IDENTIFIER),
|
||||
("True", TokenType.TRUE),
|
||||
("False", TokenType.FALSE),
|
||||
("None", TokenType.NONE),
|
||||
],
|
||||
)
|
||||
def test_identifiers_keywords(src: str, expected: TokenType):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("#", TokenType.COMMENT),
|
||||
("# This is a comment", TokenType.COMMENT),
|
||||
(" ", TokenType.WHITESPACE),
|
||||
("\t", TokenType.WHITESPACE),
|
||||
("\r", TokenType.WHITESPACE),
|
||||
(" \t \t", TokenType.WHITESPACE),
|
||||
("\n", TokenType.NEWLINE),
|
||||
],
|
||||
)
|
||||
def test_misc(src: str, expected: TokenType):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected_type,expected_value",
|
||||
[
|
||||
("0", TokenType.NUMBER, 0),
|
||||
("0.0", TokenType.NUMBER, 0),
|
||||
("1234.56", TokenType.NUMBER, 1234.56),
|
||||
],
|
||||
)
|
||||
def test_literals(src: str, expected_type: TokenType, expected_value: Any):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected_type
|
||||
assert tokens[0].value == expected_value
|
||||
|
||||
|
||||
def test_single_bang_error():
|
||||
with pytest.raises(SyntaxError):
|
||||
scan("!")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src",
|
||||
[
|
||||
"-",
|
||||
"*",
|
||||
"/",
|
||||
"{",
|
||||
"}",
|
||||
"@",
|
||||
'"',
|
||||
"'",
|
||||
".",
|
||||
],
|
||||
)
|
||||
def test_unexpected_character(src: str):
|
||||
with pytest.raises(SyntaxError):
|
||||
scan(src)
|
||||
@@ -1,129 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from lexer.midas import MidasLexer
|
||||
from lexer.token import Token, TokenType
|
||||
|
||||
|
||||
def scan(source: str) -> list[Token]:
|
||||
return MidasLexer(source).process()
|
||||
|
||||
|
||||
def assert_n_tokens(tokens: list[Token], n: int):
|
||||
assert len(tokens) == n + 1
|
||||
assert tokens[-1].type == TokenType.EOF
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("(", TokenType.LEFT_PAREN),
|
||||
(")", TokenType.RIGHT_PAREN),
|
||||
("[", TokenType.LEFT_BRACKET),
|
||||
("]", TokenType.RIGHT_BRACKET),
|
||||
("{", TokenType.LEFT_BRACE),
|
||||
("}", TokenType.RIGHT_BRACE),
|
||||
(":", TokenType.COLON),
|
||||
(",", TokenType.COMMA),
|
||||
("_", TokenType.UNDERSCORE),
|
||||
],
|
||||
)
|
||||
def test_punctuation(src: str, expected: TokenType):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("+", TokenType.PLUS),
|
||||
("-", TokenType.MINUS),
|
||||
("*", TokenType.STAR),
|
||||
("/", TokenType.SLASH),
|
||||
(">", TokenType.GREATER),
|
||||
(">=", TokenType.GREATER_EQUAL),
|
||||
("<", TokenType.LESS),
|
||||
("<=", TokenType.LESS_EQUAL),
|
||||
("=", TokenType.EQUAL),
|
||||
("==", TokenType.EQUAL_EQUAL),
|
||||
("!=", TokenType.BANG_EQUAL),
|
||||
],
|
||||
)
|
||||
def test_operators(src: str, expected: TokenType):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("a", TokenType.IDENTIFIER),
|
||||
("foo", TokenType.IDENTIFIER),
|
||||
("foo1", TokenType.IDENTIFIER),
|
||||
("foo_", TokenType.IDENTIFIER),
|
||||
("foo_bar1_baz2", TokenType.IDENTIFIER),
|
||||
("FOO_BAR1_BAZ2", TokenType.IDENTIFIER),
|
||||
("true", TokenType.TRUE),
|
||||
("false", TokenType.FALSE),
|
||||
("none", TokenType.NONE),
|
||||
],
|
||||
)
|
||||
def test_identifiers_keywords(src: str, expected: TokenType):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("// This is a comment", TokenType.COMMENT),
|
||||
("/* This is a comment */", TokenType.COMMENT),
|
||||
(" ", TokenType.WHITESPACE),
|
||||
("\t", TokenType.WHITESPACE),
|
||||
("\r", TokenType.WHITESPACE),
|
||||
(" \t \t", TokenType.WHITESPACE),
|
||||
("\n", TokenType.NEWLINE),
|
||||
],
|
||||
)
|
||||
def test_misc(src: str, expected: TokenType):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected_type,expected_value",
|
||||
[
|
||||
("0", TokenType.NUMBER, 0),
|
||||
("0.0", TokenType.NUMBER, 0),
|
||||
("1234.56", TokenType.NUMBER, 1234.56),
|
||||
],
|
||||
)
|
||||
def test_literals(src: str, expected_type: TokenType, expected_value: Any):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected_type
|
||||
assert tokens[0].value == expected_value
|
||||
|
||||
|
||||
def test_single_bang_error():
|
||||
with pytest.raises(SyntaxError):
|
||||
scan("!")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src",
|
||||
[
|
||||
"@",
|
||||
'"',
|
||||
"'",
|
||||
".",
|
||||
],
|
||||
)
|
||||
def test_unexpected_character(src: str):
|
||||
with pytest.raises(SyntaxError):
|
||||
scan(src)
|
||||
82
tests/midas.py
Normal file
82
tests/midas.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.midas import Stmt
|
||||
from midas.lexer.base import MidasSyntaxError
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
from tests.base import Tester
|
||||
from tests.serializer.midas import MidasAstJsonSerializer
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaseResult:
|
||||
tokens: Optional[list[dict]] = None
|
||||
stmts: Optional[list[dict]] = None
|
||||
errors: list[dict] = field(default_factory=list)
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(asdict(self), indent=2)
|
||||
|
||||
|
||||
class MidasTester(Tester):
|
||||
@property
|
||||
def namespace(self) -> str:
|
||||
return "midas-parser"
|
||||
|
||||
def _list_tests(self) -> list[Path]:
|
||||
return list(self.base_dir.rglob("*.midas"))
|
||||
|
||||
def _exec_case(self, path: Path) -> CaseResult:
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Could not find test '{path}'")
|
||||
if not path.is_file():
|
||||
raise TypeError(f"Test '{path}' is not a file")
|
||||
|
||||
result: CaseResult = CaseResult()
|
||||
content: str = path.read_text()
|
||||
lexer: MidasLexer = MidasLexer(content)
|
||||
tokens: list[Token] = []
|
||||
try:
|
||||
tokens = lexer.process()
|
||||
result.tokens = [
|
||||
{
|
||||
"type": token.type.name,
|
||||
"lexeme": token.lexeme,
|
||||
"line": token.position.line,
|
||||
"column": token.position.column,
|
||||
}
|
||||
for token in tokens
|
||||
]
|
||||
except MidasSyntaxError as e:
|
||||
result.errors.append(
|
||||
{
|
||||
"type": "SyntaxError",
|
||||
"line": e.pos.line,
|
||||
"column": e.pos.column,
|
||||
"message": e.message,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
parser: MidasParser = MidasParser(tokens)
|
||||
stmts: list[Stmt] = parser.parse()
|
||||
result.stmts = MidasAstJsonSerializer().serialize(stmts)
|
||||
result.errors.extend(
|
||||
[
|
||||
{
|
||||
"line": e.token.position.line,
|
||||
"column": e.token.position.column,
|
||||
"message": e.message,
|
||||
}
|
||||
for e in parser.errors
|
||||
]
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
MidasTester.main()
|
||||
@@ -1,130 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ast.annotations import (
|
||||
AnnotationStmt,
|
||||
ConstraintExpr,
|
||||
Expr,
|
||||
LiteralExpr,
|
||||
SchemaElementExpr,
|
||||
SchemaExpr,
|
||||
Stmt,
|
||||
TypeExpr,
|
||||
WildcardExpr,
|
||||
)
|
||||
from lexer.annotations import AnnotationLexer
|
||||
from lexer.position import Position
|
||||
from lexer.token import Token
|
||||
from parser.annotations import AnnotationParser
|
||||
|
||||
|
||||
class AstSerializer(Stmt.Visitor[str], Expr.Visitor[str]):
|
||||
def serialize(self, stmt: Stmt):
|
||||
return stmt.accept(self)
|
||||
|
||||
def visit_annotation_stmt(self, stmt: AnnotationStmt) -> str:
|
||||
schema: str = ""
|
||||
if stmt.schema is not None:
|
||||
schema = " " + stmt.schema.accept(self)
|
||||
return f"(annotation {stmt.name.lexeme}{schema})"
|
||||
|
||||
def visit_schema_expr(self, expr: SchemaExpr) -> str:
|
||||
elements: list[str] = [elmt.accept(self) for elmt in expr.elements]
|
||||
return f"(schema {' '.join(elements)})"
|
||||
|
||||
def visit_schema_element_expr(self, expr: SchemaElementExpr) -> str:
|
||||
name: str = expr.name.lexeme if expr.name is not None else "_"
|
||||
type: str = expr.type.accept(self) if expr.type is not None else "_"
|
||||
return f"({name} {type})"
|
||||
|
||||
def visit_type_expr(self, expr: TypeExpr) -> str:
|
||||
res: str = f"({expr.name.lexeme}"
|
||||
for constraint in expr.constraints:
|
||||
res += " " + constraint.accept(self)
|
||||
res += ")"
|
||||
return res
|
||||
|
||||
def visit_constraint_expr(self, expr: ConstraintExpr) -> str:
|
||||
return f"(constraint {expr.left.accept(self)} {expr.op.lexeme} {expr.right.accept(self)})"
|
||||
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> str:
|
||||
return "(_)"
|
||||
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> str:
|
||||
return f"({expr.value})"
|
||||
|
||||
|
||||
def parse(source: str) -> Optional[Stmt]:
|
||||
tokens: list[Token] = AnnotationLexer(source).process()
|
||||
return AnnotationParser(tokens).parse()
|
||||
|
||||
|
||||
def must_parse(source: str) -> Stmt:
|
||||
stmt: Optional[Stmt] = parse(source)
|
||||
assert stmt is not None
|
||||
return stmt
|
||||
|
||||
|
||||
def ast_str(source: str) -> str:
|
||||
stmt: Stmt = must_parse(source)
|
||||
return AstSerializer().serialize(stmt)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("Type", "(annotation Type)"),
|
||||
("Type[]", "(annotation Type (schema ))"),
|
||||
(
|
||||
"""
|
||||
Frame[
|
||||
verified: bool,
|
||||
birth_year: int,
|
||||
height: float + ( _ > 0 ) + ( _ < 250 ),
|
||||
name: str,
|
||||
date: datetime,
|
||||
float, # unnamed
|
||||
unknown: _, # untyped
|
||||
_ # unnamed and untyped
|
||||
]
|
||||
""",
|
||||
"(annotation Frame (schema (verified (bool)) (birth_year (int)) (height (float (constraint (_) > (0.0)) (constraint (_) < (250.0)))) (name (str)) (date (datetime)) (_ (float)) (unknown _) (_ _)))",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_expressions(src: str, expected: str):
|
||||
assert ast_str(src) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,pos,should_fail",
|
||||
[
|
||||
("", (1, 1), True),
|
||||
("42", (1, 1), True),
|
||||
("True", (1, 1), True),
|
||||
("Type[", (1, 6), True),
|
||||
("Type[] Type2", (1, 8), False),
|
||||
("Type[bool:]", (1, 11), True),
|
||||
("Type[3]", (1, 6), True),
|
||||
("Type[bool float]", (1, 11), True),
|
||||
("Type[bool (_ < 2)]", (1, 11), True),
|
||||
("Type[bool + _ < 2)]", (1, 13), True),
|
||||
("Type[bool + (_ < 2]", (1, 19), True),
|
||||
("Type[bool + (< 2)]", (1, 14), True),
|
||||
("Type[bool + (_ + 2)]", (1, 16), True),
|
||||
("Type[bool + (Foo + Bar)]", (1, 14), True),
|
||||
# ("Type[bool,]", (1, 11), True), # trailing comma is accepted, TODO: update parser or EBNF
|
||||
("Type[bool, Type[]]", (1, 16), True),
|
||||
("Type[foo: 3]", (1, 11), True),
|
||||
],
|
||||
)
|
||||
def test_parsing_error(src: str, pos: tuple[int, int], should_fail: bool):
|
||||
tokens: list[Token] = AnnotationLexer(src).process()
|
||||
parser: AnnotationParser = AnnotationParser(tokens)
|
||||
stmt: Optional[Stmt] = parser.parse()
|
||||
if should_fail:
|
||||
assert stmt is None
|
||||
assert len(parser.errors) != 0
|
||||
error_pos: Position = parser.errors[0].token.position
|
||||
assert (error_pos.line, error_pos.column) == pos
|
||||
@@ -1,202 +0,0 @@
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ast.midas import (
|
||||
ConstraintExpr,
|
||||
ConstraintStmt,
|
||||
Expr,
|
||||
LiteralExpr,
|
||||
OpStmt,
|
||||
PropertyStmt,
|
||||
Stmt,
|
||||
TypeBodyExpr,
|
||||
TypeExpr,
|
||||
TypeStmt,
|
||||
WildcardExpr,
|
||||
)
|
||||
from lexer.midas import MidasLexer
|
||||
from lexer.position import Position
|
||||
from lexer.token import Token
|
||||
from parser.midas import MidasParser
|
||||
|
||||
|
||||
class AstSerializer(Stmt.Visitor[str], Expr.Visitor[str]):
|
||||
def serialize(self, stmt: Stmt):
|
||||
return stmt.accept(self)
|
||||
|
||||
def visit_type_stmt(self, stmt: TypeStmt) -> str:
|
||||
res: str = f"(type_def {stmt.name.lexeme}"
|
||||
for base in stmt.bases:
|
||||
res += " " + base.accept(self)
|
||||
if stmt.body is not None:
|
||||
res += " " + stmt.body.accept(self)
|
||||
res += ")"
|
||||
return res
|
||||
|
||||
def visit_type_expr(self, expr: TypeExpr) -> str:
|
||||
res: str = f"({expr.name.lexeme}"
|
||||
for constraint in expr.constraints:
|
||||
res += " " + constraint.accept(self)
|
||||
res += ")"
|
||||
return res
|
||||
|
||||
def visit_constraint_expr(self, expr: ConstraintExpr) -> str:
|
||||
return f"(constraint {expr.left.accept(self)} {expr.op.lexeme} {expr.right.accept(self)})"
|
||||
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> str:
|
||||
return "(_)"
|
||||
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> str:
|
||||
return f"({expr.value})"
|
||||
|
||||
def visit_type_body_expr(self, expr: TypeBodyExpr) -> str:
|
||||
res: str = "(body"
|
||||
for prop in expr.properties:
|
||||
res += " " + prop.accept(self)
|
||||
res += ")"
|
||||
return res
|
||||
|
||||
def visit_property_stmt(self, stmt: PropertyStmt) -> str:
|
||||
return f"(property {stmt.name.lexeme} {stmt.type.accept(self)})"
|
||||
|
||||
def visit_op_stmt(self, stmt: OpStmt) -> str:
|
||||
left: str = stmt.left.accept(self)
|
||||
right: str = stmt.right.accept(self)
|
||||
result: str = stmt.result.accept(self)
|
||||
return f"(op_def {left} {stmt.op.lexeme} {right} {result})"
|
||||
|
||||
def visit_constraint_stmt(self, stmt: ConstraintStmt) -> str:
|
||||
return f"(constraint_def {stmt.name.lexeme} {stmt.constraint.accept(self)})"
|
||||
|
||||
|
||||
def parse(source: str) -> list[Stmt]:
|
||||
tokens: list[Token] = MidasLexer(source).process()
|
||||
return MidasParser(tokens).parse()
|
||||
|
||||
|
||||
def ast_str(source: str) -> list[str]:
|
||||
stmts: list[Stmt] = parse(source)
|
||||
return [AstSerializer().serialize(stmt) for stmt in stmts]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("type Foo<>", "(type_def Foo)"),
|
||||
("type Foo<Bar>", "(type_def Foo (Bar))"),
|
||||
("type Foo<Bar, Baz>", "(type_def Foo (Bar) (Baz))"),
|
||||
(
|
||||
"type Foo<Bar + (_ < 2), Baz>",
|
||||
"(type_def Foo (Bar (constraint (_) < (2.0))) (Baz))",
|
||||
),
|
||||
(
|
||||
"""
|
||||
type Foo<> {
|
||||
foo: Bar
|
||||
}
|
||||
""",
|
||||
"(type_def Foo (body (property foo (Bar))))",
|
||||
),
|
||||
(
|
||||
"""
|
||||
type Foo<> {
|
||||
foo: Bar + (_ != none)
|
||||
foo2: Bar2 + (0 <= _) + (_ <= 100)
|
||||
}
|
||||
""",
|
||||
"(type_def Foo (body (property foo (Bar (constraint (_) != (None)))) (property foo2 (Bar2 (constraint (0.0) <= (_)) (constraint (_) <= (100.0))))))",
|
||||
),
|
||||
("op <A> + <B> = <C>", "(op_def (A) + (B) (C))"),
|
||||
(
|
||||
"op <A + (_ < 100)> + <B + (_ < 100)> = <C + (_ < 200)>",
|
||||
"(op_def (A (constraint (_) < (100.0))) + (B (constraint (_) < (100.0))) (C (constraint (_) < (200.0))))",
|
||||
),
|
||||
(
|
||||
"constraint Positive = _ >= 0",
|
||||
"(constraint_def Positive (constraint (_) >= (0.0)))",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_expressions(src: str, expected: str | list[str]):
|
||||
if isinstance(expected, str):
|
||||
expected = [expected]
|
||||
assert ast_str(src) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,pos",
|
||||
[
|
||||
###
|
||||
# Misc
|
||||
###
|
||||
("42", (1, 1)),
|
||||
("true", (1, 1)),
|
||||
("foo", (1, 1)),
|
||||
###
|
||||
# Type statements
|
||||
###
|
||||
("type", (1, 5)),
|
||||
("type true", (1, 6)),
|
||||
("type Foo", (1, 9)),
|
||||
("type Foo<1>", (1, 10)),
|
||||
# ("type Foo<float,>", (1, 16)), # trailing comma is accepted, TODO: update parser or EBNF
|
||||
("type Foo<float, 1>", (1, 17)),
|
||||
("type Foo<float", (1, 15)),
|
||||
("type Foo<float> { 3 }", (1, 19)),
|
||||
(
|
||||
"""
|
||||
type Foo<float> {
|
||||
foo
|
||||
}
|
||||
""",
|
||||
(4, 1),
|
||||
),
|
||||
(
|
||||
"""
|
||||
type Foo<float> {
|
||||
foo: 3
|
||||
}
|
||||
""",
|
||||
(3, 10),
|
||||
),
|
||||
###
|
||||
# Operation statements
|
||||
###
|
||||
("op", (1, 3)),
|
||||
("op float", (1, 4)),
|
||||
("op <", (1, 5)),
|
||||
("op <float", (1, 10)),
|
||||
("op <float>", (1, 11)),
|
||||
("op <float> +", (1, 13)),
|
||||
("op <float> + float", (1, 14)),
|
||||
("op <float> + <", (1, 15)),
|
||||
("op <float> + <float", (1, 20)),
|
||||
("op <float> + <float>", (1, 21)),
|
||||
("op <float> + <float> =", (1, 23)),
|
||||
("op <float> + <float> = float", (1, 24)),
|
||||
("op <float> + <float> = <", (1, 25)),
|
||||
("op <float> + <float> = <float", (1, 30)),
|
||||
("op <float + 3> + <float> = <float>", (1, 13)),
|
||||
("op <float> + <float + 3> = <float>", (1, 23)),
|
||||
("op <float> + <float> = <float + 3>", (1, 33)),
|
||||
###
|
||||
# Constraint statements
|
||||
###
|
||||
("constraint", (1, 11)),
|
||||
("constraint 3", (1, 12)),
|
||||
("constraint Foo", (1, 15)),
|
||||
("constraint Foo =", (1, 17)),
|
||||
("constraint Foo = 3", (1, 19)),
|
||||
("constraint Foo = 3 <", (1, 21)),
|
||||
],
|
||||
)
|
||||
def test_parsing_error(src: str, pos: tuple[int, int]):
|
||||
src = textwrap.dedent(src)
|
||||
tokens: list[Token] = MidasLexer(src).process()
|
||||
parser: MidasParser = MidasParser(tokens)
|
||||
stmt: list[Stmt] = parser.parse()
|
||||
assert len(stmt) == 0
|
||||
assert len(parser.errors) != 0
|
||||
error_pos: Position = parser.errors[0].token.position
|
||||
assert (error_pos.line, error_pos.column) == pos
|
||||
46
tests/python.py
Normal file
46
tests/python.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import ast
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.python import Stmt
|
||||
from midas.parser.python import PythonParser
|
||||
from tests.base import Tester
|
||||
from tests.serializer.python import PythonAstJsonSerializer
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaseResult:
|
||||
stmts: Optional[list[dict]] = None
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(asdict(self), indent=2)
|
||||
|
||||
|
||||
class PythonTester(Tester):
|
||||
@property
|
||||
def namespace(self) -> str:
|
||||
return "python-parser"
|
||||
|
||||
def _list_tests(self) -> list[Path]:
|
||||
return list(self.base_dir.rglob("*.py"))
|
||||
|
||||
def _exec_case(self, path: Path) -> CaseResult:
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Could not find test '{path}'")
|
||||
if not path.is_file():
|
||||
raise TypeError(f"Test '{path}' is not a file")
|
||||
|
||||
result: CaseResult = CaseResult()
|
||||
content: str = path.read_text()
|
||||
tree: ast.Module = ast.parse(content)
|
||||
|
||||
parser: PythonParser = PythonParser()
|
||||
stmts: list[Stmt] = parser.parse_module(tree)
|
||||
result.stmts = PythonAstJsonSerializer().serialize(stmts)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
PythonTester.main()
|
||||
167
tests/serializer/midas.py
Normal file
167
tests/serializer/midas.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from midas.ast.midas import (
|
||||
BinaryExpr,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
GenericType,
|
||||
GetExpr,
|
||||
GroupingExpr,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
NamedType,
|
||||
OpStmt,
|
||||
PredicateStmt,
|
||||
PropertyStmt,
|
||||
Stmt,
|
||||
Type,
|
||||
TypeStmt,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
WildcardExpr,
|
||||
)
|
||||
|
||||
|
||||
class MidasAstJsonSerializer(
|
||||
Stmt.Visitor[dict], Expr.Visitor[dict], Type.Visitor[dict]
|
||||
):
|
||||
"""An AST serializer which produces a JSON-compatible structure"""
|
||||
|
||||
def serialize(self, stmts: list[Stmt]) -> list[dict]:
|
||||
return [stmt.accept(self) for stmt in stmts]
|
||||
|
||||
def _serialize_optional(
|
||||
self, element: Optional[Stmt | Expr | Type]
|
||||
) -> Optional[dict]:
|
||||
if element is None:
|
||||
return None
|
||||
return element.accept(self)
|
||||
|
||||
def _serialize_list(self, elements: Sequence[Stmt | Expr | Type]) -> list[dict]:
|
||||
return [element.accept(self) for element in elements]
|
||||
|
||||
def visit_type_stmt(self, stmt: TypeStmt) -> dict:
|
||||
return {
|
||||
"_type": "TypeStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"params": [
|
||||
self._serialize_type_stmt_template_param(param) for param in stmt.params
|
||||
],
|
||||
"type": stmt.type.accept(self),
|
||||
}
|
||||
|
||||
def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict:
|
||||
return {
|
||||
"name": param.name.lexeme,
|
||||
"bound": self._serialize_optional(param.bound),
|
||||
}
|
||||
|
||||
def visit_property_stmt(self, stmt: PropertyStmt) -> dict:
|
||||
return {
|
||||
"_type": "PropertyStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"type": stmt.type.accept(self),
|
||||
}
|
||||
|
||||
def visit_extend_stmt(self, stmt: ExtendStmt) -> dict:
|
||||
return {
|
||||
"_type": "ExtendStmt",
|
||||
"type": stmt.type.accept(self),
|
||||
"operations": self._serialize_list(stmt.operations),
|
||||
}
|
||||
|
||||
def visit_op_stmt(self, stmt: OpStmt) -> dict:
|
||||
return {
|
||||
"_type": "OpStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"operand": stmt.operand.accept(self),
|
||||
"result": stmt.result.accept(self),
|
||||
}
|
||||
|
||||
def visit_predicate_stmt(self, stmt: PredicateStmt) -> dict:
|
||||
return {
|
||||
"_type": "PredicateStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"subject": stmt.subject.lexeme,
|
||||
"type": stmt.type.accept(self),
|
||||
"condition": stmt.condition.accept(self),
|
||||
}
|
||||
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
||||
return {
|
||||
"_type": "LogicalExpr",
|
||||
"left": expr.left.accept(self),
|
||||
"operator": expr.operator.lexeme,
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
|
||||
return {
|
||||
"_type": "BinaryExpr",
|
||||
"left": expr.left.accept(self),
|
||||
"operator": expr.operator.lexeme,
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_unary_expr(self, expr: UnaryExpr) -> dict:
|
||||
return {
|
||||
"_type": "UnaryExpr",
|
||||
"operator": expr.operator.lexeme,
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_get_expr(self, expr: GetExpr) -> dict:
|
||||
return {
|
||||
"_type": "GetExpr",
|
||||
"expr": expr.expr.accept(self),
|
||||
"name": expr.name.lexeme,
|
||||
}
|
||||
|
||||
def visit_variable_expr(self, expr: VariableExpr) -> dict:
|
||||
return {
|
||||
"_type": "VariableExpr",
|
||||
"name": expr.name.lexeme,
|
||||
}
|
||||
|
||||
def visit_grouping_expr(self, expr: GroupingExpr) -> dict:
|
||||
return {
|
||||
"_type": "GroupingExpr",
|
||||
"expr": expr.expr.accept(self),
|
||||
}
|
||||
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> dict:
|
||||
return {
|
||||
"_type": "LiteralExpr",
|
||||
"value": expr.value,
|
||||
}
|
||||
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> dict:
|
||||
return {"_type": "WildcardExpr"}
|
||||
|
||||
def visit_named_type(self, type: NamedType) -> dict:
|
||||
return {
|
||||
"_type": "NamedType",
|
||||
"name": type.name.lexeme,
|
||||
}
|
||||
|
||||
def visit_generic_type(self, type: GenericType) -> dict:
|
||||
return {
|
||||
"_type": "GenericType",
|
||||
"type": type.type.accept(self),
|
||||
"params": self._serialize_list(type.params),
|
||||
}
|
||||
|
||||
def visit_constraint_type(self, type: ConstraintType) -> dict:
|
||||
return {
|
||||
"_type": "ConstraintType",
|
||||
"type": type.type.accept(self),
|
||||
"constraint": type.constraint.accept(self),
|
||||
}
|
||||
|
||||
def visit_complex_type(self, type: ComplexType) -> dict:
|
||||
return {
|
||||
"_type": "ComplexType",
|
||||
"properties": self._serialize_list(type.properties),
|
||||
}
|
||||
247
tests/serializer/python.py
Normal file
247
tests/serializer/python.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import ast
|
||||
from typing import Optional, Sequence, Type
|
||||
|
||||
from midas.ast.python import (
|
||||
AssignStmt,
|
||||
BaseType,
|
||||
BinaryExpr,
|
||||
CallExpr,
|
||||
CastExpr,
|
||||
CompareExpr,
|
||||
ConstraintType,
|
||||
Expr,
|
||||
ExpressionStmt,
|
||||
FrameColumn,
|
||||
FrameType,
|
||||
Function,
|
||||
GetExpr,
|
||||
IfStmt,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
MidasType,
|
||||
ReturnStmt,
|
||||
Stmt,
|
||||
TernaryExpr,
|
||||
TypeAssign,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
)
|
||||
|
||||
unary_ops: dict[Type[ast.unaryop], str] = {
|
||||
ast.Invert: "~",
|
||||
ast.Not: "not",
|
||||
ast.UAdd: "+",
|
||||
ast.USub: "-",
|
||||
}
|
||||
binary_ops: dict[Type[ast.operator], str] = {
|
||||
ast.Add: "+",
|
||||
ast.Sub: "-",
|
||||
ast.Mult: "*",
|
||||
ast.MatMult: "@",
|
||||
ast.Div: "/",
|
||||
ast.Mod: "%",
|
||||
ast.LShift: "<<",
|
||||
ast.RShift: ">>",
|
||||
ast.BitOr: "|",
|
||||
ast.BitXor: "^",
|
||||
ast.BitAnd: "&",
|
||||
ast.FloorDiv: "//",
|
||||
ast.Pow: "**",
|
||||
}
|
||||
compare_ops: dict[Type[ast.cmpop], str] = {
|
||||
ast.Eq: "==",
|
||||
ast.NotEq: "!=",
|
||||
ast.Lt: "<",
|
||||
ast.LtE: "<=",
|
||||
ast.Gt: ">",
|
||||
ast.GtE: ">=",
|
||||
ast.Is: "is",
|
||||
ast.IsNot: "is not",
|
||||
ast.In: "in",
|
||||
ast.NotIn: "not in",
|
||||
}
|
||||
boolean_ops: dict[Type[ast.boolop], str] = {
|
||||
ast.And: "and",
|
||||
ast.Or: "or",
|
||||
}
|
||||
|
||||
|
||||
class PythonAstJsonSerializer(
|
||||
Stmt.Visitor[dict], Expr.Visitor[dict], MidasType.Visitor[dict]
|
||||
):
|
||||
"""An AST serializer which produces a JSON-compatible structure"""
|
||||
|
||||
def serialize(self, stmts: list[Stmt]) -> list[dict]:
|
||||
return [stmt.accept(self) for stmt in stmts]
|
||||
|
||||
def _serialize_optional(
|
||||
self, element: Optional[Stmt | Expr | MidasType]
|
||||
) -> Optional[dict]:
|
||||
if element is None:
|
||||
return None
|
||||
return element.accept(self)
|
||||
|
||||
def _serialize_list(
|
||||
self, elements: Sequence[Stmt | Expr | MidasType]
|
||||
) -> list[dict]:
|
||||
return [element.accept(self) for element in elements]
|
||||
|
||||
def visit_base_type(self, node: BaseType) -> dict:
|
||||
return {
|
||||
"_type": "BaseType",
|
||||
"base": node.base,
|
||||
"param": self._serialize_optional(node.param),
|
||||
}
|
||||
|
||||
def visit_constraint_type(self, node: ConstraintType) -> dict:
|
||||
return {
|
||||
"_type": "ConstraintType",
|
||||
"type": node.type.accept(self),
|
||||
"constraint": ast.unparse(node.constraint),
|
||||
}
|
||||
|
||||
def visit_frame_column(self, node: FrameColumn) -> dict:
|
||||
return {
|
||||
"_type": "FrameColumn",
|
||||
"name": node.name,
|
||||
"type": self._serialize_optional(node.type),
|
||||
}
|
||||
|
||||
def visit_frame_type(self, node: FrameType) -> dict:
|
||||
return {
|
||||
"_type": "FrameType",
|
||||
"columns": self._serialize_list(node.columns),
|
||||
}
|
||||
|
||||
def visit_expression_stmt(self, stmt: ExpressionStmt) -> dict:
|
||||
return {
|
||||
"_type": "ExpressionStmt",
|
||||
"expr": stmt.expr.accept(self),
|
||||
}
|
||||
|
||||
def _serialize_argument(self, arg: Function.Argument) -> dict:
|
||||
return {
|
||||
"name": arg.name,
|
||||
"type": self._serialize_optional(arg.type),
|
||||
"default": self._serialize_optional(arg.default),
|
||||
}
|
||||
|
||||
def visit_function(self, stmt: Function) -> dict:
|
||||
return {
|
||||
"_type": "Function",
|
||||
"name": stmt.name,
|
||||
"posonlyargs": [self._serialize_argument(arg) for arg in stmt.posonlyargs],
|
||||
"args": [self._serialize_argument(arg) for arg in stmt.args],
|
||||
"sink": (
|
||||
self._serialize_argument(stmt.sink) if stmt.sink is not None else None
|
||||
),
|
||||
"kwonlyargs": [self._serialize_argument(arg) for arg in stmt.kwonlyargs],
|
||||
"kw_sink": (
|
||||
self._serialize_argument(stmt.kw_sink)
|
||||
if stmt.kw_sink is not None
|
||||
else None
|
||||
),
|
||||
"returns": self._serialize_optional(stmt.returns),
|
||||
"body": self._serialize_list(stmt.body),
|
||||
}
|
||||
|
||||
def visit_type_assign(self, stmt: TypeAssign) -> dict:
|
||||
return {
|
||||
"_type": "TypeAssign",
|
||||
"name": stmt.name,
|
||||
"type": stmt.type.accept(self),
|
||||
}
|
||||
|
||||
def visit_assign_stmt(self, stmt: AssignStmt) -> dict:
|
||||
return {
|
||||
"_type": "AssignStmt",
|
||||
"targets": self._serialize_list(stmt.targets),
|
||||
"value": stmt.value.accept(self),
|
||||
}
|
||||
|
||||
def visit_return_stmt(self, stmt: ReturnStmt) -> dict:
|
||||
return {
|
||||
"_type": "ReturnStmt",
|
||||
"value": self._serialize_optional(stmt.value),
|
||||
}
|
||||
|
||||
def visit_if_stmt(self, stmt: IfStmt) -> dict:
|
||||
return {
|
||||
"_type": "IfStmt",
|
||||
"test": stmt.test.accept(self),
|
||||
"body": self._serialize_list(stmt.body),
|
||||
"orelse": self._serialize_list(stmt.orelse),
|
||||
}
|
||||
|
||||
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
|
||||
return {
|
||||
"_type": "BinaryExpr",
|
||||
"left": expr.left.accept(self),
|
||||
"operator": binary_ops[expr.operator.__class__],
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_compare_expr(self, expr: CompareExpr) -> dict:
|
||||
return {
|
||||
"_type": "CompareExpr",
|
||||
"left": expr.left.accept(self),
|
||||
"operator": compare_ops[expr.operator.__class__],
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_unary_expr(self, expr: UnaryExpr) -> dict:
|
||||
return {
|
||||
"_type": "UnaryExpr",
|
||||
"operator": unary_ops[expr.operator.__class__],
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_call_expr(self, expr: CallExpr) -> dict:
|
||||
return {
|
||||
"_type": "CallExpr",
|
||||
"callee": expr.callee.accept(self),
|
||||
"arguments": self._serialize_list(expr.arguments),
|
||||
"keywords": {name: arg.accept(self) for name, arg in expr.keywords.items()},
|
||||
}
|
||||
|
||||
def visit_get_expr(self, expr: GetExpr) -> dict:
|
||||
return {
|
||||
"_type": "GetExpr",
|
||||
"object": expr.object.accept(self),
|
||||
"name": expr.name,
|
||||
}
|
||||
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> dict:
|
||||
return {
|
||||
"_type": "LiteralExpr",
|
||||
"value": expr.value,
|
||||
}
|
||||
|
||||
def visit_variable_expr(self, expr: VariableExpr) -> dict:
|
||||
return {
|
||||
"_type": "VariableExpr",
|
||||
"name": expr.name,
|
||||
}
|
||||
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
||||
return {
|
||||
"_type": "LogicalExpr",
|
||||
"left": expr.left.accept(self),
|
||||
"operator": boolean_ops[expr.operator.__class__],
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_cast_expr(self, expr: CastExpr) -> dict:
|
||||
return {
|
||||
"_type": "CastExpr",
|
||||
"type": expr.type.accept(self),
|
||||
"expr": expr.expr.accept(self),
|
||||
}
|
||||
|
||||
def visit_ternary_expr(self, expr: TernaryExpr) -> dict:
|
||||
return {
|
||||
"_type": "TernaryExpr",
|
||||
"test": expr.test.accept(self),
|
||||
"if_true": expr.if_true.accept(self),
|
||||
"if_false": expr.if_false.accept(self),
|
||||
}
|
||||
@@ -31,22 +31,32 @@
|
||||
]
|
||||
},
|
||||
"type-base": {
|
||||
"begin": "<",
|
||||
"end": ">",
|
||||
"begin": "(\\()([a-zA-Z_][a-zA-Z_\\d]*)(\\))",
|
||||
"end": "$",
|
||||
"beginCaptures": {
|
||||
"0": {
|
||||
"1": {
|
||||
"name": "punctuation.definition.base.begin.midas"
|
||||
}
|
||||
},
|
||||
"endCaptures": {
|
||||
"0": {
|
||||
},
|
||||
"2": {
|
||||
"name": "variable.name"
|
||||
},
|
||||
"3": {
|
||||
"name": "punctuation.definition.base.end.midas"
|
||||
}
|
||||
},
|
||||
"patterns": [
|
||||
{"include": "source.python"}
|
||||
{ "include": "#type-cond" }
|
||||
]
|
||||
},
|
||||
"type-cond": {
|
||||
"begin": "where",
|
||||
"end": "$",
|
||||
"beginCaptures": {
|
||||
"0": {
|
||||
"name": "keyword.control.where.midas"
|
||||
}
|
||||
}
|
||||
},
|
||||
"type-body": {
|
||||
"begin": "\\{",
|
||||
"end": "\\}",
|
||||
@@ -61,7 +71,8 @@
|
||||
}
|
||||
},
|
||||
"patterns": [
|
||||
{"include": "#type-prop"}
|
||||
{"include": "#type-prop"},
|
||||
{"include": "#comment"}
|
||||
]
|
||||
},
|
||||
"type-prop": {
|
||||
@@ -78,44 +89,67 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"op-def": {
|
||||
"match": "\\b(op)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>\\s+(\\S+)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>\\s+(=)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>",
|
||||
"captures": {
|
||||
"1": {
|
||||
"name": "keyword.control.op.midas"
|
||||
},
|
||||
"2": {
|
||||
"name" : "variable.name"
|
||||
},
|
||||
"3": {
|
||||
"name" : "keyword.operator"
|
||||
},
|
||||
"4": {
|
||||
"name" : "variable.name"
|
||||
},
|
||||
"5": {
|
||||
"name" : "keyword.operator.assignment"
|
||||
},
|
||||
"6": {
|
||||
"name" : "variable.name"
|
||||
}
|
||||
},
|
||||
"patterns": [
|
||||
{ "include": "#type-base" },
|
||||
{ "include": "#type-body" }
|
||||
]
|
||||
},
|
||||
"constr-def": {
|
||||
"begin": "(constraint)\\s+([a-zA-Z_][a-zA-Z_\\d]*)\\s*(=)",
|
||||
"end": "$",
|
||||
"extend-def": {
|
||||
"begin": "\\b(extend)\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\s+(\\{)",
|
||||
"end": "\\}",
|
||||
"beginCaptures": {
|
||||
"1": {
|
||||
"name": "keyword.control.constr.midas"
|
||||
"name": "keyword.control.extend.midas"
|
||||
},
|
||||
"2": {
|
||||
"name": "variable.name"
|
||||
},
|
||||
"3": {
|
||||
"name": "punctuation.definition.extend-body.begin.midas"
|
||||
}
|
||||
},
|
||||
"endCaptures": {
|
||||
"0": {
|
||||
"name": "punctuation.definition.extend-body.end.midas"
|
||||
}
|
||||
},
|
||||
"patterns": [
|
||||
{"include": "#op-def"},
|
||||
{"include": "#comment"}
|
||||
]
|
||||
},
|
||||
"op-def": {
|
||||
"match": "\\b(op)\\s+(\\S+)\\s*\\(\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\s*\\)\\s*(->)\\s*([a-zA-Z_][a-zA-Z_\\d]*)",
|
||||
"captures": {
|
||||
"1": {
|
||||
"name": "keyword.control.op.midas"
|
||||
},
|
||||
"2": {
|
||||
"name" : "keyword.operator"
|
||||
},
|
||||
"3": {
|
||||
"name" : "variable.name"
|
||||
},
|
||||
"4": {
|
||||
"name" : "keyword.operator.assignment"
|
||||
},
|
||||
"5": {
|
||||
"name" : "variable.name"
|
||||
}
|
||||
}
|
||||
},
|
||||
"pred-def": {
|
||||
"begin": "(predicate)\\s+([a-zA-Z_][a-zA-Z_\\d]*)\\(([a-zA-Z_][a-zA-Z_\\d]*):\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\)\\s*(=)",
|
||||
"end": "$",
|
||||
"beginCaptures": {
|
||||
"1": {
|
||||
"name": "keyword.control.pred.midas"
|
||||
},
|
||||
"2": {
|
||||
"name": "variable.name"
|
||||
},
|
||||
"3": {
|
||||
"name": "variable.name"
|
||||
},
|
||||
"4": {
|
||||
"name": "variable.name"
|
||||
},
|
||||
"5": {
|
||||
"name": "keyword.operator.assignment"
|
||||
}
|
||||
},
|
||||
@@ -127,8 +161,8 @@
|
||||
"patterns": [
|
||||
{ "include": "#comment" },
|
||||
{ "include": "#type-def" },
|
||||
{ "include": "#op-def" },
|
||||
{ "include": "#constr-def" }
|
||||
{ "include": "#extend-def" },
|
||||
{ "include": "#pred-def" }
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user