Compare commits
101 Commits
d039a8e4b3
...
feat/add-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
9764484fd9
|
|||
|
5b9e322c91
|
|||
|
c18d9c18de
|
|||
|
9229f00375
|
|||
|
6b7a682dc5
|
|||
|
35b97fd17b
|
|||
| 03bc32400b | |||
|
4a93ee45d9
|
|||
|
8197131d8d
|
|||
|
cf91187b7a
|
|||
|
1b2bdf0b79
|
|||
| c6cc38bfeb | |||
|
4d3e3f44a1
|
|||
|
ec80b1e92e
|
|||
|
4ea15519f3
|
|||
|
7a6e01cff8
|
|||
|
733c8736b8
|
|||
|
20173a0b07
|
|||
|
a143972ef1
|
|||
|
0c70048b62
|
|||
|
1c0c917873
|
|||
|
1f6189daa4
|
|||
|
66b585c3d6
|
|||
|
819ab3c2bf
|
|||
|
d8c0b17512
|
|||
|
6e06f9078e
|
|||
|
ece2e3a6a3
|
|||
|
74c07c9afb
|
|||
|
be2fd4c837
|
|||
|
1bc4c704c3
|
|||
|
0288a05901
|
|||
|
b14f46d405
|
|||
|
8e8ed62266
|
|||
|
2fce2f4bfc
|
|||
|
640f2d1771
|
|||
|
b48dfe5301
|
|||
|
0d5840a4ce
|
|||
|
3c92f0867d
|
|||
|
b5acae4078
|
|||
|
5d20f8ec3e
|
|||
|
955c2233ed
|
|||
|
ff69b65171
|
|||
|
8df01afd8c
|
|||
|
47b2dfdd73
|
|||
|
bd4d793ce0
|
|||
|
f7a36f61b6
|
|||
|
ad2fabf471
|
|||
|
a59a58d21a
|
|||
| 3260ae4a1e | |||
|
bd1c9581c7
|
|||
|
663642ea6c
|
|||
|
e2abc04fe4
|
|||
|
a4016b55ce
|
|||
|
1ea5da7024
|
|||
|
a017a8cf1f
|
|||
|
8fc5ab623e
|
|||
|
14007db846
|
|||
|
6ad2ce4b68
|
|||
|
9a276c34c7
|
|||
|
6e717a3f9e
|
|||
|
77aadfa264
|
|||
| c81287df7f | |||
|
ffccc1bedd
|
|||
|
d14f208897
|
|||
|
293953a078
|
|||
|
bccc96e4d0
|
|||
|
9db56adf56
|
|||
|
3f99563ac8
|
|||
|
b36896cc7b
|
|||
|
cb75878ae9
|
|||
|
a5fe985eb2
|
|||
|
e324f414e6
|
|||
|
256536562f
|
|||
|
64f4314f0d
|
|||
|
6f6245d283
|
|||
|
3392bc347d
|
|||
|
7e0319906a
|
|||
|
75bd203d4a
|
|||
|
db40198357
|
|||
|
d79e1dee18
|
|||
|
4ea400265c
|
|||
|
24bffdabd4
|
|||
|
d7bb6326de
|
|||
|
dbf6f9e2db
|
|||
|
3cdc9031d3
|
|||
|
00e2ca8fe3
|
|||
|
4efb01285c
|
|||
|
f84a19159f
|
|||
|
946b2e0d2e
|
|||
|
08dd7408ec
|
|||
|
b33fadf768
|
|||
|
7219109e5d
|
|||
|
cdf1725c26
|
|||
|
7074b074bc
|
|||
|
ede7272c09
|
|||
|
87d5e286d2
|
|||
|
c91b206791
|
|||
|
a31d295eb1
|
|||
|
0d20993f02
|
|||
|
5357ca8e58
|
|||
|
556765fd35
|
120
docs/manual.typ
120
docs/manual.typ
@@ -198,10 +198,26 @@ python3 build/midas/script.py
|
||||
In this chapter, you will find a complete reference for the Midas definition language.
|
||||
|
||||
A `*.midas` file contains a number of statements, which can be:
|
||||
- *`alias`* statements (see @alias-stmt): to define a new type alias
|
||||
- *`type`* statements (see @type-stmt): to define a new type
|
||||
- *`extend`* statements (see @extend-stmt): to define member of a type
|
||||
- *`predicate`* statements (see @predicate-stmt): to define named predicates that can be used in constraint types
|
||||
|
||||
== Alias Statement <alias-stmt>
|
||||
|
||||
An *`alias`* statement lets you define a new type alias. It requires a unique name and base type.
|
||||
|
||||
While a `type` statement (see @type-stmt) allows generic definitions, aliases are purely a for givin an alternative name to a type.
|
||||
|
||||
#figure(
|
||||
```midas
|
||||
alias MyType = float
|
||||
```,
|
||||
caption: [Simple `alias` statement declaring a new type "`MyType`" equivalent to `float`],
|
||||
) <midas-simple-alias>
|
||||
|
||||
This statement defines a new type called `MyType` which is equivalent to `float`. `MyType` and `float` can be used interchangeably.
|
||||
|
||||
== Type Statement <type-stmt>
|
||||
|
||||
A *`type`* statement lets you define a new type. It requires a unique name and base type.
|
||||
@@ -212,7 +228,7 @@ The simplest form of a *`type`* statement is:
|
||||
type MyType = float
|
||||
```,
|
||||
caption: [Simple `type` statement declaring a new type "`MyType`" as a subtype of `float`],
|
||||
) <midas-simple-alias>
|
||||
) <midas-simple-type>
|
||||
|
||||
This statement defines a new type called `MyType` which is a subtype of `float`. `MyType` is a `float` but a `float` is not necessarily `MyType`.
|
||||
|
||||
@@ -291,8 +307,7 @@ To better refine a generic type, you can also bound type parameters using the fo
|
||||
caption: [Generic container type definition with a bound],
|
||||
)
|
||||
|
||||
This can be read as "`Container` is a generic type which takes one type parameter `T` that must be a subtype of `float`".
|
||||
|
||||
This can be read as "`Container` is a generic type which takes one type parameter `T` that must be a subtype of `float`".\
|
||||
You can use a generic type, i.e. instantiate it, by using a similar syntax with concrete type as arguments:
|
||||
|
||||
#figure(
|
||||
@@ -318,6 +333,46 @@ The _body_ of a generic type, i.e. the right-hand side of the definition, can co
|
||||
caption: [Type parameters in a generic type's body],
|
||||
)
|
||||
|
||||
=== `Column` / `Frame` types
|
||||
|
||||
To provide useful type-checking for data engineers, Midas offers two special types: `Column` and `Frame`.
|
||||
Their goal is to help type check Pandas' `Series` and `DataFrame` respectively.
|
||||
|
||||
==== `Column`
|
||||
|
||||
The `Column` type is a generic type used to represent a `pandas.Series` object.
|
||||
You can use it like any other generic type and it will provide type checking for some common methods and attributes offered by Pandas.
|
||||
|
||||
#figure(
|
||||
```midas
|
||||
type Temperature = float
|
||||
alias Temperatures = Column[Temperature]
|
||||
```,
|
||||
caption: [Simple column type definition],
|
||||
)
|
||||
|
||||
==== `Frame` <frame-type>
|
||||
|
||||
The `Frame` type is a super-powered generic type used to represent a `pandas.DataFrame` object.
|
||||
In place of type arguments, `Frame` accepts a schema, i.e. a series of column definitions.
|
||||
@simple-frame show how you can define a simple frame type with 3 columns:
|
||||
- `name`: a column of `Name` values
|
||||
- `age`: a column of `int` values
|
||||
- `height`: a column of `float where _ >= 0` values
|
||||
|
||||
Notice that you don't need to specify `Column` types.
|
||||
|
||||
#figure(
|
||||
```midas
|
||||
type Name = str where len(_) != 0
|
||||
alias Data = Frame[
|
||||
name: Name,
|
||||
age: int,
|
||||
height: float where _ >= 0
|
||||
]
|
||||
```,
|
||||
) <simple-frame>
|
||||
|
||||
#pagebreak()
|
||||
|
||||
== Extend Statement <extend-stmt>
|
||||
@@ -503,6 +558,7 @@ A simple annotation declaration, without assigning a value, is enough to declare
|
||||
)
|
||||
|
||||
Because unpacking is not supported, assigning to multiple values is also not handled by the type checker.
|
||||
For more information about type annotations, see @type-annotations
|
||||
|
||||
== Arithmetic
|
||||
|
||||
@@ -578,7 +634,7 @@ Conditional statements are checked relatively strictly by Midas. The test expres
|
||||
|
||||
Simple forms of `for` loops can be used, that is using a single variable and iterating over an object implementing the `__getitem__` method. Like above in @if-else, leaking variables from inside the loop is ignored.
|
||||
|
||||
The `for`-`else` statements are not supported. `while` loops are also not not supported.
|
||||
`for`-`else` statements are not supported. `while` loops are also not supported.
|
||||
|
||||
== Functions
|
||||
|
||||
@@ -678,10 +734,43 @@ In the following example, a runtime check would be generated to ensure that the
|
||||
caption: [Typing of `cast` expression],
|
||||
)
|
||||
|
||||
#gc.warning[
|
||||
Assertions are statements inserted just before a statement using a `cast` expression. This means that the expression is evaluated _before_ its actual intended usage location, which might cause issues if you rely on logical operator short-circuiting. See @eager-eval for more information.
|
||||
]
|
||||
|
||||
There may be some cases where the cost of checking a value at runtime is simply not worth the safety, for example when dealing with a big dataset. If do wish so, you can use `unsafe_cast` which will only tell the type checker the type of the value, without generating a runtime assertion. This maps to the default behavior of `typing`'s own `cast` function.
|
||||
|
||||
If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a string, a list of literals, etc.), the assertion is evaluated _at compile-time_ and no runtime assertion is generated.
|
||||
|
||||
== Annotations / Type Hints <type-annotations>
|
||||
|
||||
Vanilla Python already lets you use type hints to specify the type of variables and function parameters.
|
||||
|
||||
Midas use them to type check your code. Additionally, it allows you to use a special syntax to define a `Frame` types directly in these annotations.
|
||||
|
||||
Because these annotations are not interpretable by Python, your integrated type checker might complain loudly about them being invalid.
|
||||
A workaround is to silence it by adding a type comment at the end of the line, as shown in @silence-errors.
|
||||
|
||||
#figure(
|
||||
```python
|
||||
var: Frame[name: str, age: float] # type: ignore # noqa: F821
|
||||
```,
|
||||
caption: [MyPy's and Pylance's complaints about custom type annotation can be silenced with type comments],
|
||||
) <silence-errors>
|
||||
|
||||
=== Frame type annotation
|
||||
|
||||
The syntax is similar to how you can define frame types in the Midas language (see @frame-type). The only difference is that types can only be name references; you cannot inline constraint types.
|
||||
|
||||
The example of @python-frame-type shows how you can annotate a dataframe with some columns directly in Python.
|
||||
|
||||
#figure(
|
||||
```python
|
||||
df: Frame[name: Name, age: float, height: Length[Meter]] = ...
|
||||
```,
|
||||
caption: [Frame type annotation in Python],
|
||||
) <python-frame-type>
|
||||
|
||||
= Commands <commands>
|
||||
|
||||
#TODO
|
||||
@@ -695,3 +784,26 @@ If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a
|
||||
== Generating Stubs (`stubs`) <cmd-stubs>
|
||||
== Showing Type Judgements (`types`) <cmd-types>
|
||||
== Validating Definitions (`validate`) <cmd-validate>
|
||||
|
||||
= Known limitations <limitations>
|
||||
|
||||
== Eager evaluation in runtime assertions <eager-eval>
|
||||
|
||||
The process of generating assertions to ensure safety at runtime, mainly for `cast` expressions, leads to the creation of aliases for the expressions being casted. These alias definitions eagerly evaluate before the assertion, and most importantly before the real usage location. This means that you should avoid using `cast` expressions inside logical expressions like `and` or `or`, because the normal "short-circuit" behavior will be irrelevant to the evaluations of the operands.
|
||||
|
||||
For example:
|
||||
|
||||
#figure(
|
||||
```py
|
||||
def foo():
|
||||
print("Foo")
|
||||
return True
|
||||
def bar():
|
||||
print("Bar")
|
||||
return True
|
||||
result = foo() or bar()
|
||||
# Foo
|
||||
# Bar
|
||||
```,
|
||||
caption: [Runtime assertions may eagerly evaluate expressions and bypass logical operator's short-circuit],
|
||||
)
|
||||
|
||||
@@ -37,6 +37,9 @@ contexts:
|
||||
pop: true
|
||||
|
||||
keywords:
|
||||
- match: \balias\b
|
||||
scope: keyword.declaration.midas
|
||||
push: alias-stmt
|
||||
- match: \btype\b
|
||||
scope: keyword.declaration.midas
|
||||
push: type-stmt
|
||||
@@ -47,6 +50,15 @@ contexts:
|
||||
scope: keyword.declaration.midas
|
||||
push: predicate-stmt
|
||||
|
||||
alias-stmt:
|
||||
- match: "{{identifier}}"
|
||||
scope: entity.name.type
|
||||
- match: "="
|
||||
scope: keyword.operator.equal.midas
|
||||
push: type-expr
|
||||
- match: $
|
||||
pop: true
|
||||
|
||||
type-stmt:
|
||||
- match: "{{identifier}}"
|
||||
scope: entity.name.type
|
||||
@@ -67,6 +79,13 @@ contexts:
|
||||
- match: \b(where)\b
|
||||
scope: keyword.other.midas
|
||||
set: constraint
|
||||
- match: "Frame"
|
||||
scope: entity.name.type
|
||||
push:
|
||||
- match: \[
|
||||
push: frame-schema
|
||||
- match: $
|
||||
pop: true
|
||||
- match: "{{identifier}}"
|
||||
scope: entity.name.type
|
||||
- match: $
|
||||
@@ -178,3 +197,15 @@ contexts:
|
||||
|
||||
- match: '\)'
|
||||
pop: true
|
||||
|
||||
frame-schema:
|
||||
- include: frame-column
|
||||
- match: \]
|
||||
# scope: punctuation.section.block.end
|
||||
pop: true
|
||||
|
||||
frame-column:
|
||||
- match: "{{identifier}}"
|
||||
scope: variable.other.member
|
||||
- match: ":"
|
||||
push: type-expr
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
"""
|
||||
Helper script to generate AST nodes for Midas and Python.
|
||||
|
||||
Takes in simple templates and generates full dataclasses and a visitor interface
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
18
gen/midas.py
18
gen/midas.py
@@ -29,9 +29,9 @@ class MemberKind(Enum):
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
l_paren: Token
|
||||
pos: list[FunctionType.Argument]
|
||||
mixed: list[FunctionType.Argument]
|
||||
kw: list[FunctionType.Argument]
|
||||
pos: list[FunctionType.Parameter]
|
||||
mixed: list[FunctionType.Parameter]
|
||||
kw: list[FunctionType.Parameter]
|
||||
|
||||
|
||||
###<
|
||||
@@ -150,11 +150,21 @@ class FunctionType:
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
class Parameter:
|
||||
location: Optional[Location] = None
|
||||
name: Optional[Token]
|
||||
type: Type
|
||||
required: bool
|
||||
|
||||
|
||||
class FrameType:
|
||||
columns: list[Column]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Column:
|
||||
location: Optional[Location] = None
|
||||
name: Token
|
||||
type: Type
|
||||
|
||||
|
||||
###<
|
||||
|
||||
@@ -12,10 +12,25 @@ from midas.ast.location import Location
|
||||
###<
|
||||
|
||||
|
||||
###> Preamble
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
pos: list[Function.Parameter]
|
||||
mixed: list[Function.Parameter]
|
||||
kw: list[Function.Parameter]
|
||||
|
||||
@property
|
||||
def all(self) -> list[Function.Parameter]:
|
||||
return self.pos + self.mixed + self.kw
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> MidasType | Type annotations | node
|
||||
class BaseType:
|
||||
base: str
|
||||
param: Optional[MidasType]
|
||||
args: tuple[MidasType, ...]
|
||||
|
||||
|
||||
class ConstraintType:
|
||||
@@ -42,25 +57,17 @@ class ExpressionStmt:
|
||||
|
||||
class Function:
|
||||
name: str
|
||||
posonlyargs: list[Argument]
|
||||
args: list[Argument]
|
||||
sink: Optional[Argument]
|
||||
kwonlyargs: list[Argument]
|
||||
kw_sink: Optional[Argument]
|
||||
params: ParamSpec
|
||||
returns: Optional[MidasType]
|
||||
body: list[Stmt]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
class Parameter:
|
||||
location: Optional[Location] = None
|
||||
name: str
|
||||
type: Optional[MidasType]
|
||||
default: Optional[Expr]
|
||||
|
||||
@property
|
||||
def all_args(self) -> list[Argument]:
|
||||
return self.posonlyargs + self.args + self.kwonlyargs
|
||||
|
||||
|
||||
class TypeAssign:
|
||||
name: str
|
||||
@@ -174,6 +181,10 @@ class SliceExpr:
|
||||
step: Optional[Expr]
|
||||
|
||||
|
||||
class TupleExpr:
|
||||
items: tuple[Expr, ...]
|
||||
|
||||
|
||||
class RawExpr:
|
||||
expr: ast.expr
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@ class HasLocation(Protocol):
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Location:
|
||||
"""Information about the location of an AST node"""
|
||||
|
||||
lineno: int
|
||||
col_offset: int
|
||||
end_lineno: Optional[int]
|
||||
@@ -29,6 +31,16 @@ class Location:
|
||||
|
||||
@staticmethod
|
||||
def span(start: Location, end: Location) -> Location:
|
||||
"""Create a new location spanning from one location to another
|
||||
|
||||
Args:
|
||||
start (Location): the starting location
|
||||
end (Location): the end location
|
||||
|
||||
Returns:
|
||||
Location: a new location spanning from the start of `start`
|
||||
to the end of `end`
|
||||
"""
|
||||
return Location(
|
||||
lineno=start.lineno,
|
||||
col_offset=start.col_offset,
|
||||
|
||||
@@ -30,9 +30,9 @@ class MemberKind(Enum):
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
l_paren: Token
|
||||
pos: list[FunctionType.Argument]
|
||||
mixed: list[FunctionType.Argument]
|
||||
kw: list[FunctionType.Argument]
|
||||
pos: list[FunctionType.Parameter]
|
||||
mixed: list[FunctionType.Parameter]
|
||||
kw: list[FunctionType.Parameter]
|
||||
|
||||
|
||||
##############
|
||||
@@ -265,6 +265,9 @@ class Type(ABC):
|
||||
@abstractmethod
|
||||
def visit_function_type(self, type: FunctionType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_frame_type(self, type: FrameType) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NamedType(Type):
|
||||
@@ -315,7 +318,7 @@ class FunctionType(Type):
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
class Parameter:
|
||||
location: Optional[Location] = None
|
||||
name: Optional[Token]
|
||||
type: Type
|
||||
@@ -323,3 +326,17 @@ class FunctionType(Type):
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_function_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FrameType(Type):
|
||||
columns: list[Column]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Column:
|
||||
location: Optional[Location] = None
|
||||
name: Token
|
||||
type: Type
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_frame_type(self)
|
||||
|
||||
@@ -1,843 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import io
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
from typing import Generator, Generic, Optional, Protocol, TypeVar
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
|
||||
|
||||
class _Level(Enum):
|
||||
EMPTY = auto()
|
||||
ACTIVE = auto()
|
||||
LAST = auto()
|
||||
|
||||
|
||||
class Expr(Protocol):
|
||||
def accept(self, printer: AstPrinter) -> None: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Expr)
|
||||
|
||||
|
||||
class AstPrinter(Generic[T]):
|
||||
LAST_CHILD = "└── "
|
||||
CHILD = "├── "
|
||||
VERTICAL = "│ "
|
||||
EMPTY = " "
|
||||
|
||||
def __init__(self):
|
||||
self._levels: list[_Level] = []
|
||||
self._idx: Optional[int] = None
|
||||
self._buf: io.StringIO = io.StringIO()
|
||||
|
||||
def print(self, expr: T):
|
||||
self._buf = io.StringIO()
|
||||
expr.accept(self)
|
||||
return self._buf.getvalue()
|
||||
|
||||
@contextmanager
|
||||
def _child_level(self, single: bool = False) -> Generator[None, None, None]:
|
||||
self._levels.append(_Level.LAST if single else _Level.ACTIVE)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._levels.pop()
|
||||
|
||||
def _mark_last(self):
|
||||
if self._levels:
|
||||
self._levels[-1] = _Level.LAST
|
||||
|
||||
def _write_line(self, text: str, *, last: bool = False):
|
||||
if last:
|
||||
self._mark_last()
|
||||
indent: str = self._build_indent()
|
||||
if self._idx is not None:
|
||||
text = f"[{self._idx}] {text}"
|
||||
self._idx = None
|
||||
self._buf.write(indent + text + "\n")
|
||||
|
||||
def _build_indent(self) -> str:
|
||||
parts: list[str] = []
|
||||
for level in self._levels[:-1]:
|
||||
parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL)
|
||||
if self._levels:
|
||||
if self._levels[-1] == _Level.LAST:
|
||||
parts.append(self.LAST_CHILD)
|
||||
self._levels[-1] = _Level.EMPTY
|
||||
else:
|
||||
parts.append(self.CHILD)
|
||||
return "".join(parts)
|
||||
|
||||
def _write_optional_child(
|
||||
self, label: str, child: Optional[T], *, last: bool = False
|
||||
):
|
||||
if last:
|
||||
self._mark_last()
|
||||
if child is None:
|
||||
self._write_line(f"{label}: None")
|
||||
else:
|
||||
self._write_line(label)
|
||||
with self._child_level(single=True):
|
||||
child.accept(self)
|
||||
|
||||
|
||||
class MidasAstPrinter(
|
||||
AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None], m.Type.Visitor[None]
|
||||
):
|
||||
# Statements
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
self._write_line("TypeStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
for i, param in enumerate(stmt.params):
|
||||
self._idx = i
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
self._print_type_param(param)
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
|
||||
self._write_line("AliasStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def _print_type_param(self, param: m.TypeParam) -> None:
|
||||
self._write_line("Param")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{param.name.lexeme}"')
|
||||
self._write_optional_child("bound", param.bound, last=True)
|
||||
|
||||
def visit_member_stmt(self, stmt: m.MemberStmt):
|
||||
self._write_line("MemberStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f"kind: {stmt.kind.name}")
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self._write_line("ExtendStmt")
|
||||
with self._child_level():
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
for i, param in enumerate(stmt.params):
|
||||
self._idx = i
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
self._print_type_param(param)
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
for i, param in enumerate(stmt.params):
|
||||
self._idx = i
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
self._print_type_param(param)
|
||||
self._write_line("members", last=True)
|
||||
with self._child_level():
|
||||
for i, member in enumerate(stmt.members):
|
||||
self._idx = i
|
||||
if i == len(stmt.members) - 1:
|
||||
self._mark_last()
|
||||
member.accept(self)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
self._write_line("PredicateStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
for i, spec in enumerate(stmt.params):
|
||||
self._idx = i
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
self._visit_param_spec(spec)
|
||||
|
||||
self._write_line("body", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.body.accept(self)
|
||||
|
||||
# Expressions
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
self._write_line("LogicalExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr):
|
||||
self._write_line("BinaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr):
|
||||
self._write_line("UnaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> None:
|
||||
self._write_line("CallExpr")
|
||||
with self._child_level():
|
||||
self._write_line("callee")
|
||||
with self._child_level(single=True):
|
||||
expr.callee.accept(self)
|
||||
self._write_line("arguments")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(expr.arguments):
|
||||
self._idx = i
|
||||
if i == len(expr.arguments) - 1:
|
||||
self._mark_last()
|
||||
arg.accept(self)
|
||||
self._write_line("keywords", last=True)
|
||||
with self._child_level():
|
||||
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||
self._idx = i
|
||||
if i == len(expr.keywords) - 1:
|
||||
self._mark_last()
|
||||
self._write_line(name)
|
||||
with self._child_level(single=True):
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
self._write_line("GetExpr")
|
||||
with self._child_level():
|
||||
self._write_line("expr")
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr):
|
||||
self._write_line("VariableExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr):
|
||||
self._write_line("GroupingExpr")
|
||||
with self._child_level():
|
||||
self._write_line("expr", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"value: {expr.value}", last=True)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||
self._write_line("WildcardExpr")
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> None:
|
||||
self._write_line("NamedType")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{type.name.lexeme}"', last=True)
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> None:
|
||||
self._write_line("GenericType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level():
|
||||
type.type.accept(self)
|
||||
self._write_line("args", last=True)
|
||||
with self._child_level():
|
||||
for i, param in enumerate(type.args):
|
||||
self._idx = i
|
||||
if i == len(type.args) - 1:
|
||||
self._mark_last()
|
||||
param.accept(self)
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
||||
self._write_line("ConstraintType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
type.type.accept(self)
|
||||
self._write_line("constraint", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.constraint.accept(self)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> None:
|
||||
self._write_line("ComplexType")
|
||||
with self._child_level():
|
||||
self._write_line("members", last=True)
|
||||
with self._child_level():
|
||||
for i, member in enumerate(type.members):
|
||||
self._idx = i
|
||||
if i == len(type.members) - 1:
|
||||
self._mark_last()
|
||||
member.accept(self)
|
||||
|
||||
def visit_extension_type(self, type: m.ExtensionType) -> None:
|
||||
self._write_line("ExtensionType")
|
||||
with self._child_level():
|
||||
self._write_line("base")
|
||||
with self._child_level(single=True):
|
||||
type.base.accept(self)
|
||||
self._write_line("extension", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.extension.accept(self)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||
self._write_line("FunctionType")
|
||||
with self._child_level():
|
||||
self._write_line("params")
|
||||
with self._child_level(single=True):
|
||||
self._visit_param_spec(type.params)
|
||||
|
||||
self._write_line("returns", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.returns.accept(self)
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
|
||||
self._write_line("ParamSpec")
|
||||
with self._child_level():
|
||||
self._write_line("pos")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(spec.pos):
|
||||
self._idx = i
|
||||
if i == len(spec.pos) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("mixed")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(spec.mixed):
|
||||
self._idx = i
|
||||
if i == len(spec.mixed) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("kw", last=True)
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(spec.kw):
|
||||
self._idx = i
|
||||
if i == len(spec.kw) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
|
||||
self._write_line("Argument")
|
||||
with self._child_level():
|
||||
name: str = "None"
|
||||
if arg.name is not None:
|
||||
name = f'"{arg.name.lexeme}"'
|
||||
self._write_line(f"name: {name}")
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
arg.type.accept(self)
|
||||
self._write_line(f"required: {arg.required}", last=True)
|
||||
|
||||
|
||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
||||
def __init__(self, indent: int = 4):
|
||||
self.indent: int = indent
|
||||
self.level: int = 0
|
||||
|
||||
def indented(self, text: str) -> str:
|
||||
return " " * (self.level * self.indent) + text
|
||||
|
||||
def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
|
||||
self.level = 0
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
|
||||
template: str = ""
|
||||
if len(stmt.params) != 0:
|
||||
params: list[str] = [self._print_type_param(param) for param in stmt.params]
|
||||
template = f"[{', '.join(params)}]"
|
||||
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
||||
return self.indented(res)
|
||||
|
||||
def visit_alias_stmt(self, stmt: m.AliasStmt) -> str:
|
||||
return self.indented(f"alias {stmt.name.lexeme} = {stmt.type.accept(self)}")
|
||||
|
||||
def _print_type_param(self, param: m.TypeParam) -> str:
|
||||
res: str = param.name.lexeme
|
||||
if param.bound is not None:
|
||||
res += "<:" + param.bound.accept(self)
|
||||
return res
|
||||
|
||||
def visit_member_stmt(self, stmt: m.MemberStmt):
|
||||
keyword: str = {
|
||||
m.MemberKind.PROPERTY: "prop",
|
||||
m.MemberKind.METHOD: "def",
|
||||
}.get(stmt.kind, "")
|
||||
res: str = f"{keyword} {stmt.name.lexeme}: {stmt.type.accept(self)}"
|
||||
return self.indented(res)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt):
|
||||
template: str = ""
|
||||
if len(stmt.params) != 0:
|
||||
params: list[str] = [self._print_type_param(param) for param in stmt.params]
|
||||
template = f"[{', '.join(params)}]"
|
||||
res: str = self.indented(f"extend {stmt.name.lexeme}{template}")
|
||||
res += " {\n"
|
||||
self.level += 1
|
||||
for member in stmt.members:
|
||||
res += member.accept(self) + "\n"
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
name: str = stmt.name.lexeme
|
||||
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
|
||||
body: str = stmt.body.accept(self)
|
||||
return self.indented(f"predicate {name}{sig} = {body}")
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{left} {operator} {right}"
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{left} {operator} {right}"
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr):
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{operator}{right}"
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> str:
|
||||
args: list[str] = [arg.accept(self) for arg in expr.arguments] + [
|
||||
f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items()
|
||||
]
|
||||
return f"{expr.callee.accept(self)}({', '.join(args)})"
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
expr_: str = expr.expr.accept(self)
|
||||
name: str = expr.name.lexeme
|
||||
return f"{expr_}.{name}"
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr):
|
||||
return expr.name.lexeme
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr):
|
||||
expr_: str = expr.expr.accept(self)
|
||||
return f"({expr_})"
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr):
|
||||
return str(expr.value)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr):
|
||||
return "_"
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> str:
|
||||
return type.name.lexeme
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
if len(type.args) != 0:
|
||||
args: list[str] = [param.accept(self) for param in type.args]
|
||||
res += f"[{', '.join(args)}]"
|
||||
return res
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
res += " where " + type.constraint.accept(self)
|
||||
return res
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> str:
|
||||
res: str = "{\n"
|
||||
self.level += 1
|
||||
for member in type.members:
|
||||
res += member.accept(self)
|
||||
res += "\n"
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_extension_type(self, type: m.ExtensionType) -> str:
|
||||
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||
spec: str = self._visit_param_spec(type.params)
|
||||
return f"fn {spec} -> {type.returns.accept(self)}"
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
|
||||
pos_args: list[str] = [self._print_arg(arg) for arg in spec.pos]
|
||||
mixed_args: list[str] = [self._print_arg(arg) for arg in spec.mixed]
|
||||
kw_args: list[str] = [self._print_arg(arg) for arg in spec.kw]
|
||||
args: list[str] = pos_args
|
||||
|
||||
if len(pos_args) != 0:
|
||||
args.append("/")
|
||||
args += mixed_args
|
||||
if len(kw_args) != 0:
|
||||
args.append("*")
|
||||
args += kw_args
|
||||
return f"({', '.join(args)})"
|
||||
|
||||
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
|
||||
res: str = ""
|
||||
if arg.name is not None:
|
||||
res += arg.name.lexeme
|
||||
res += ": "
|
||||
res += arg.type.accept(self)
|
||||
if not arg.required:
|
||||
res += "?"
|
||||
return res
|
||||
|
||||
|
||||
class PythonAstPrinter(
|
||||
AstPrinter,
|
||||
p.MidasType.Visitor[None],
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[None],
|
||||
):
|
||||
def visit_base_type(self, node: p.BaseType) -> None:
|
||||
self._write_line("BaseType")
|
||||
with self._child_level():
|
||||
self._write_line(f"base: {node.base}")
|
||||
self._write_optional_child("param", node.param, last=True)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||
self._write_line("ConstraintType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
node.type.accept(self)
|
||||
self._write_line(f"constraint: {ast.unparse(node.constraint)}", last=True)
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> None:
|
||||
self._write_line("FrameColumn")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {node.name}")
|
||||
self._write_optional_child("type", node.type, last=True)
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> None:
|
||||
self._write_line("FrameType")
|
||||
with self._child_level():
|
||||
self._write_line("columns", last=True)
|
||||
with self._child_level():
|
||||
for i, col in enumerate(node.columns):
|
||||
self._idx = i
|
||||
if i == len(node.columns) - 1:
|
||||
self._mark_last()
|
||||
col.accept(self)
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
stmt.expr.accept(self)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
self._write_line("Function")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {stmt.name}")
|
||||
|
||||
self._write_line("posonlyargs")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(stmt.posonlyargs):
|
||||
self._idx = i
|
||||
if i == len(stmt.posonlyargs) - 1:
|
||||
self._mark_last()
|
||||
self._print_argument(arg)
|
||||
|
||||
self._write_line("args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(stmt.args):
|
||||
self._idx = i
|
||||
if i == len(stmt.args) - 1:
|
||||
self._mark_last()
|
||||
self._print_argument(arg)
|
||||
|
||||
self._write_line("kwonlyargs")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(stmt.kwonlyargs):
|
||||
self._idx = i
|
||||
if i == len(stmt.kwonlyargs) - 1:
|
||||
self._mark_last()
|
||||
self._print_argument(arg)
|
||||
|
||||
self._write_optional_child("returns", stmt.returns)
|
||||
self._write_line("body", last=True)
|
||||
with self._child_level():
|
||||
for i, body_stmt in enumerate(stmt.body):
|
||||
self._idx = i
|
||||
if i == len(stmt.body) - 1:
|
||||
self._mark_last()
|
||||
body_stmt.accept(self)
|
||||
|
||||
def _print_argument(self, arg: p.Function.Argument) -> None:
|
||||
self._write_line("FunctionArgument")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {arg.name}")
|
||||
self._write_optional_child("type", arg.type, last=True)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
self._write_line("TypeAssign")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {stmt.name}")
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
self._write_line("AssignStmt")
|
||||
with self._child_level():
|
||||
self._write_line("targets")
|
||||
with self._child_level():
|
||||
for i, target in enumerate(stmt.targets):
|
||||
self._idx = i
|
||||
if i == len(stmt.targets) - 1:
|
||||
self._mark_last()
|
||||
target.accept(self)
|
||||
self._write_line("value", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.value.accept(self)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
self._write_line("ReturnStmt")
|
||||
with self._child_level():
|
||||
self._write_optional_child("value", stmt.value, last=True)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
self._write_line("IfStmt")
|
||||
with self._child_level():
|
||||
self._write_line("test")
|
||||
with self._child_level(single=True):
|
||||
stmt.test.accept(self)
|
||||
self._write_line("body")
|
||||
with self._child_level():
|
||||
for i, body_stmt in enumerate(stmt.body):
|
||||
self._idx = i
|
||||
if i == len(stmt.body) - 1:
|
||||
self._mark_last()
|
||||
body_stmt.accept(self)
|
||||
self._write_line("orelse", last=True)
|
||||
with self._child_level():
|
||||
for i, else_stmt in enumerate(stmt.orelse):
|
||||
self._idx = i
|
||||
if i == len(stmt.orelse) - 1:
|
||||
self._mark_last()
|
||||
else_stmt.accept(self)
|
||||
|
||||
def visit_pass(self, stmt: p.Pass) -> None:
|
||||
self._write_line("Pass")
|
||||
|
||||
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
|
||||
self._write_line("ForStmt")
|
||||
with self._child_level():
|
||||
self._write_line("target")
|
||||
with self._child_level(single=True):
|
||||
stmt.target.accept(self)
|
||||
self._write_line("iterator")
|
||||
with self._child_level(single=True):
|
||||
stmt.iterator.accept(self)
|
||||
self._write_line("body", last=True)
|
||||
with self._child_level():
|
||||
for i, body_stmt in enumerate(stmt.body):
|
||||
self._idx = i
|
||||
if i == len(stmt.body) - 1:
|
||||
self._mark_last()
|
||||
body_stmt.accept(self)
|
||||
|
||||
def visit_raw_stmt(self, stmt: p.RawStmt) -> None:
|
||||
self._write_line("RawStmt")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"stmt: {ast.unparse(stmt.stmt)}")
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||
self._write_line("BinaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
|
||||
self._write_line("CompareExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
|
||||
self._write_line("UnaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||
self._write_line("CallExpr")
|
||||
with self._child_level():
|
||||
self._write_line("callee")
|
||||
with self._child_level(single=True):
|
||||
expr.callee.accept(self)
|
||||
|
||||
self._write_line("arguments")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(expr.arguments):
|
||||
self._idx = i
|
||||
if i == len(expr.arguments) - 1:
|
||||
self._mark_last()
|
||||
arg.accept(self)
|
||||
|
||||
self._write_line("keywords", last=True)
|
||||
with self._child_level():
|
||||
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||
self._idx = i
|
||||
if i == len(expr.keywords) - 1:
|
||||
self._mark_last()
|
||||
self._write_line(name)
|
||||
with self._child_level(single=True):
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> None:
|
||||
self._write_line("GetExpr")
|
||||
with self._child_level():
|
||||
self._write_line("object")
|
||||
with self._child_level(single=True):
|
||||
expr.object.accept(self)
|
||||
self._write_line(f"name: {expr.name}", last=True)
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"value: {expr.value!r}")
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
|
||||
self._write_line("VariableExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"name: {expr.name}")
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
|
||||
self._write_line("LogicalExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||
self._write_line("CastExpr")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
expr.type.accept(self)
|
||||
self._write_line("expr")
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
self._write_line(f"unsafe: {expr.unsafe}", last=True)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||
self._write_line("TernaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("test")
|
||||
with self._child_level(single=True):
|
||||
expr.test.accept(self)
|
||||
|
||||
self._write_line("if_true")
|
||||
with self._child_level(single=True):
|
||||
expr.if_true.accept(self)
|
||||
|
||||
self._write_line("if_false", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.if_false.accept(self)
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||
self._write_line("ListExpr")
|
||||
with self._child_level():
|
||||
self._write_line("items", last=True)
|
||||
with self._child_level():
|
||||
for i, item in enumerate(expr.items):
|
||||
self._idx = i
|
||||
if i == len(expr.items) - 1:
|
||||
self._mark_last()
|
||||
item.accept(self)
|
||||
|
||||
def visit_dict_expr(self, expr: p.DictExpr) -> None:
|
||||
self._write_line("DictExpr")
|
||||
with self._child_level():
|
||||
self._write_line("keys")
|
||||
with self._child_level():
|
||||
for i, key in enumerate(expr.keys):
|
||||
self._idx = i
|
||||
if i == len(expr.keys) - 1:
|
||||
self._mark_last()
|
||||
if key is None:
|
||||
self._write_line("None")
|
||||
else:
|
||||
key.accept(self)
|
||||
self._write_line("values", last=True)
|
||||
with self._child_level():
|
||||
for i, value in enumerate(expr.values):
|
||||
self._idx = i
|
||||
if i == len(expr.values) - 1:
|
||||
self._mark_last()
|
||||
value.accept(self)
|
||||
|
||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
||||
self._write_line("SubscriptExpr")
|
||||
with self._child_level():
|
||||
self._write_line("object")
|
||||
with self._child_level(single=True):
|
||||
expr.object.accept(self)
|
||||
self._write_line("index", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.index.accept(self)
|
||||
|
||||
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
|
||||
self._write_line("SliceExpr")
|
||||
with self._child_level():
|
||||
self._write_optional_child("lower", expr.lower)
|
||||
self._write_optional_child("upper", expr.upper)
|
||||
self._write_optional_child("step", expr.step, last=True)
|
||||
|
||||
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
||||
self._write_line("RawExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"expr: {ast.unparse(expr.expr)}")
|
||||
3
midas/ast/printer/__init__.py
Normal file
3
midas/ast/printer/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .midas import MidasPrinter as MidasPrinter
|
||||
from .midas_ast import MidasAstPrinter as MidasAstPrinter
|
||||
from .python_ast import PythonAstPrinter as PythonAstPrinter
|
||||
103
midas/ast/printer/base.py
Normal file
103
midas/ast/printer/base.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, Generator, Generic, Optional, Protocol, Sequence, TypeVar
|
||||
|
||||
|
||||
class _Level(Enum):
|
||||
EMPTY = auto()
|
||||
ACTIVE = auto()
|
||||
LAST = auto()
|
||||
|
||||
|
||||
class Expr(Protocol):
|
||||
def accept(self, printer: AstPrinter) -> None: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Expr)
|
||||
|
||||
|
||||
class AstPrinter(Generic[T]):
|
||||
LAST_CHILD = "└── "
|
||||
CHILD = "├── "
|
||||
VERTICAL = "│ "
|
||||
EMPTY = " "
|
||||
|
||||
def __init__(self):
|
||||
self._levels: list[_Level] = []
|
||||
self._idx: Optional[int] = None
|
||||
self._buf: io.StringIO = io.StringIO()
|
||||
|
||||
def print(self, expr: T):
|
||||
self._buf = io.StringIO()
|
||||
expr.accept(self)
|
||||
return self._buf.getvalue()
|
||||
|
||||
@contextmanager
|
||||
def _child_level(self, single: bool = False) -> Generator[None, None, None]:
|
||||
self._levels.append(_Level.LAST if single else _Level.ACTIVE)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._levels.pop()
|
||||
|
||||
def _mark_last(self):
|
||||
if self._levels:
|
||||
self._levels[-1] = _Level.LAST
|
||||
|
||||
def _write_line(self, text: str, *, last: bool = False):
|
||||
if last:
|
||||
self._mark_last()
|
||||
indent: str = self._build_indent()
|
||||
if self._idx is not None:
|
||||
text = f"[{self._idx}] {text}"
|
||||
self._idx = None
|
||||
self._buf.write(indent + text + "\n")
|
||||
|
||||
def _build_indent(self) -> str:
|
||||
parts: list[str] = []
|
||||
for level in self._levels[:-1]:
|
||||
parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL)
|
||||
if self._levels:
|
||||
if self._levels[-1] == _Level.LAST:
|
||||
parts.append(self.LAST_CHILD)
|
||||
self._levels[-1] = _Level.EMPTY
|
||||
else:
|
||||
parts.append(self.CHILD)
|
||||
return "".join(parts)
|
||||
|
||||
def _write_optional_child(
|
||||
self, label: str, child: Optional[T], *, last: bool = False
|
||||
):
|
||||
if last:
|
||||
self._mark_last()
|
||||
if child is None:
|
||||
self._write_line(f"{label}: None")
|
||||
else:
|
||||
self._write_line(label)
|
||||
with self._child_level(single=True):
|
||||
child.accept(self)
|
||||
|
||||
def _write_sequence(
|
||||
self,
|
||||
label: str,
|
||||
list_: Sequence[T],
|
||||
*,
|
||||
last: bool = False,
|
||||
print_func: Optional[Callable[[T], None]] = None,
|
||||
):
|
||||
if last:
|
||||
self._mark_last()
|
||||
|
||||
self._write_line(label)
|
||||
with self._child_level():
|
||||
for i, item in enumerate(list_):
|
||||
self._idx = i
|
||||
if i == len(list_) - 1:
|
||||
self._mark_last()
|
||||
if print_func is not None:
|
||||
print_func(item)
|
||||
else:
|
||||
item.accept(self)
|
||||
183
midas/ast/printer/midas.py
Normal file
183
midas/ast/printer/midas.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import midas.ast.midas as m
|
||||
|
||||
|
||||
class MidasPrinter(
|
||||
m.Expr.Visitor[str],
|
||||
m.Stmt.Visitor[str],
|
||||
m.Type.Visitor[str],
|
||||
):
|
||||
def __init__(self, indent: int = 4):
|
||||
self.indent: int = indent
|
||||
self.level: int = 0
|
||||
|
||||
def indented(self, text: str) -> str:
|
||||
return " " * (self.level * self.indent) + text
|
||||
|
||||
def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
|
||||
self.level = 0
|
||||
return expr.accept(self)
|
||||
|
||||
# Statements
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
|
||||
template: str = ""
|
||||
if len(stmt.params) != 0:
|
||||
params: list[str] = [self._print_type_param(param) for param in stmt.params]
|
||||
template = f"[{', '.join(params)}]"
|
||||
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
||||
return self.indented(res)
|
||||
|
||||
def visit_alias_stmt(self, stmt: m.AliasStmt) -> str:
|
||||
return self.indented(f"alias {stmt.name.lexeme} = {stmt.type.accept(self)}")
|
||||
|
||||
def _print_type_param(self, param: m.TypeParam) -> str:
|
||||
res: str = param.name.lexeme
|
||||
if param.bound is not None:
|
||||
res += "<:" + param.bound.accept(self)
|
||||
return res
|
||||
|
||||
def visit_member_stmt(self, stmt: m.MemberStmt):
|
||||
keyword: str = {
|
||||
m.MemberKind.PROPERTY: "prop",
|
||||
m.MemberKind.METHOD: "def",
|
||||
}.get(stmt.kind, "")
|
||||
res: str = f"{keyword} {stmt.name.lexeme}: {stmt.type.accept(self)}"
|
||||
return self.indented(res)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt):
|
||||
template: str = ""
|
||||
if len(stmt.params) != 0:
|
||||
params: list[str] = [self._print_type_param(param) for param in stmt.params]
|
||||
template = f"[{', '.join(params)}]"
|
||||
res: str = self.indented(f"extend {stmt.name.lexeme}{template}")
|
||||
res += " {\n"
|
||||
self.level += 1
|
||||
for member in stmt.members:
|
||||
res += member.accept(self) + "\n"
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
name: str = stmt.name.lexeme
|
||||
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
|
||||
body: str = stmt.body.accept(self)
|
||||
return self.indented(f"predicate {name}{sig} = {body}")
|
||||
|
||||
# Expressions
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{left} {operator} {right}"
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{left} {operator} {right}"
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr):
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{operator}{right}"
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> str:
|
||||
args: list[str] = [arg.accept(self) for arg in expr.arguments] + [
|
||||
f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items()
|
||||
]
|
||||
return f"{expr.callee.accept(self)}({', '.join(args)})"
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
expr_: str = expr.expr.accept(self)
|
||||
name: str = expr.name.lexeme
|
||||
return f"{expr_}.{name}"
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr):
|
||||
return expr.name.lexeme
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr):
|
||||
expr_: str = expr.expr.accept(self)
|
||||
return f"({expr_})"
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr):
|
||||
return str(expr.value)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr):
|
||||
return "_"
|
||||
|
||||
# Types
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> str:
|
||||
return type.name.lexeme
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
if len(type.args) != 0:
|
||||
args: list[str] = [param.accept(self) for param in type.args]
|
||||
res += f"[{', '.join(args)}]"
|
||||
return res
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
res += " where " + type.constraint.accept(self)
|
||||
return res
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> str:
|
||||
res: str = "{\n"
|
||||
self.level += 1
|
||||
for member in type.members:
|
||||
res += member.accept(self)
|
||||
res += "\n"
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_extension_type(self, type: m.ExtensionType) -> str:
|
||||
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||
spec: str = self._visit_param_spec(type.params)
|
||||
return f"fn {spec} -> {type.returns.accept(self)}"
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
|
||||
pos: list[str] = [self._print_param(param) for param in spec.pos]
|
||||
mixed: list[str] = [self._print_param(param) for param in spec.mixed]
|
||||
kw: list[str] = [self._print_param(param) for param in spec.kw]
|
||||
params: list[str] = pos
|
||||
|
||||
if len(pos) != 0:
|
||||
params.append("/")
|
||||
params += mixed
|
||||
if len(kw) != 0:
|
||||
params.append("*")
|
||||
params += kw
|
||||
return f"({', '.join(params)})"
|
||||
|
||||
def _print_param(self, param: m.FunctionType.Parameter) -> str:
|
||||
res: str = ""
|
||||
if param.name is not None:
|
||||
res += param.name.lexeme
|
||||
res += ": "
|
||||
res += param.type.accept(self)
|
||||
if not param.required:
|
||||
res += "?"
|
||||
return res
|
||||
|
||||
def visit_frame_type(self, type: m.FrameType) -> str:
|
||||
res: str = self.indented("Frame[")
|
||||
if len(type.columns) != 0:
|
||||
res += "\n"
|
||||
self.level += 1
|
||||
columns: list[str] = []
|
||||
for column in type.columns:
|
||||
columns.append(self.indented(self._print_frame_column(column)))
|
||||
res += ",\n".join(columns)
|
||||
self.level -= 1
|
||||
res += "\n"
|
||||
res += "]"
|
||||
return res
|
||||
|
||||
def _print_frame_column(self, column: m.FrameType.Column) -> str:
|
||||
return f"{column.name.lexeme}: {column.type.accept(self)}"
|
||||
253
midas/ast/printer/midas_ast.py
Normal file
253
midas/ast/printer/midas_ast.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import midas.ast.midas as m
|
||||
from midas.ast.printer.base import AstPrinter
|
||||
|
||||
|
||||
class MidasAstPrinter(
|
||||
AstPrinter,
|
||||
m.Expr.Visitor[None],
|
||||
m.Stmt.Visitor[None],
|
||||
m.Type.Visitor[None],
|
||||
):
|
||||
# Statements
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
self._write_line("TypeStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_sequence(
|
||||
"params",
|
||||
stmt.params,
|
||||
print_func=self._print_type_param,
|
||||
)
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
|
||||
self._write_line("AliasStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def _print_type_param(self, param: m.TypeParam) -> None:
|
||||
self._write_line("Param")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{param.name.lexeme}"')
|
||||
self._write_optional_child("bound", param.bound, last=True)
|
||||
|
||||
def visit_member_stmt(self, stmt: m.MemberStmt):
|
||||
self._write_line("MemberStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f"kind: {stmt.kind.name}")
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self._write_line("ExtendStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_sequence(
|
||||
"params",
|
||||
stmt.params,
|
||||
print_func=self._print_type_param,
|
||||
)
|
||||
self._write_sequence("members", stmt.members, last=True)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
self._write_line("PredicateStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_sequence(
|
||||
"params",
|
||||
stmt.params,
|
||||
print_func=self._visit_param_spec,
|
||||
)
|
||||
self._write_line("body", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.body.accept(self)
|
||||
|
||||
# Expressions
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
self._write_line("LogicalExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr):
|
||||
self._write_line("BinaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr):
|
||||
self._write_line("UnaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> None:
|
||||
self._write_line("CallExpr")
|
||||
with self._child_level():
|
||||
self._write_line("callee")
|
||||
with self._child_level(single=True):
|
||||
expr.callee.accept(self)
|
||||
self._write_sequence("arguments", expr.arguments)
|
||||
self._write_line("keywords", last=True)
|
||||
with self._child_level():
|
||||
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||
self._idx = i
|
||||
if i == len(expr.keywords) - 1:
|
||||
self._mark_last()
|
||||
self._write_line(name)
|
||||
with self._child_level(single=True):
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
self._write_line("GetExpr")
|
||||
with self._child_level():
|
||||
self._write_line("expr")
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr):
|
||||
self._write_line("VariableExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr):
|
||||
self._write_line("GroupingExpr")
|
||||
with self._child_level():
|
||||
self._write_line("expr", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"value: {expr.value}", last=True)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||
self._write_line("WildcardExpr")
|
||||
|
||||
# Types
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> None:
|
||||
self._write_line("NamedType")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{type.name.lexeme}"', last=True)
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> None:
|
||||
self._write_line("GenericType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level():
|
||||
type.type.accept(self)
|
||||
self._write_sequence("args", type.args, last=True)
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
||||
self._write_line("ConstraintType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
type.type.accept(self)
|
||||
self._write_line("constraint", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.constraint.accept(self)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> None:
|
||||
self._write_line("ComplexType")
|
||||
with self._child_level():
|
||||
self._write_sequence("members", type.members, last=True)
|
||||
|
||||
def visit_extension_type(self, type: m.ExtensionType) -> None:
|
||||
self._write_line("ExtensionType")
|
||||
with self._child_level():
|
||||
self._write_line("base")
|
||||
with self._child_level(single=True):
|
||||
type.base.accept(self)
|
||||
self._write_line("extension", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.extension.accept(self)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||
self._write_line("FunctionType")
|
||||
with self._child_level():
|
||||
self._write_line("params")
|
||||
with self._child_level(single=True):
|
||||
self._visit_param_spec(type.params)
|
||||
|
||||
self._write_line("returns", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.returns.accept(self)
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
|
||||
self._write_line("ParamSpec")
|
||||
with self._child_level():
|
||||
self._write_sequence(
|
||||
"pos",
|
||||
spec.pos,
|
||||
print_func=self._print_param,
|
||||
)
|
||||
self._write_sequence(
|
||||
"mixed",
|
||||
spec.mixed,
|
||||
print_func=self._print_param,
|
||||
)
|
||||
self._write_sequence(
|
||||
"kw",
|
||||
spec.kw,
|
||||
print_func=self._print_param,
|
||||
last=True,
|
||||
)
|
||||
|
||||
def _print_param(self, param: m.FunctionType.Parameter) -> None:
|
||||
self._write_line("Parameter")
|
||||
with self._child_level():
|
||||
name: str = "None"
|
||||
if param.name is not None:
|
||||
name = f'"{param.name.lexeme}"'
|
||||
self._write_line(f"name: {name}")
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
param.type.accept(self)
|
||||
self._write_line(f"required: {param.required}", last=True)
|
||||
|
||||
def visit_frame_type(self, type: m.FrameType) -> None:
|
||||
self._write_line("FrameType")
|
||||
with self._child_level(single=True):
|
||||
self._write_sequence(
|
||||
"columns",
|
||||
type.columns,
|
||||
print_func=self._print_frame_column,
|
||||
)
|
||||
|
||||
def _print_frame_column(self, column: m.FrameType.Column) -> None:
|
||||
self._write_line("Column")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{column.name.lexeme}"')
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
column.type.accept(self)
|
||||
285
midas/ast/printer/python_ast.py
Normal file
285
midas/ast/printer/python_ast.py
Normal file
@@ -0,0 +1,285 @@
|
||||
import ast
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.printer.base import AstPrinter
|
||||
|
||||
|
||||
class PythonAstPrinter(
|
||||
AstPrinter,
|
||||
p.MidasType.Visitor[None],
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[None],
|
||||
):
|
||||
# Types
|
||||
|
||||
def visit_base_type(self, node: p.BaseType) -> None:
|
||||
self._write_line("BaseType")
|
||||
with self._child_level():
|
||||
self._write_line(f"base: {node.base}")
|
||||
self._write_sequence("args", node.args, last=True)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||
self._write_line("ConstraintType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
node.type.accept(self)
|
||||
self._write_line(f"constraint: {ast.unparse(node.constraint)}", last=True)
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> None:
|
||||
self._write_line("FrameColumn")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {node.name}")
|
||||
self._write_optional_child("type", node.type, last=True)
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> None:
|
||||
self._write_line("FrameType")
|
||||
with self._child_level(single=True):
|
||||
self._write_sequence("columns", node.columns)
|
||||
|
||||
# Statements
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
stmt.expr.accept(self)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
self._write_line("Function")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {stmt.name}")
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
self._print_param_spec(stmt.params)
|
||||
|
||||
self._write_optional_child("returns", stmt.returns)
|
||||
self._write_sequence("body", stmt.body, last=True)
|
||||
|
||||
def _print_param_spec(self, spec: p.ParamSpec) -> None:
|
||||
self._write_line("ParamSpec")
|
||||
with self._child_level():
|
||||
self._write_sequence(
|
||||
"pos",
|
||||
spec.pos,
|
||||
print_func=self._print_param,
|
||||
)
|
||||
self._write_sequence(
|
||||
"mixed",
|
||||
spec.mixed,
|
||||
print_func=self._print_param,
|
||||
)
|
||||
self._write_sequence(
|
||||
"kw",
|
||||
spec.kw,
|
||||
print_func=self._print_param,
|
||||
last=True,
|
||||
)
|
||||
|
||||
def _print_param(self, param: p.Function.Parameter) -> None:
|
||||
self._write_line("Parameter")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {param.name}")
|
||||
self._write_optional_child("type", param.type, last=True)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
self._write_line("TypeAssign")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {stmt.name}")
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
self._write_line("AssignStmt")
|
||||
with self._child_level():
|
||||
self._write_sequence("targets", stmt.targets)
|
||||
self._write_line("value", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.value.accept(self)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
self._write_line("ReturnStmt")
|
||||
with self._child_level():
|
||||
self._write_optional_child("value", stmt.value, last=True)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
self._write_line("IfStmt")
|
||||
with self._child_level():
|
||||
self._write_line("test")
|
||||
with self._child_level(single=True):
|
||||
stmt.test.accept(self)
|
||||
self._write_sequence("body", stmt.body)
|
||||
self._write_sequence("orelse", stmt.orelse, last=True)
|
||||
|
||||
def visit_pass(self, stmt: p.Pass) -> None:
|
||||
self._write_line("Pass")
|
||||
|
||||
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
|
||||
self._write_line("ForStmt")
|
||||
with self._child_level():
|
||||
self._write_line("target")
|
||||
with self._child_level(single=True):
|
||||
stmt.target.accept(self)
|
||||
self._write_line("iterator")
|
||||
with self._child_level(single=True):
|
||||
stmt.iterator.accept(self)
|
||||
self._write_sequence("body", stmt.body, last=True)
|
||||
|
||||
def visit_raw_stmt(self, stmt: p.RawStmt) -> None:
|
||||
self._write_line("RawStmt")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"stmt: {ast.unparse(stmt.stmt)}")
|
||||
|
||||
# Expressions
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||
self._write_line("BinaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
|
||||
self._write_line("CompareExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
|
||||
self._write_line("UnaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||
self._write_line("CallExpr")
|
||||
with self._child_level():
|
||||
self._write_line("callee")
|
||||
with self._child_level(single=True):
|
||||
expr.callee.accept(self)
|
||||
|
||||
self._write_sequence("arguments", expr.arguments)
|
||||
self._write_line("keywords", last=True)
|
||||
with self._child_level():
|
||||
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||
self._idx = i
|
||||
if i == len(expr.keywords) - 1:
|
||||
self._mark_last()
|
||||
self._write_line(name)
|
||||
with self._child_level(single=True):
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> None:
|
||||
self._write_line("GetExpr")
|
||||
with self._child_level():
|
||||
self._write_line("object")
|
||||
with self._child_level(single=True):
|
||||
expr.object.accept(self)
|
||||
self._write_line(f"name: {expr.name}", last=True)
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"value: {expr.value!r}")
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
|
||||
self._write_line("VariableExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"name: {expr.name}")
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
|
||||
self._write_line("LogicalExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||
self._write_line("CastExpr")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
expr.type.accept(self)
|
||||
self._write_line("expr")
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
self._write_line(f"unsafe: {expr.unsafe}", last=True)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||
self._write_line("TernaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("test")
|
||||
with self._child_level(single=True):
|
||||
expr.test.accept(self)
|
||||
|
||||
self._write_line("if_true")
|
||||
with self._child_level(single=True):
|
||||
expr.if_true.accept(self)
|
||||
|
||||
self._write_line("if_false", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.if_false.accept(self)
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||
self._write_line("ListExpr")
|
||||
with self._child_level():
|
||||
self._write_sequence("items", expr.items, last=True)
|
||||
|
||||
def visit_dict_expr(self, expr: p.DictExpr) -> None:
|
||||
self._write_line("DictExpr")
|
||||
with self._child_level():
|
||||
self._write_sequence(
|
||||
"keys",
|
||||
expr.keys,
|
||||
print_func=lambda k: (
|
||||
self._write_line("None") if k is None else k.accept(self)
|
||||
),
|
||||
)
|
||||
self._write_sequence("values", expr.values, last=True)
|
||||
|
||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
||||
self._write_line("SubscriptExpr")
|
||||
with self._child_level():
|
||||
self._write_line("object")
|
||||
with self._child_level(single=True):
|
||||
expr.object.accept(self)
|
||||
self._write_line("index", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.index.accept(self)
|
||||
|
||||
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
|
||||
self._write_line("SliceExpr")
|
||||
with self._child_level():
|
||||
self._write_optional_child("lower", expr.lower)
|
||||
self._write_optional_child("upper", expr.upper)
|
||||
self._write_optional_child("step", expr.step, last=True)
|
||||
|
||||
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
|
||||
self._write_line("TupleExpr")
|
||||
with self._child_level():
|
||||
self._write_sequence("items", expr.items, last=True)
|
||||
|
||||
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
||||
self._write_line("RawExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"expr: {ast.unparse(expr.expr)}")
|
||||
@@ -14,6 +14,16 @@ from midas.ast.location import Location
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
pos: list[Function.Parameter]
|
||||
mixed: list[Function.Parameter]
|
||||
kw: list[Function.Parameter]
|
||||
|
||||
@property
|
||||
def all(self) -> list[Function.Parameter]:
|
||||
return self.pos + self.mixed + self.kw
|
||||
|
||||
|
||||
####################
|
||||
# Type annotations #
|
||||
@@ -44,7 +54,7 @@ class MidasType(ABC):
|
||||
@dataclass(frozen=True)
|
||||
class BaseType(MidasType):
|
||||
base: str
|
||||
param: Optional[MidasType]
|
||||
args: tuple[MidasType, ...]
|
||||
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_base_type(self)
|
||||
@@ -128,25 +138,17 @@ class ExpressionStmt(Stmt):
|
||||
@dataclass(frozen=True)
|
||||
class Function(Stmt):
|
||||
name: str
|
||||
posonlyargs: list[Argument]
|
||||
args: list[Argument]
|
||||
sink: Optional[Argument]
|
||||
kwonlyargs: list[Argument]
|
||||
kw_sink: Optional[Argument]
|
||||
params: ParamSpec
|
||||
returns: Optional[MidasType]
|
||||
body: list[Stmt]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
class Parameter:
|
||||
location: Optional[Location] = None
|
||||
name: str
|
||||
type: Optional[MidasType]
|
||||
default: Optional[Expr]
|
||||
|
||||
@property
|
||||
def all_args(self) -> list[Argument]:
|
||||
return self.posonlyargs + self.args + self.kwonlyargs
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_function(self)
|
||||
|
||||
@@ -268,6 +270,9 @@ class Expr(ABC):
|
||||
@abstractmethod
|
||||
def visit_slice_expr(self, expr: SliceExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_tuple_expr(self, expr: TupleExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_raw_expr(self, expr: RawExpr) -> T: ...
|
||||
|
||||
@@ -402,6 +407,14 @@ class SliceExpr(Expr):
|
||||
return visitor.visit_slice_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TupleExpr(Expr):
|
||||
items: tuple[Expr, ...]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_tuple_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RawExpr(Expr):
|
||||
expr: ast.expr
|
||||
|
||||
@@ -178,4 +178,100 @@ extend dict[K, V] {
|
||||
// def __ior__: fn(value: SupportsKeysAndGetItem[K, V], /) -> dict[K, V]
|
||||
// def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V]
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
extend str {
|
||||
def capitalize: fn() -> str
|
||||
def casefold: fn() -> str
|
||||
def center: fn(width: int, fillchar: str?, /) -> str
|
||||
def count: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def count: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def count: fn(sub: str, start: None, end: int, /) -> int
|
||||
def count: fn(sub: str, start: int, end: int, /) -> int
|
||||
def encode: fn(encoding: str?, errors: str?) -> bytes
|
||||
def endswith: fn(suffix: str, start: None?, end: None?, /) -> bool
|
||||
def endswith: fn(suffix: str, start: int, end: None?, /) -> bool
|
||||
def endswith: fn(suffix: str, start: None, end: int, /) -> bool
|
||||
def endswith: fn(suffix: str, start: int, end: int, /) -> bool
|
||||
def expandtabs: fn(tabsize: int?) -> str
|
||||
def find: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def find: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def find: fn(sub: str, start: None, end: int, /) -> int
|
||||
def find: fn(sub: str, start: int, end: int, /) -> int
|
||||
// def format: fn(*args: object, **kwargs: object) -> str
|
||||
// def format_map: fn(mapping: _FormatMapMapping, /) -> str
|
||||
def index: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def index: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def index: fn(sub: str, start: None, end: int, /) -> int
|
||||
def index: fn(sub: str, start: int, end: int, /) -> int
|
||||
def isalnum: fn() -> bool
|
||||
def isalpha: fn() -> bool
|
||||
def isascii: fn() -> bool
|
||||
def isdecimal: fn() -> bool
|
||||
def isdigit: fn() -> bool
|
||||
def isidentifier: fn() -> bool
|
||||
def islower: fn() -> bool
|
||||
def isnumeric: fn() -> bool
|
||||
def isprintable: fn() -> bool
|
||||
def isspace: fn() -> bool
|
||||
def istitle: fn() -> bool
|
||||
def isupper: fn() -> bool
|
||||
def join: fn(iterable: list[str], /) -> str // TODO: use Iterable
|
||||
def ljust: fn(width: int, fillchar: str?, /) -> str
|
||||
def lower: fn() -> str
|
||||
def lstrip: fn(chars: None?, /) -> str
|
||||
def lstrip: fn(chars: str, /) -> str
|
||||
def partition: fn(sep: str, /) -> tuple[str, str, str]
|
||||
|
||||
def replace: fn(old: str, new: str, count: int?, /) -> str
|
||||
|
||||
def removeprefix: fn(prefix: str, /) -> str
|
||||
def removesuffix: fn(suffix: str, /) -> str
|
||||
def rfind: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def rfind: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def rfind: fn(sub: str, start: None, end: int, /) -> int
|
||||
def rfind: fn(sub: str, start: int, end: int, /) -> int
|
||||
def rindex: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def rindex: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def rindex: fn(sub: str, start: None, end: int, /) -> int
|
||||
def rindex: fn(sub: str, start: int, end: int, /) -> int
|
||||
def rjust: fn(width: int, fillchar: str?, /) -> str
|
||||
def rpartition: fn(sep: str, /) -> tuple[str, str, str]
|
||||
def rsplit: fn(sep: None?, maxsplit: int?) -> list[str]
|
||||
def rsplit: fn(sep: str, maxsplit: int?) -> list[str]
|
||||
def rstrip: fn(chars: None?, /) -> str
|
||||
def rstrip: fn(chars: str, /) -> str
|
||||
def split: fn(sep: None?, maxsplit: int?) -> list[str]
|
||||
def split: fn(sep: str, maxsplit: int?) -> list[str]
|
||||
def splitlines: fn(keepends: bool?) -> list[str]
|
||||
def startswith: fn(prefix: str, start: None?, end: None?, /) -> bool
|
||||
def startswith: fn(prefix: str, start: int, end: None?, /) -> bool
|
||||
def startswith: fn(prefix: str, start: None, end: int, /) -> bool
|
||||
def startswith: fn(prefix: str, start: int, end: int, /) -> bool
|
||||
def strip: fn(chars: None?, /) -> str
|
||||
def strip: fn(chars: str, /) -> str
|
||||
def swapcase: fn() -> str
|
||||
def title: fn() -> str
|
||||
// def translate: fn(table: _TranslateTable, /) -> str
|
||||
def upper: fn() -> str
|
||||
def zfill: fn(width: int, /) -> str
|
||||
def __add__: fn(value: str, /) -> str
|
||||
// Incompatible with Sequence.__contains__
|
||||
def __contains__: fn(key: str, /) -> bool
|
||||
def __eq__: fn(value: object, /) -> bool
|
||||
def __ge__: fn(value: str, /) -> bool
|
||||
def __getitem__: fn(key: slice, /) -> str
|
||||
def __getitem__: fn(key: int, /) -> str
|
||||
def __gt__: fn(value: str, /) -> bool
|
||||
def __hash__: fn() -> int
|
||||
// def __iter__: fn() -> Iterator[str]
|
||||
def __le__: fn(value: str, /) -> bool
|
||||
def __len__: fn() -> int
|
||||
def __lt__: fn(value: str, /) -> bool
|
||||
def __mod__: fn(value: Any, /) -> str
|
||||
def __mul__: fn(value: int, /) -> str
|
||||
def __ne__: fn(value: object, /) -> bool
|
||||
def __rmul__: fn(value: int, /) -> str
|
||||
def __getnewargs__: fn() -> tuple[str]
|
||||
def __format__: fn(format_spec: str, /) -> str
|
||||
}
|
||||
|
||||
@@ -15,10 +15,14 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||
"object": {"float", "list", "dict", "str"},
|
||||
"object": {"float", "list", "dict", "str", "bytes", "tuple"},
|
||||
"float": {"int"},
|
||||
"int": {"bool"},
|
||||
}
|
||||
"""
|
||||
Hard-coded subtype relationships between builtin types
|
||||
|
||||
Circular dependencies and diamond inheritance MUST be avoided
|
||||
"""
|
||||
|
||||
|
||||
def define_builtins(reg: TypesRegistry):
|
||||
@@ -26,12 +30,15 @@ def define_builtins(reg: TypesRegistry):
|
||||
any = reg.define_type("Any", TopType())
|
||||
unit = reg.define_type("None", UnitType())
|
||||
object = reg.define_type("object", BaseType(name="object"))
|
||||
bytes = reg.define_type("bytes", BaseType(name="bytes"))
|
||||
bool = reg.define_type("bool", BaseType(name="bool"))
|
||||
int = reg.define_type("int", BaseType(name="int"))
|
||||
float = reg.define_type("float", BaseType(name="float"))
|
||||
str = reg.define_type("str", BaseType(name="str"))
|
||||
slice = reg.define_type("slice", BaseType(name="slice"))
|
||||
|
||||
tuple = reg.define_type("tuple", BaseType(name="tuple"))
|
||||
|
||||
list = reg.define_type(
|
||||
"list",
|
||||
GenericType(
|
||||
|
||||
@@ -10,6 +10,11 @@ from midas.utils import TypedAST
|
||||
|
||||
|
||||
class TypeChecker:
|
||||
"""Type checking dispatcher
|
||||
|
||||
Contains a typer for Midas and one for Python, as well as the types registry
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.types: TypesRegistry = TypesRegistry()
|
||||
self.reporter: Reporter = Reporter()
|
||||
|
||||
@@ -14,6 +14,15 @@ class DiagnosticType(StrEnum):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Diagnostic:
|
||||
"""Information about a diagnostic (warning, errors, etc.)
|
||||
|
||||
Holds a location, a diagnostic type and a message.
|
||||
Optionally bound to a file path
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
|
||||
file_path: Optional[str]
|
||||
location: Location
|
||||
type: DiagnosticType
|
||||
@@ -21,6 +30,18 @@ class Diagnostic:
|
||||
|
||||
@property
|
||||
def location_str(self) -> str:
|
||||
"""The diagnostic type and location as a human readable string
|
||||
|
||||
The location is formatted as "<Type> in <file> from L<start_line>:<start_col> to <end_line>:<end_col>",
|
||||
for example: "Error in /home/user/Desktop/script.py from L12:5 to L12:8"
|
||||
|
||||
If the file is `None`, the "in ..." section is excluded from the result.<br>
|
||||
If the location's end is not specified, the formulation "at L<start_line>:<start_col>" is used.
|
||||
|
||||
Returns:
|
||||
str: _description_
|
||||
"""
|
||||
|
||||
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
|
||||
end_loc: Optional[str] = ""
|
||||
if (
|
||||
|
||||
486
midas/checker/dispatcher.py
Normal file
486
midas/checker/dispatcher.py
Normal file
@@ -0,0 +1,486 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Generic, Optional, Protocol, TypeVar, Union
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
DerivedType,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.checker.unifier import Unifier
|
||||
|
||||
|
||||
class HasLocation(Protocol):
|
||||
@property
|
||||
def location(self) -> Location: ...
|
||||
|
||||
|
||||
E = TypeVar("E", bound=HasLocation)
|
||||
|
||||
TypedExpr = tuple[E, Type]
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MappedArgument(Generic[E]):
|
||||
arg_expr: E
|
||||
arg_type: Type
|
||||
parameter: Function.Parameter
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class OverloadCandidate:
|
||||
function: Function
|
||||
mapped: list[MappedArgument]
|
||||
|
||||
|
||||
class CallError(StrEnum):
|
||||
INVALID_ARGS = "Invalid arguments"
|
||||
NO_MATCHING_OVERLOAD = "No matching overload"
|
||||
IMPOSSIBLE_UNIFICATION = "Parameters unification failed"
|
||||
NOT_CALLABLE = "Not callable"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class CallResult:
|
||||
error: Optional[CallError] = None
|
||||
result: Type = UnknownType()
|
||||
message: Optional[str] = None
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return self.error is None
|
||||
|
||||
@property
|
||||
def error_message(self) -> str:
|
||||
if self.message is not None:
|
||||
return self.message
|
||||
if self.error is not None:
|
||||
return str(self.error)
|
||||
return ""
|
||||
|
||||
|
||||
class CallDispatcher(Generic[E]):
|
||||
def __init__(self, types: TypesRegistry, reporter: FileReporter) -> None:
|
||||
self.types: TypesRegistry = types
|
||||
self.reporter: FileReporter = reporter
|
||||
self.logger: logging.Logger = logging.getLogger("CallDispatcher")
|
||||
|
||||
def set_reporter(self, reporter: FileReporter):
|
||||
self.reporter = reporter
|
||||
|
||||
def get_result(
|
||||
self,
|
||||
location: Location,
|
||||
callee: Type,
|
||||
positional: list[TypedExpr[E]],
|
||||
keywords: dict[str, TypedExpr[E]],
|
||||
report_errors: bool = True,
|
||||
) -> CallResult:
|
||||
"""Get the result type of a function call
|
||||
|
||||
If the function has overloads, the function will try to resolve the
|
||||
appropriate signature.
|
||||
Argument types are matched to the defined parameters.
|
||||
The function doesn't take the raw expression as a parameter to accommodate
|
||||
for desugared calls such as for operators.
|
||||
|
||||
Args:
|
||||
location (Location): the call location
|
||||
callee (Type): the called function
|
||||
positional (list[TypedExpr]): the list positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Type: the return type of the call, or `None` if either
|
||||
the call is invalid or no overload matched the arguments uniquely
|
||||
"""
|
||||
match callee:
|
||||
case Function() as function:
|
||||
valid: bool
|
||||
mapped: list[MappedArgument[E]]
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function, location, positional, keywords
|
||||
)
|
||||
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
||||
if not valid:
|
||||
return CallResult(error=CallError.INVALID_ARGS)
|
||||
return CallResult(result=function.returns)
|
||||
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
res = self._match_overload(
|
||||
overloads, location, positional, keywords, report_errors
|
||||
)
|
||||
if res[0] is None:
|
||||
return CallResult(
|
||||
error=CallError.NO_MATCHING_OVERLOAD,
|
||||
message=res[1],
|
||||
)
|
||||
return CallResult(result=res[0].returns)
|
||||
|
||||
case AppliedType(body=body):
|
||||
return self.get_result(
|
||||
location, body, positional, keywords, report_errors
|
||||
)
|
||||
|
||||
case UnknownType():
|
||||
return CallResult(result=UnknownType())
|
||||
|
||||
case DerivedType(type=base):
|
||||
return self.get_result(
|
||||
location, base, positional, keywords, report_errors
|
||||
)
|
||||
|
||||
case GenericType():
|
||||
unifier: Unifier = Unifier(self.types)
|
||||
pos: list[Type] = [a[1] for a in positional]
|
||||
kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()}
|
||||
unified: Optional[Type] = unifier.unify_call(callee, pos, kw)
|
||||
if unified is None:
|
||||
pos_str: str = ", ".join(str(t) for t in pos)
|
||||
kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items())
|
||||
message: str = (
|
||||
f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}"
|
||||
)
|
||||
if report_errors:
|
||||
self.reporter.error(location, message)
|
||||
return CallResult(
|
||||
error=CallError.IMPOSSIBLE_UNIFICATION,
|
||||
message=message,
|
||||
)
|
||||
return self.get_result(
|
||||
location,
|
||||
unified,
|
||||
positional,
|
||||
keywords,
|
||||
report_errors,
|
||||
)
|
||||
|
||||
case _:
|
||||
message: str = f"{callee} ({callee.__class__.__name__}) is not callable"
|
||||
if report_errors:
|
||||
self.reporter.error(location, message)
|
||||
return CallResult(
|
||||
error=CallError.NOT_CALLABLE,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def _unwrap_function(
|
||||
self,
|
||||
callee: Type,
|
||||
positional: list[TypedExpr[E]],
|
||||
keywords: dict[str, TypedExpr[E]],
|
||||
) -> Union[tuple[Function, None], tuple[None, CallError]]:
|
||||
match callee:
|
||||
case DerivedType(type=base):
|
||||
return self._unwrap_function(base, positional, keywords)
|
||||
|
||||
case GenericType():
|
||||
unifier: Unifier = Unifier(self.types)
|
||||
unified: Optional[Type] = unifier.unify_call(
|
||||
callee,
|
||||
[a[1] for a in positional],
|
||||
{k: v[1] for k, v in keywords.items()},
|
||||
)
|
||||
if unified is None:
|
||||
return None, CallError.IMPOSSIBLE_UNIFICATION
|
||||
return self._unwrap_function(unified, positional, keywords)
|
||||
|
||||
case Function():
|
||||
return callee, None
|
||||
|
||||
case AppliedType(body=body):
|
||||
return self._unwrap_function(body, positional, keywords)
|
||||
|
||||
case _:
|
||||
return None, CallError.NOT_CALLABLE
|
||||
|
||||
def _are_arguments_valid(
|
||||
self,
|
||||
arguments: list[MappedArgument[E]],
|
||||
report_errors: bool = True,
|
||||
) -> bool:
|
||||
"""Check whether the passed argument types correspond to their matched parameter definitions
|
||||
|
||||
Args:
|
||||
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
||||
"""
|
||||
valid: bool = True
|
||||
for arg in arguments:
|
||||
if not self.types.is_subtype(arg.arg_type, arg.parameter.type):
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg.arg_expr.location,
|
||||
f"Wrong type for argument '{arg.parameter.name}', expected {arg.parameter.type}, got {arg.arg_type}",
|
||||
)
|
||||
valid = False
|
||||
return valid
|
||||
|
||||
def _match_overload(
|
||||
self,
|
||||
overloads: list[Type],
|
||||
location: Location,
|
||||
positional: list[TypedExpr[E]],
|
||||
keywords: dict[str, TypedExpr[E]],
|
||||
report_errors: bool = True,
|
||||
) -> Union[tuple[Function, None], tuple[None, str]]:
|
||||
"""Try and resolve the appropriate overload for the given arguments
|
||||
|
||||
Args:
|
||||
overloads (list[Type]): the list of possible overloads
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional[Function]: the resolved function signature if it can be
|
||||
determined unambiguously, or `None`.
|
||||
"""
|
||||
candidates: list[OverloadCandidate] = []
|
||||
errors: list[CallError] = []
|
||||
for overload in overloads:
|
||||
function, unwrap_error = self._unwrap_function(
|
||||
overload, positional, keywords
|
||||
)
|
||||
if function is None:
|
||||
errors.append(unwrap_error) # type: ignore
|
||||
continue
|
||||
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function=function,
|
||||
location=location,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
report_errors=False,
|
||||
)
|
||||
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||
candidates.append(
|
||||
OverloadCandidate(
|
||||
function=function,
|
||||
mapped=mapped,
|
||||
)
|
||||
)
|
||||
|
||||
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||
kw_types: str = ", ".join(
|
||||
f"{name}: {type}" for name, (_, type) in keywords.items()
|
||||
)
|
||||
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||
|
||||
n_candidates: int = len(candidates)
|
||||
|
||||
# Exactly 1 match -> return it
|
||||
if n_candidates == 1:
|
||||
return candidates[0].function, None
|
||||
|
||||
# No match -> invalid call
|
||||
if n_candidates == 0:
|
||||
overloads_str: str = ", ".join(map(str, overloads))
|
||||
errors_str: str = ", ".join(errors)
|
||||
message: str = (
|
||||
f"No matching overload in [{overloads_str}] {for_args} (errors: {errors_str})"
|
||||
)
|
||||
if report_errors:
|
||||
self.reporter.error(location, message)
|
||||
return None, message
|
||||
|
||||
# Multiple matches -> see if one <: all others (more specific)
|
||||
for i1, c1 in enumerate(candidates):
|
||||
mapped1: list[MappedArgument[E]] = c1.mapped
|
||||
best_match: bool = True
|
||||
for i2, c2 in enumerate(candidates):
|
||||
if i1 == i2:
|
||||
continue
|
||||
mapped2: list[MappedArgument[E]] = c2.mapped
|
||||
if not self._are_mapped_subtypes(mapped1, mapped2):
|
||||
best_match = False
|
||||
break
|
||||
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
||||
if best_match:
|
||||
return c1.function, None
|
||||
|
||||
candidates_str: str = ", ".join(
|
||||
str(candidate.function) for candidate in candidates
|
||||
)
|
||||
message: str = f"Multiple matching overloads {for_args}: {candidates_str}"
|
||||
if report_errors:
|
||||
self.reporter.error(location, message)
|
||||
return None, message
|
||||
|
||||
def map_call_arguments(
|
||||
self,
|
||||
function: Function,
|
||||
location: Location,
|
||||
positional: list[TypedExpr[E]],
|
||||
keywords: dict[str, TypedExpr[E]],
|
||||
report_errors: bool = True,
|
||||
) -> tuple[bool, list[MappedArgument]]:
|
||||
"""Map call arguments to a function's parameters as defined in its signature
|
||||
|
||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||
with the arguments passed at the call site
|
||||
|
||||
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
||||
unless `report_errors` is set to `False`
|
||||
|
||||
Args:
|
||||
function (Function): the function definition
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
||||
the call is valid and the list of mapped arguments
|
||||
"""
|
||||
set_params: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
param.name
|
||||
for param in function.params.pos + function.params.mixed
|
||||
if param.required
|
||||
]
|
||||
required_keyword: list[str] = [
|
||||
param.name for param in function.params.kw if param.required
|
||||
]
|
||||
|
||||
mapped: list[MappedArgument[E]] = []
|
||||
|
||||
pos_params: list[Function.Parameter] = list(function.params.pos)
|
||||
mixed_params: list[Function.Parameter] = list(function.params.mixed)
|
||||
kw_params: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in function.params.kw
|
||||
}
|
||||
|
||||
valid_call: bool = True
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Parameter
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg[0].location, "Too many positional arguments"
|
||||
)
|
||||
valid_call = False
|
||||
break
|
||||
name: str = param.name
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_params.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
arg_expr=arg[0],
|
||||
arg_type=arg[1],
|
||||
parameter=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({param.name: param for param in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Parameter
|
||||
if name not in kw_params:
|
||||
if report_errors:
|
||||
if name in set_params:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Multiple values for parameter '{name}'"
|
||||
)
|
||||
else:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Unknown keyword parameter '{name}'"
|
||||
)
|
||||
valid_call = False
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_params.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
arg_expr=arg[0],
|
||||
arg_type=arg[1],
|
||||
parameter=param,
|
||||
)
|
||||
)
|
||||
|
||||
def join_params(params: list[str]) -> str:
|
||||
params = list(map(lambda p: f"'{p}'", params))
|
||||
if len(params) == 0:
|
||||
return ""
|
||||
if len(params) == 1:
|
||||
return params[0]
|
||||
return ", ".join(params[:-1]) + " and " + params[-1]
|
||||
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
params: str = join_params(required_positional)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required positional argument{plural}: {params}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
params: str = join_params(required_keyword)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required keyword argument{plural}: {params}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
return valid_call, mapped
|
||||
|
||||
def _are_mapped_subtypes(
|
||||
self, mapped1: list[MappedArgument[E]], mapped2: list[MappedArgument[E]]
|
||||
) -> bool:
|
||||
"""Check whether the given argument mappings are subtype/supertype of one another
|
||||
|
||||
This function checks whether the argument mappings `mapped1` are subtypes
|
||||
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
||||
of the corresponding parameter in `mapped2`, `False` is returned.
|
||||
|
||||
This is used to check whether a given overload is
|
||||
a more specific function/ a subtype of another.
|
||||
|
||||
Args:
|
||||
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
||||
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
||||
|
||||
Returns:
|
||||
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
||||
"""
|
||||
by_expr: dict[E, Type] = {}
|
||||
for arg in mapped1:
|
||||
by_expr[arg.arg_expr] = arg.parameter.type
|
||||
|
||||
for arg in mapped2:
|
||||
type2: Type = arg.parameter.type
|
||||
type1: Type = by_expr[arg.arg_expr]
|
||||
if not self.types.is_subtype(type1, type2):
|
||||
return False
|
||||
return True
|
||||
@@ -158,15 +158,17 @@ class Evaluator(m.Expr.Visitor[Any]):
|
||||
return res
|
||||
|
||||
def _map_args(self, function: Function, args: list[Any], kwargs: dict[str, Any]):
|
||||
positional: list[Function.Argument] = function.pos_args + function.args
|
||||
keywords: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.args + function.kw_args
|
||||
positional: list[Function.Parameter] = (
|
||||
function.params.pos + function.params.mixed
|
||||
)
|
||||
keywords: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in function.params.mixed + function.params.kw
|
||||
}
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
param: Function.Argument = positional[i]
|
||||
param: Function.Parameter = positional[i]
|
||||
self.set_value(param.name, arg)
|
||||
|
||||
for name, arg in kwargs.items():
|
||||
param: Function.Argument = keywords[name]
|
||||
param: Function.Parameter = keywords[name]
|
||||
self.set_value(param.name, arg)
|
||||
|
||||
210
midas/checker/frames/column_groupby_methods.py
Normal file
210
midas/checker/frames/column_groupby_methods.py
Normal file
@@ -0,0 +1,210 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
Function,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
Type,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
groupby: ColumnGroupBy
|
||||
groupby_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.groupby_expr, self.groupby)
|
||||
|
||||
|
||||
class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
NAMED_ARGS: dict[str, str] = {
|
||||
"numeric_only": "bool",
|
||||
"skipna": "bool",
|
||||
"engine": "str",
|
||||
"engine_kwargs": "dict",
|
||||
}
|
||||
|
||||
def _aggregate(
|
||||
self,
|
||||
call: Call,
|
||||
params: list[str | tuple[str, str, bool]] = [],
|
||||
*,
|
||||
preserve_inner_type: bool = False,
|
||||
) -> Type:
|
||||
real_params: list[Function.Parameter] = []
|
||||
for i, param in enumerate(params):
|
||||
match param:
|
||||
case str() as name:
|
||||
param = Function.Parameter(
|
||||
pos=i,
|
||||
name=name,
|
||||
type=self.types.get_type(self.NAMED_ARGS[name]),
|
||||
required=False,
|
||||
)
|
||||
case (name, type, required):
|
||||
param = Function.Parameter(
|
||||
pos=i,
|
||||
name=name,
|
||||
type=self.types.get_type(type),
|
||||
required=required,
|
||||
)
|
||||
real_params.append(param)
|
||||
|
||||
signature = Function(
|
||||
params=ParamSpec(mixed=real_params),
|
||||
returns=(
|
||||
call.groupby.column
|
||||
if preserve_inner_type
|
||||
else ColumnType(type=TopType())
|
||||
),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def kurt(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["skipna", "numeric_only"],
|
||||
)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["numeric_only", "skipna", "engine", "engine_kwargs"],
|
||||
)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["numeric_only", "skipna"],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def prod(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
(
|
||||
"ddof",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
"numeric_only",
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
(
|
||||
"var",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
"numeric_only",
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
78
midas/checker/frames/column_manager.py
Normal file
78
midas/checker/frames/column_manager.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frames.column_groupby_methods import Call as GroupByCall
|
||||
from midas.checker.frames.column_groupby_methods import ColumnGroupByMethodRegistry
|
||||
from midas.checker.frames.column_methods import Call, ColumnMethodRegistry
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import ColumnGroupBy, ColumnType, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
|
||||
|
||||
class ColumnManager:
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
self.method_resolver: ColumnMethodRegistry = ColumnMethodRegistry(self.typer)
|
||||
self.groupby_method_resolver: ColumnGroupByMethodRegistry = (
|
||||
ColumnGroupByMethodRegistry(self.typer)
|
||||
)
|
||||
|
||||
def call(
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
column: ColumnType,
|
||||
column_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: Call = Call(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
column=column,
|
||||
column_expr=column_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.method_resolver.call(method, call)
|
||||
|
||||
def groupby_call(
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
groupby: ColumnGroupBy,
|
||||
groupby_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: GroupByCall = GroupByCall(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
groupby=groupby,
|
||||
groupby_expr=groupby_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.groupby_method_resolver.call(method, call)
|
||||
|
||||
def get_attribute(self, column: ColumnType, name: str) -> Optional[Type]:
|
||||
types: TypesRegistry = self.typer.types
|
||||
match name:
|
||||
case "ndim" | "size":
|
||||
return types.get_type("int")
|
||||
|
||||
case "shape":
|
||||
return types.tuple_of("int")
|
||||
|
||||
case "T":
|
||||
return column
|
||||
|
||||
case _:
|
||||
return None
|
||||
400
midas/checker/frames/column_methods.py
Normal file
400
midas/checker/frames/column_methods.py
Normal file
@@ -0,0 +1,400 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
Function,
|
||||
GenericType,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
column: ColumnType
|
||||
column_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.column_expr, self.column)
|
||||
|
||||
|
||||
class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
def _element_binary_op(self, call: Call, method: str) -> ColumnType:
|
||||
"""Compute the result of an element-wise binary operation
|
||||
|
||||
This function delegates to the inner types for computing the resulting
|
||||
type.
|
||||
|
||||
Args:
|
||||
call (Call): the call that triggered this resolution
|
||||
method (str): the method name
|
||||
|
||||
Returns:
|
||||
ColumnType: the resulting column type
|
||||
"""
|
||||
column2: Optional[ColumnType] = None
|
||||
|
||||
col_type1: Type = call.column.type
|
||||
new_column: Type = ColumnType(type=UnknownType())
|
||||
if len(call.positional) != 0:
|
||||
other: Type = call.positional[0][1]
|
||||
unfolded_other: Type = unfold_type(other)
|
||||
if isinstance(unfolded_other, ColumnType):
|
||||
column2 = unfolded_other
|
||||
col_type2: Type = column2.type
|
||||
|
||||
new_inner_type = self.typer.result_of_binary_op(
|
||||
location=call.location,
|
||||
expr=call.call_expr,
|
||||
left=(call.column_expr, col_type1),
|
||||
right=(call.positional[0][0], col_type2),
|
||||
method=method,
|
||||
)
|
||||
new_column = ColumnType(type=new_inner_type)
|
||||
return new_column
|
||||
|
||||
def _element_wise(self, call: Call, method: str) -> Type:
|
||||
# TODO: support add with scalar
|
||||
|
||||
# Build signature with new column type and generic operand
|
||||
param_type: TypeVar = TypeVar(name="T", bound=None)
|
||||
signature = GenericType(
|
||||
name="add",
|
||||
params=[param_type],
|
||||
body=Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="other",
|
||||
type=ColumnType(type=param_type),
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
returns=self._element_binary_op(call, method),
|
||||
),
|
||||
)
|
||||
|
||||
# Map arguments and compute result type
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
if result.is_valid:
|
||||
self._assert_same_length(
|
||||
call.call_expr, call.column_expr, call.positional[0][0]
|
||||
)
|
||||
|
||||
return result.result
|
||||
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__add__")
|
||||
|
||||
@method("sub", "__sub__")
|
||||
def sub(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__sub__")
|
||||
|
||||
@method("mul", "__mul__")
|
||||
def mul(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mul__")
|
||||
|
||||
@method("div", "truediv", "__truediv__")
|
||||
def truediv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__truediv__")
|
||||
|
||||
@method("floordiv", "__floordiv__")
|
||||
def floordiv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__floordiv__")
|
||||
|
||||
@method("mod", "__mod__")
|
||||
def mod(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mod__")
|
||||
|
||||
@method("pow", "__pow__")
|
||||
def pow(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__pow__")
|
||||
|
||||
@method("lt", "__lt__")
|
||||
def lt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__lt__")
|
||||
|
||||
@method("gt", "__gt__")
|
||||
def gt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__gt__")
|
||||
|
||||
@method("le", "__le__")
|
||||
def le(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__le__")
|
||||
|
||||
@method("ge", "__ge__")
|
||||
def ge(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ge__")
|
||||
|
||||
@method("ne", "__ne__")
|
||||
def ne(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ne__")
|
||||
|
||||
@method("eq", "__eq__")
|
||||
def eq(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__eq__")
|
||||
|
||||
def _aggregate(
|
||||
self,
|
||||
call: Call,
|
||||
kwargs: list[Function.Parameter] = [],
|
||||
*,
|
||||
preserve_inner_type: bool = False,
|
||||
) -> Type:
|
||||
signature = Function(
|
||||
params=ParamSpec(
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
),
|
||||
returns=call.column if preserve_inner_type else ColumnType(type=TopType()),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method("kurtosis", "kurt")
|
||||
def kurtosis(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def mode(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method("product", "prod")
|
||||
def product(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="ddof",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="var",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def head(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
returns=call.column,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def tail(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
returns=call.column,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def groupby(self, call: Call) -> Type:
|
||||
bool_: Type = self.types.get_type("bool")
|
||||
function: Function = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="by",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="level",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=i + 2,
|
||||
name=name,
|
||||
type=bool_,
|
||||
required=False,
|
||||
)
|
||||
for i, name in enumerate(
|
||||
["as_index", "sort", "group_keys", "observed", "dropna"]
|
||||
)
|
||||
],
|
||||
),
|
||||
returns=ColumnGroupBy(column=call.column),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=function,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def _assert_same_length(self, call_expr: p.Expr, column1: p.Expr, column2: p.Expr):
|
||||
func_name: str = "__midas_column_same_length__"
|
||||
|
||||
# Efficiently compute length
|
||||
# https://stackoverflow.com/a/15943975/11109181
|
||||
def len_of_col(col: ast.expr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=ast.Name(id="len"),
|
||||
args=[
|
||||
ast.Attribute(
|
||||
value=col,
|
||||
attr="index",
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
|
||||
self.assertions.define(
|
||||
func_name,
|
||||
ast.FunctionDef(
|
||||
name=func_name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
ast.arg(arg="column1"),
|
||||
ast.arg(arg="column2"),
|
||||
],
|
||||
kwonlyargs=[],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
),
|
||||
body=[
|
||||
ast.Return(
|
||||
value=ast.Compare(
|
||||
left=len_of_col(ast.Name(id="column1")),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[
|
||||
len_of_col(ast.Name(id="column2")),
|
||||
],
|
||||
)
|
||||
)
|
||||
],
|
||||
decorator_list=[],
|
||||
),
|
||||
)
|
||||
self.assertions.add(
|
||||
bound_expr=call_expr,
|
||||
inputs=[column1, column2],
|
||||
builder=lambda c1, c2: ast.Call(
|
||||
func=ast.Name(id=func_name),
|
||||
args=[c1, c2],
|
||||
keywords=[],
|
||||
),
|
||||
message="Columns must have the same length",
|
||||
)
|
||||
103
midas/checker/frames/frame_groupby_methods.py
Normal file
103
midas/checker/frames/frame_groupby_methods.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
FrameGroupBy,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
groupby: FrameGroupBy
|
||||
groupby_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.groupby_expr, self.groupby)
|
||||
|
||||
|
||||
class FrameGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
NAMED_ARGS: dict[str, str] = {
|
||||
"numeric_only": "bool",
|
||||
"skipna": "bool",
|
||||
"engine": "str",
|
||||
"engine_kwargs": "dict",
|
||||
}
|
||||
|
||||
def _aggregate(self, call: Call, method: str) -> Type:
|
||||
new_columns: list[DataFrameType.Column] = []
|
||||
|
||||
for column in call.groupby.frame.columns:
|
||||
column_groupby: ColumnGroupBy = ColumnGroupBy(column=column.type)
|
||||
result_type: Type = self.typer.call_method(
|
||||
location=call.location,
|
||||
call_expr=call.call_expr,
|
||||
obj=(call.groupby_expr, column_groupby),
|
||||
method_name=method,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
if not isinstance(result_type, ColumnType):
|
||||
result_type = ColumnType(type=UnknownType())
|
||||
new_columns.append(
|
||||
DataFrameType.Column(
|
||||
index=column.index,
|
||||
name=column.name,
|
||||
type=result_type,
|
||||
)
|
||||
)
|
||||
|
||||
return DataFrameType(columns=new_columns)
|
||||
|
||||
@method()
|
||||
def kurt(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "kurt")
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "max")
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "mean")
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "median")
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "min")
|
||||
|
||||
@method()
|
||||
def prod(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "prod")
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "std")
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "sum")
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "var")
|
||||
255
midas/checker/frames/frame_manager.py
Normal file
255
midas/checker/frames/frame_manager.py
Normal file
@@ -0,0 +1,255 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, TypeGuard, cast
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frames.frame_groupby_methods import Call as GroupByCall
|
||||
from midas.checker.frames.frame_groupby_methods import FrameGroupByMethodRegistry
|
||||
from midas.checker.frames.frame_methods import Call, FrameMethodRegistry
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
FrameGroupBy,
|
||||
TupleType,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
|
||||
|
||||
def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
|
||||
return all(isinstance(expr, p.LiteralExpr) for expr in exprs)
|
||||
|
||||
|
||||
class FrameManager:
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
self.method_resolver: FrameMethodRegistry = FrameMethodRegistry(self.typer)
|
||||
self.groupby_method_resolver: FrameGroupByMethodRegistry = (
|
||||
FrameGroupByMethodRegistry(self.typer)
|
||||
)
|
||||
|
||||
def assign(
|
||||
self,
|
||||
reporter: FileReporter,
|
||||
location: Location,
|
||||
frame: DataFrameType,
|
||||
index: p.Expr,
|
||||
value_type: Type,
|
||||
) -> Type:
|
||||
match index:
|
||||
case p.LiteralExpr(value=str() as name):
|
||||
return self.assign_column(reporter, location, frame, name, value_type)
|
||||
|
||||
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
|
||||
isinstance(index.value, str) for index in indices
|
||||
):
|
||||
names: list[str] = [cast(str, index.value) for index in indices]
|
||||
|
||||
if not isinstance(value_type, TupleType):
|
||||
reporter.error(
|
||||
location,
|
||||
f"Cannot assign {type} to dataframe columns. Must be a tuple of columns",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
if len(names) != len(value_type.items):
|
||||
reporter.error(
|
||||
location,
|
||||
f"Wrong number of columns. Cannot assign {len(value_type.items)} to {len(names)} targets",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
new_frame: Type = frame
|
||||
for name, value in zip(names, value_type.items):
|
||||
new_frame = self.assign_column(
|
||||
reporter,
|
||||
location,
|
||||
new_frame,
|
||||
name,
|
||||
value,
|
||||
)
|
||||
if not isinstance(new_frame, DataFrameType):
|
||||
return new_frame
|
||||
return new_frame
|
||||
|
||||
case _:
|
||||
reporter.error(
|
||||
location, f"Invalid index type {index} on {frame} (assignment)"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def assign_column(
|
||||
self,
|
||||
reporter: FileReporter,
|
||||
location: Location,
|
||||
frame: DataFrameType,
|
||||
name: str,
|
||||
type: Type,
|
||||
) -> Type:
|
||||
if not isinstance(type, ColumnType):
|
||||
reporter.error(
|
||||
location,
|
||||
f"Cannot assign {type} to dataframe column. Must be a ColumnType",
|
||||
)
|
||||
return frame
|
||||
return self._set_column(frame, name, type)
|
||||
|
||||
def get(
|
||||
self,
|
||||
reporter: FileReporter,
|
||||
location: Location,
|
||||
frame: DataFrameType,
|
||||
index: p.Expr,
|
||||
) -> Type:
|
||||
match index:
|
||||
case p.LiteralExpr(value=str() as name):
|
||||
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
|
||||
if column is None:
|
||||
reporter.error(location, f"Unknown column '{name}' on {frame}")
|
||||
return UnknownType()
|
||||
return column
|
||||
|
||||
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
|
||||
isinstance(index.value, str) for index in indices
|
||||
):
|
||||
names: list[str] = [cast(str, index.value) for index in indices]
|
||||
columns: list[ColumnType] = []
|
||||
for name in names:
|
||||
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
|
||||
if column is None:
|
||||
reporter.error(location, f"Unknown column '{name}' on {frame}")
|
||||
return UnknownType()
|
||||
columns.append(column)
|
||||
return TupleType(items=tuple(columns))
|
||||
|
||||
case _:
|
||||
reporter.error(
|
||||
location, f"Invalid index type {index} on {frame} (access)"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def groupby_get(
|
||||
self,
|
||||
reporter: FileReporter,
|
||||
location: Location,
|
||||
groupby: FrameGroupBy,
|
||||
index: p.Expr,
|
||||
) -> Type:
|
||||
result: Type = self.get(reporter, location, groupby.frame, index)
|
||||
match result:
|
||||
case ColumnType():
|
||||
result = ColumnGroupBy(column=result)
|
||||
case TupleType(items=columns):
|
||||
result = TupleType(
|
||||
items=tuple(
|
||||
ColumnGroupBy(column=cast(ColumnType, column))
|
||||
for column in columns
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _set_column(
|
||||
cls, frame: DataFrameType, name: str, column: ColumnType
|
||||
) -> DataFrameType:
|
||||
new_columns: list[DataFrameType.Column] = []
|
||||
index: int = len(frame.columns)
|
||||
replace: bool = False
|
||||
for i, col in enumerate(frame.columns):
|
||||
if col.name == name:
|
||||
index = i
|
||||
replace = True
|
||||
# TODO: check column type here to prevent changing it
|
||||
new_columns.append(col)
|
||||
|
||||
new_col: DataFrameType.Column = DataFrameType.Column(
|
||||
index=index,
|
||||
name=name,
|
||||
type=column,
|
||||
)
|
||||
if replace:
|
||||
new_columns[index] = new_col
|
||||
else:
|
||||
new_columns.append(new_col)
|
||||
|
||||
return DataFrameType(columns=new_columns)
|
||||
|
||||
@classmethod
|
||||
def _set_columns(
|
||||
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
|
||||
) -> DataFrameType:
|
||||
for name, col in zip(names, columns):
|
||||
frame = cls._set_column(frame, name, col)
|
||||
return frame
|
||||
|
||||
@classmethod
|
||||
def _get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
|
||||
for col in frame.columns:
|
||||
if col.name == name:
|
||||
return col.type
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_columns(
|
||||
cls, frame: DataFrameType, names: list[str]
|
||||
) -> list[Optional[ColumnType]]:
|
||||
return [cls._get_column(frame, name) for name in names]
|
||||
|
||||
def call(
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
frame: DataFrameType,
|
||||
frame_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: Call = Call(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
frame=frame,
|
||||
frame_expr=frame_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.method_resolver.call(method, call)
|
||||
|
||||
def groupby_call(
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
groupby: FrameGroupBy,
|
||||
groupby_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: GroupByCall = GroupByCall(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
groupby=groupby,
|
||||
groupby_expr=groupby_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.groupby_method_resolver.call(method, call)
|
||||
|
||||
def get_attribute(self, frame: DataFrameType, name: str) -> Optional[Type]:
|
||||
types: TypesRegistry = self.typer.types
|
||||
match name:
|
||||
case "ndim" | "size":
|
||||
return types.get_type("int")
|
||||
|
||||
case "shape":
|
||||
return types.tuple_of("int", "int")
|
||||
|
||||
case _:
|
||||
return None
|
||||
479
midas/checker/frames/frame_methods.py
Normal file
479
midas/checker/frames/frame_methods.py
Normal file
@@ -0,0 +1,479 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import (
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
OverloadedFunction,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
Type,
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
frame: DataFrameType
|
||||
frame_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.frame_expr, self.frame)
|
||||
|
||||
|
||||
class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
def _get_method_result(
|
||||
self,
|
||||
call: Call,
|
||||
column1: ColumnType,
|
||||
column2: ColumnType,
|
||||
method: str,
|
||||
) -> ColumnType:
|
||||
"""Get the result of calling a method on a column, passing a second
|
||||
|
||||
This function delegates to the main typer the resolution of the method
|
||||
member, as well as computing the result type. Because we don't have any
|
||||
AST expression for the individual columns, the frame expressions are
|
||||
used instead.
|
||||
|
||||
Args:
|
||||
call (Call): the call that triggered this resolution
|
||||
column1 (ColumnType): the first column, i.e. left operand
|
||||
column2 (ColumnType): the second column, i.e. right operand
|
||||
method (str): the method name
|
||||
|
||||
Returns:
|
||||
ColumnType: the resulting column.
|
||||
If the operation is invalid / doesn't exist,
|
||||
`ColumnType(type=UnknownType())` is returned
|
||||
"""
|
||||
|
||||
result: Type = self.typer.result_of_binary_op(
|
||||
location=call.location,
|
||||
expr=call.call_expr,
|
||||
left=(call.frame_expr, column1),
|
||||
right=(call.positional[0][0], column2),
|
||||
method=method,
|
||||
)
|
||||
|
||||
if not isinstance(result, ColumnType):
|
||||
return ColumnType(type=UnknownType())
|
||||
return result
|
||||
|
||||
def _element_binary_op(self, call: Call, method: str) -> DataFrameType:
|
||||
"""Compute the result of an element-wise binary operation
|
||||
|
||||
This function delegates to the matching columns for computing resulting
|
||||
types. Any column only present in one of the frames is forwarded as a
|
||||
generic `ColumnType(type=UnknownType())`. Columns only in the second
|
||||
frame are append at the end of the schema.
|
||||
|
||||
Args:
|
||||
call (Call): the call that triggered this resolution
|
||||
method (str): the method name
|
||||
|
||||
Returns:
|
||||
DataFrameType: the resulting frame type
|
||||
"""
|
||||
new_columns: list[DataFrameType.Column] = []
|
||||
|
||||
by_name: dict[str, DataFrameType.Column] = {}
|
||||
frame2: Optional[DataFrameType] = None
|
||||
# Get map of operand's columns by name, if there is at least 1 operand, which is a dataframe
|
||||
if len(call.positional) != 0:
|
||||
operand: TypedExpr = call.positional[0]
|
||||
unfolded_other: Type = unfold_type(operand[1])
|
||||
if isinstance(unfolded_other, DataFrameType):
|
||||
frame2 = unfolded_other
|
||||
by_name = {
|
||||
col.name: col for col in frame2.columns if col.name is not None
|
||||
}
|
||||
|
||||
# Compute new schema:
|
||||
# Step 1: for all columns in frame1:
|
||||
# - if present in frame2 -> delegate operation to columns
|
||||
# - if not -> add to schema as unknown
|
||||
in_frame1: set[str] = set()
|
||||
for column in call.frame.columns:
|
||||
if column.name is not None:
|
||||
in_frame1.add(column.name)
|
||||
|
||||
col_type1: ColumnType = column.type
|
||||
col_type: ColumnType = ColumnType(type=UnknownType())
|
||||
if column.name in by_name:
|
||||
column2 = by_name[column.name]
|
||||
col_type2: ColumnType = column2.type
|
||||
|
||||
col_type = self._get_method_result(call, col_type1, col_type2, method)
|
||||
|
||||
new_column = DataFrameType.Column(
|
||||
index=column.index,
|
||||
name=column.name,
|
||||
type=col_type,
|
||||
)
|
||||
new_columns.append(new_column)
|
||||
|
||||
# Step 2: for all columns in frame2
|
||||
# - if not in frame1 -> add to schema as unknown
|
||||
if frame2 is not None:
|
||||
for column in frame2.columns:
|
||||
if column.name in in_frame1:
|
||||
continue
|
||||
new_columns.append(
|
||||
DataFrameType.Column(
|
||||
index=len(new_columns),
|
||||
name=column.name,
|
||||
type=ColumnType(type=UnknownType()),
|
||||
)
|
||||
)
|
||||
|
||||
return DataFrameType(columns=new_columns)
|
||||
|
||||
def _element_wise(self, call: Call, method: str) -> Type:
|
||||
# TODO: support scalar, sequence, Series, dict operand
|
||||
# Build signature with new schema and generic operand
|
||||
signature = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="other",
|
||||
type=DataFrameType(columns=[]),
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
returns=self._element_binary_op(call, method),
|
||||
)
|
||||
|
||||
# Map arguments and compute result type
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
if result.is_valid:
|
||||
self._assert_same_length(
|
||||
call.call_expr, call.frame_expr, call.positional[0][0]
|
||||
)
|
||||
|
||||
return result.result
|
||||
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__add__")
|
||||
|
||||
@method("sub", "__sub__")
|
||||
def sub(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__sub__")
|
||||
|
||||
@method("mul", "__mul__")
|
||||
def mul(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mul__")
|
||||
|
||||
@method("div", "truediv", "__truediv__")
|
||||
def truediv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__truediv__")
|
||||
|
||||
@method("floordiv", "__floordiv__")
|
||||
def floordiv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__floordiv__")
|
||||
|
||||
@method("mod", "__mod__")
|
||||
def mod(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mod__")
|
||||
|
||||
@method("pow", "__pow__")
|
||||
def pow(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__pow__")
|
||||
|
||||
@method("lt", "__lt__")
|
||||
def lt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__lt__")
|
||||
|
||||
@method("gt", "__gt__")
|
||||
def gt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__gt__")
|
||||
|
||||
@method("le", "__le__")
|
||||
def le(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__le__")
|
||||
|
||||
@method("ge", "__ge__")
|
||||
def ge(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ge__")
|
||||
|
||||
@method("ne", "__ne__")
|
||||
def ne(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ne__")
|
||||
|
||||
@method("eq", "__eq__")
|
||||
def eq(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__eq__")
|
||||
|
||||
def _aggregate(self, call: Call, kwargs: list[Function.Parameter] = []) -> Type:
|
||||
with_axis = Function(
|
||||
params=ParamSpec(
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
),
|
||||
returns=ColumnType(type=TopType()),
|
||||
)
|
||||
without_axis = Function(
|
||||
params=ParamSpec(
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=self.types.get_type("None"),
|
||||
required=True,
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
),
|
||||
returns=TopType(),
|
||||
)
|
||||
overload = OverloadedFunction(
|
||||
overloads=[
|
||||
with_axis,
|
||||
without_axis,
|
||||
]
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=overload,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method("kurtosis", "kurt")
|
||||
def kurtosis(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def mode(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method("product", "prod")
|
||||
def product(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="ddof",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="var",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def head(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
returns=call.frame,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def tail(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
returns=call.frame,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def groupby(self, call: Call) -> Type:
|
||||
bool_: Type = self.types.get_type("bool")
|
||||
function: Function = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="by",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="level",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=i + 2,
|
||||
name=name,
|
||||
type=bool_,
|
||||
required=False,
|
||||
)
|
||||
for i, name in enumerate(
|
||||
["as_index", "sort", "group_keys", "observed", "dropna"]
|
||||
)
|
||||
],
|
||||
),
|
||||
returns=FrameGroupBy(frame=call.frame),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=function,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr):
|
||||
func_name: str = "__midas_frame_same_length__"
|
||||
|
||||
# Efficiently compute length
|
||||
# https://stackoverflow.com/a/15943975/11109181
|
||||
def len_of_df(df: ast.expr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=ast.Name(id="len"),
|
||||
args=[
|
||||
ast.Attribute(
|
||||
value=df,
|
||||
attr="index",
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
|
||||
self.assertions.define(
|
||||
func_name,
|
||||
ast.FunctionDef(
|
||||
name=func_name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
ast.arg(arg="frame1"),
|
||||
ast.arg(arg="frame2"),
|
||||
],
|
||||
kwonlyargs=[],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
),
|
||||
body=[
|
||||
ast.Return(
|
||||
value=ast.Compare(
|
||||
left=len_of_df(ast.Name(id="frame1")),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[len_of_df(ast.Name(id="frame2"))],
|
||||
)
|
||||
)
|
||||
],
|
||||
decorator_list=[],
|
||||
),
|
||||
)
|
||||
self.assertions.add(
|
||||
bound_expr=call_expr,
|
||||
inputs=[frame1, frame2],
|
||||
builder=lambda f1, f2: ast.Call(
|
||||
func=ast.Name(id=func_name),
|
||||
args=[f1, f2],
|
||||
keywords=[],
|
||||
),
|
||||
message="DataFrames must have the same length",
|
||||
)
|
||||
100
midas/checker/frames/utils.py
Normal file
100
midas/checker/frames/utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Optional,
|
||||
Protocol,
|
||||
Self,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallDispatcher
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import Type, UnknownType
|
||||
from midas.generator.collector import AssertionCollector
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
|
||||
|
||||
class _MethodRegistryMeta(type):
|
||||
_methods: dict[str, Callable[..., Type]] = {}
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
name: str,
|
||||
bases: tuple[type, ...],
|
||||
namespace: dict[str, Any],
|
||||
):
|
||||
new_class = super().__new__(cls, name, bases, namespace)
|
||||
new_class._methods = {}
|
||||
for attr in namespace.values():
|
||||
if callable(attr) and hasattr(attr, "__method_names__"):
|
||||
for name in attr.__method_names__: # type: ignore
|
||||
new_class._methods[name] = attr # type: ignore
|
||||
return new_class
|
||||
|
||||
|
||||
class MethodCall(Protocol):
|
||||
@property
|
||||
def location(self) -> Location: ...
|
||||
|
||||
@property
|
||||
def call_expr(self) -> p.Expr: ...
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=MethodCall)
|
||||
|
||||
|
||||
class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
|
||||
@property
|
||||
def reporter(self) -> FileReporter:
|
||||
return self.typer.reporter
|
||||
|
||||
@property
|
||||
def types(self) -> TypesRegistry:
|
||||
return self.typer.types
|
||||
|
||||
@property
|
||||
def dispatcher(self) -> CallDispatcher[p.Expr]:
|
||||
return self.typer.dispatcher
|
||||
|
||||
@property
|
||||
def assertions(self) -> AssertionCollector:
|
||||
return self.typer.assertions
|
||||
|
||||
def call(self, method: str, call: T) -> Type:
|
||||
func: Optional[Callable[[Self, T], Type]] = self._methods.get(method)
|
||||
if func is None:
|
||||
self.reporter.warning(
|
||||
call.location, f"Unknown method {method} on {call.subject[1]}"
|
||||
)
|
||||
return UnknownType()
|
||||
return func(self, call)
|
||||
|
||||
|
||||
_Self = TypeVar("_Self", bound=MethodRegistry[Any])
|
||||
Method = Callable[[_Self, T], Type]
|
||||
|
||||
|
||||
def method(*names: str) -> Callable[[Method[_Self, T]], Method[_Self, T]]:
|
||||
def wrapper(func: Method[_Self, T]) -> Method[_Self, T]:
|
||||
names_: tuple[str, ...] = names
|
||||
if len(names_) == 0:
|
||||
names_ = (func.__name__,)
|
||||
setattr(func, "__method_names__", names_)
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
@@ -6,25 +6,26 @@ from typing import Optional
|
||||
import midas.ast.midas as m
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.builtins import define_builtins
|
||||
from midas.checker.dispatcher import CallDispatcher, CallResult
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
|
||||
from midas.checker.preamble import Preamble
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter, Reporter
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
ColumnType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
DataFrameType,
|
||||
DerivedType,
|
||||
ExtensionType,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
ParamSpec,
|
||||
Predicate,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
from midas.checker.variance import VarianceInferrer
|
||||
from midas.lexer.midas import MidasLexer
|
||||
@@ -32,16 +33,6 @@ from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TypedParamSpec:
|
||||
pos: list[Function.Argument]
|
||||
mixed: list[Function.Argument]
|
||||
kw: list[Function.Argument]
|
||||
|
||||
|
||||
TypedExpr = tuple[m.Expr, Type]
|
||||
|
||||
|
||||
class ReturnException(Exception):
|
||||
pass
|
||||
|
||||
@@ -50,7 +41,7 @@ class ReturnException(Exception):
|
||||
class MappedArgument:
|
||||
expr: m.Expr
|
||||
type: Type
|
||||
argument: Function.Argument
|
||||
argument: Function.Parameter
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
@@ -65,8 +56,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||
self.logger: logging.Logger = logging.getLogger("MidasTyper")
|
||||
self.reporter: FileReporter = reporter.for_file(None)
|
||||
|
||||
self.types: TypesRegistry = types
|
||||
self.dispatcher: CallDispatcher[m.Expr] = CallDispatcher[m.Expr](
|
||||
self.types, self.reporter
|
||||
)
|
||||
|
||||
self._local_variables: dict[str, TypeVar] = {}
|
||||
|
||||
self._predicate_params: dict[str, Type] = {}
|
||||
@@ -81,8 +75,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
|
||||
self._preamble: Environment = Preamble(self.types)
|
||||
|
||||
def set_reporter(self, reporter: FileReporter):
|
||||
self.reporter = reporter
|
||||
self.dispatcher.set_reporter(reporter)
|
||||
|
||||
def process(self, source: str, path: Optional[str]):
|
||||
self.reporter = self.reporter.for_file(path)
|
||||
reporter: FileReporter = self.reporter.for_file(path)
|
||||
self.set_reporter(reporter)
|
||||
|
||||
lexer: MidasLexer = MidasLexer(source)
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser: MidasParser = MidasParser(tokens)
|
||||
@@ -190,9 +190,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
self._predicate_params[param.name.lexeme] = param.type.accept(self)
|
||||
|
||||
type: Type = self.type_of(stmt.body)
|
||||
params: list[TypedParamSpec] = [
|
||||
self._visit_param_spec(spec) for spec in stmt.params
|
||||
]
|
||||
params: list[ParamSpec] = [self._visit_param_spec(spec) for spec in stmt.params]
|
||||
|
||||
if not self._is_valid_predicate(type):
|
||||
self.reporter.error(
|
||||
@@ -203,9 +201,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
type = self._bool
|
||||
for spec in reversed(params):
|
||||
type = Function(
|
||||
pos_args=spec.pos,
|
||||
args=spec.mixed,
|
||||
kw_args=spec.kw,
|
||||
params=spec,
|
||||
returns=type,
|
||||
)
|
||||
self._predicate_params = {}
|
||||
@@ -257,13 +253,13 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
result: Optional[Type] = self._get_call_result(
|
||||
location,
|
||||
operation,
|
||||
[(right_expr, right)],
|
||||
{},
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=location,
|
||||
callee=operation,
|
||||
positional=[(right_expr, right)],
|
||||
keywords={},
|
||||
)
|
||||
return result or UnknownType()
|
||||
return result.result
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
|
||||
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
|
||||
@@ -283,31 +279,29 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
result: Optional[Type] = self._get_call_result(
|
||||
expr.location,
|
||||
operation,
|
||||
[],
|
||||
{},
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=expr.location,
|
||||
callee=operation,
|
||||
positional=[],
|
||||
keywords={},
|
||||
)
|
||||
return result or UnknownType()
|
||||
return result.result
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> Type:
|
||||
callee: Type = expr.callee.accept(self)
|
||||
positional: list[TypedExpr] = [
|
||||
positional: list[tuple[m.Expr, Type]] = [
|
||||
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||
]
|
||||
keywords: dict[str, TypedExpr] = {
|
||||
keywords: dict[str, tuple[m.Expr, Type]] = {
|
||||
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
||||
}
|
||||
return (
|
||||
self._get_call_result(
|
||||
expr.location,
|
||||
callee,
|
||||
positional,
|
||||
keywords,
|
||||
)
|
||||
or UnknownType()
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=expr.location,
|
||||
callee=callee,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> Type:
|
||||
object: Type = expr.expr.accept(self)
|
||||
@@ -382,30 +376,46 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> Type:
|
||||
params: TypedParamSpec = self._visit_param_spec(type.params)
|
||||
return Function(
|
||||
pos_args=params.pos,
|
||||
args=params.mixed,
|
||||
kw_args=params.kw,
|
||||
params=self._visit_param_spec(type.params),
|
||||
returns=type.returns.accept(self),
|
||||
)
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec:
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> ParamSpec:
|
||||
n_pos: int = len(spec.pos)
|
||||
n_mixed: int = len(spec.mixed)
|
||||
|
||||
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
|
||||
return Function.Argument(
|
||||
def process_param(
|
||||
param: m.FunctionType.Parameter, i: int
|
||||
) -> Function.Parameter:
|
||||
return Function.Parameter(
|
||||
pos=i,
|
||||
name=arg.name.lexeme if arg.name is not None else str(i),
|
||||
type=arg.type.accept(self),
|
||||
required=arg.required,
|
||||
name=param.name.lexeme if param.name is not None else str(i),
|
||||
type=param.type.accept(self),
|
||||
required=param.required,
|
||||
)
|
||||
|
||||
return TypedParamSpec(
|
||||
pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)],
|
||||
mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)],
|
||||
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
|
||||
return ParamSpec(
|
||||
pos=[process_param(param, i) for i, param in enumerate(spec.pos)],
|
||||
mixed=[
|
||||
process_param(param, i + n_pos) for i, param in enumerate(spec.mixed)
|
||||
],
|
||||
kw=[
|
||||
process_param(param, i + n_pos + n_mixed)
|
||||
for i, param in enumerate(spec.kw)
|
||||
],
|
||||
)
|
||||
|
||||
def visit_frame_type(self, type: m.FrameType) -> Type:
|
||||
def process_column(i: int, col: m.FrameType.Column) -> DataFrameType.Column:
|
||||
return DataFrameType.Column(
|
||||
index=i,
|
||||
name=col.name.lexeme,
|
||||
type=ColumnType(type=col.type.accept(self)),
|
||||
)
|
||||
|
||||
return DataFrameType(
|
||||
columns=[process_column(i, col) for i, col in enumerate(type.columns)]
|
||||
)
|
||||
|
||||
def _resolve_type_params(self, params: list[m.TypeParam]):
|
||||
@@ -419,343 +429,3 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
self._local_variables[name] = var
|
||||
vars.append(var)
|
||||
return vars
|
||||
|
||||
def _get_call_result(
|
||||
self,
|
||||
location: Location,
|
||||
callee: Type,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
report_errors: bool = True,
|
||||
) -> Optional[Type]:
|
||||
"""Get the result type of a function call
|
||||
|
||||
If the function has overloads, the function will try to resolve the
|
||||
appropriate signature.
|
||||
Argument types are matched to the defined parameters.
|
||||
The function doesn't take the raw expression as a parameter to accommodate
|
||||
for desugared calls such as for operators.
|
||||
|
||||
Args:
|
||||
location (Location): the call location
|
||||
callee (Type): the called function
|
||||
positional (list[TypedExpr]): the list positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Type: the return type of the call, or `None` if either
|
||||
the call is invalid or no overload matched the arguments uniquely
|
||||
"""
|
||||
match callee:
|
||||
case Function() as function:
|
||||
valid: bool
|
||||
mapped: list[MappedArgument]
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function, location, positional, keywords
|
||||
)
|
||||
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
||||
if not valid:
|
||||
return None
|
||||
return function.returns
|
||||
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
function = self._match_overload(
|
||||
overloads, location, positional, keywords, report_errors
|
||||
)
|
||||
if function is None:
|
||||
return None
|
||||
return function.returns
|
||||
|
||||
case AppliedType(body=body):
|
||||
return self._get_call_result(
|
||||
location, body, positional, keywords, report_errors
|
||||
)
|
||||
|
||||
case UnknownType():
|
||||
return UnknownType()
|
||||
|
||||
case _:
|
||||
if report_errors:
|
||||
self.reporter.error(location, f"{callee} is not callable")
|
||||
return None
|
||||
|
||||
def _are_arguments_valid(
|
||||
self,
|
||||
arguments: list[MappedArgument],
|
||||
report_errors: bool = True,
|
||||
) -> bool:
|
||||
"""Check whether the passed argument types correspond to their matched parameter definitions
|
||||
|
||||
Args:
|
||||
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
||||
"""
|
||||
valid: bool = True
|
||||
for arg in arguments:
|
||||
if not self.types.is_subtype(arg.type, arg.argument.type):
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
)
|
||||
valid = False
|
||||
return valid
|
||||
|
||||
def _match_overload(
|
||||
self,
|
||||
overloads: list[Type],
|
||||
location: Location,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
report_errors: bool = True,
|
||||
) -> Optional[Function]:
|
||||
"""Try and resolve the appropriate overload for the given arguments
|
||||
|
||||
Args:
|
||||
overloads (list[Type]): the list of possible overloads
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional[Function]: the resolved function signature if it can be
|
||||
determined unambiguously, or `None`.
|
||||
"""
|
||||
candidates: list[OverloadCandidate] = []
|
||||
for overload in overloads:
|
||||
function: Type = unfold_type(overload)
|
||||
if not isinstance(function, Function):
|
||||
if report_errors:
|
||||
self.logger.error(
|
||||
f"Overload is not a function: {overload} is {function}"
|
||||
)
|
||||
continue
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function=function,
|
||||
location=location,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
report_errors=False,
|
||||
)
|
||||
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||
candidates.append(
|
||||
OverloadCandidate(
|
||||
function=function,
|
||||
mapped=mapped,
|
||||
)
|
||||
)
|
||||
|
||||
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||
kw_types: str = ", ".join(
|
||||
f"{name}: {type}" for name, (_, type) in keywords.items()
|
||||
)
|
||||
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||
|
||||
n_candidates: int = len(candidates)
|
||||
|
||||
# Exactly 1 match -> return it
|
||||
if n_candidates == 1:
|
||||
return candidates[0].function
|
||||
|
||||
# No match -> invalid call
|
||||
if n_candidates == 0:
|
||||
overloads_str: str = ", ".join(map(str, overloads))
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"No matching overload in [{overloads_str}] {for_args}",
|
||||
)
|
||||
return None
|
||||
|
||||
# Multiple matches -> see if one <: all others (more specific)
|
||||
for i1, c1 in enumerate(candidates):
|
||||
mapped1: list[MappedArgument] = c1.mapped
|
||||
best_match: bool = True
|
||||
for i2, c2 in enumerate(candidates):
|
||||
if i1 == i2:
|
||||
continue
|
||||
mapped2: list[MappedArgument] = c2.mapped
|
||||
if not self._are_mapped_subtypes(mapped1, mapped2):
|
||||
best_match = False
|
||||
break
|
||||
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
||||
if best_match:
|
||||
return c1.function
|
||||
|
||||
candidates_str: str = ", ".join(
|
||||
str(candidate.function) for candidate in candidates
|
||||
)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Multiple matching overloads {for_args}: {candidates_str}",
|
||||
)
|
||||
return None
|
||||
|
||||
def map_call_arguments(
|
||||
self,
|
||||
function: Function,
|
||||
location: Location,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
report_errors: bool = True,
|
||||
) -> tuple[bool, list[MappedArgument]]:
|
||||
"""Map call arguments to a function's parameters as defined in its signature
|
||||
|
||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||
with the arguments passed at the call site
|
||||
|
||||
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
||||
unless `report_errors` is set to `False`
|
||||
|
||||
Args:
|
||||
function (Function): the function definition
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
||||
the call is valid and the list of mapped arguments
|
||||
"""
|
||||
set_args: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
arg.name for arg in function.pos_args + function.args if arg.required
|
||||
]
|
||||
required_keyword: list[str] = [
|
||||
arg.name for arg in function.kw_args if arg.required
|
||||
]
|
||||
|
||||
mapped: list[MappedArgument] = []
|
||||
|
||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||
mixed_params: list[Function.Argument] = list(function.args)
|
||||
kw_params: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.kw_args
|
||||
}
|
||||
|
||||
valid_call: bool = True
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Argument
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg[0].location, "Too many positional arguments"
|
||||
)
|
||||
valid_call = False
|
||||
break
|
||||
name: str = param.name
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Argument
|
||||
if name not in kw_params:
|
||||
if report_errors:
|
||||
if name in set_args:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
)
|
||||
else:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||
)
|
||||
valid_call = False
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
def join_args(args: list[str]) -> str:
|
||||
args = list(map(lambda a: f"'{a}'", args))
|
||||
if len(args) == 0:
|
||||
return ""
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
args: str = join_args(required_positional)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required positional argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
args: str = join_args(required_keyword)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required keyword argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
return valid_call, mapped
|
||||
|
||||
def _are_mapped_subtypes(
|
||||
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
|
||||
) -> bool:
|
||||
"""Check whether the given argument mappings are subtype/supertype of one another
|
||||
|
||||
This function checks whether the argument mappings `mapped1` are subtypes
|
||||
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
||||
of the corresponding parameter in `mapped2`, `False` is returned.
|
||||
|
||||
This is used to check whether a given overload is
|
||||
a more specific function/ a subtype of another.
|
||||
|
||||
Args:
|
||||
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
||||
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
||||
|
||||
Returns:
|
||||
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
||||
"""
|
||||
by_expr: dict[m.Expr, Type] = {}
|
||||
for arg in mapped1:
|
||||
by_expr[arg.expr] = arg.argument.type
|
||||
|
||||
for arg in mapped2:
|
||||
type2: Type = arg.argument.type
|
||||
type1: Type = by_expr[arg.expr]
|
||||
if not self.types.is_subtype(type1, type2):
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -41,7 +41,7 @@ PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = {
|
||||
|
||||
|
||||
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
|
||||
# TokenType.PLUS: "__add__",
|
||||
TokenType.PLUS: "__add__",
|
||||
TokenType.MINUS: "__sub__",
|
||||
TokenType.STAR: "__mul__",
|
||||
TokenType.SLASH: "__truediv__",
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import Function, GenericType, TopType, Type, TypeVar, UnitType
|
||||
from midas.checker.types import (
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -17,7 +26,7 @@ class Preamble(Environment):
|
||||
def __init__(self, types: TypesRegistry) -> None:
|
||||
super().__init__()
|
||||
self._types: TypesRegistry = types
|
||||
self._python_funcs: dict[str, Callable] = {}
|
||||
self._python_funcs: dict[str, Callable[..., Any]] = {}
|
||||
|
||||
self._def_type_constructor("object", object)
|
||||
self._def_type_constructor("float", float)
|
||||
@@ -34,7 +43,7 @@ class Preamble(Environment):
|
||||
# TODO: use sink
|
||||
self._def_function(
|
||||
name="print",
|
||||
pos=[Param("object", TopType())],
|
||||
pos=[Param("object", TopType(), required=False)],
|
||||
returns=UnitType(),
|
||||
py_function=print,
|
||||
)
|
||||
@@ -64,11 +73,48 @@ class Preamble(Environment):
|
||||
pos=[Param("prompt", TopType(), required=False)],
|
||||
returns=self._types.get_type("str"),
|
||||
)
|
||||
self._def_function(
|
||||
name="len",
|
||||
pos=[Param("object", TopType())],
|
||||
returns=self._types.get_type("int"),
|
||||
)
|
||||
|
||||
def _list_of(self, item_type: Type) -> Type:
|
||||
return self._types.apply_generic(self._types.get_type("list"), [item_type])
|
||||
T = TypeVar(name="T", bound=None)
|
||||
self._def_overloads(
|
||||
name="max",
|
||||
py_function=max,
|
||||
signatures=[
|
||||
(
|
||||
[Param("arg1", T), Param("arg2", T)],
|
||||
[],
|
||||
[],
|
||||
T,
|
||||
[T],
|
||||
),
|
||||
([Param("iterable", self._list_of(T))], [], [], T, [T]),
|
||||
],
|
||||
)
|
||||
self._def_overloads(
|
||||
name="min",
|
||||
py_function=min,
|
||||
signatures=[
|
||||
(
|
||||
[Param("arg1", T), Param("arg2", T)],
|
||||
[],
|
||||
[],
|
||||
T,
|
||||
[T],
|
||||
),
|
||||
([Param("iterable", self._list_of(T))], [], [], T, [T]),
|
||||
],
|
||||
)
|
||||
|
||||
def _def_type_constructor(self, name: str, py_function: Optional[Callable] = None):
|
||||
def _list_of(self, item_type: str | Type) -> Type:
|
||||
return self._types.list_of(item_type)
|
||||
|
||||
def _def_type_constructor(
|
||||
self, name: str, py_function: Optional[Callable[..., Any]] = None
|
||||
):
|
||||
# TODO: more specific arg types
|
||||
self._def_function(
|
||||
name=name,
|
||||
@@ -87,9 +133,9 @@ class Preamble(Environment):
|
||||
returns: Type = UnitType(),
|
||||
type_vars: list[TypeVar] = [],
|
||||
) -> Type:
|
||||
def map_args(params: list[Param], offset: int) -> list[Function.Argument]:
|
||||
def map_params(params: list[Param], offset: int) -> list[Function.Parameter]:
|
||||
return [
|
||||
Function.Argument(
|
||||
Function.Parameter(
|
||||
pos=i + offset,
|
||||
name=param.name,
|
||||
type=param.type,
|
||||
@@ -99,9 +145,11 @@ class Preamble(Environment):
|
||||
]
|
||||
|
||||
function = Function(
|
||||
pos_args=map_args(pos, 0),
|
||||
args=map_args(mixed, len(pos)),
|
||||
kw_args=map_args(kw, len(pos) + len(mixed)),
|
||||
params=ParamSpec(
|
||||
pos=map_params(pos, 0),
|
||||
mixed=map_params(mixed, len(pos)),
|
||||
kw=map_params(kw, len(pos) + len(mixed)),
|
||||
),
|
||||
returns=returns,
|
||||
)
|
||||
if len(type_vars) != 0:
|
||||
@@ -121,7 +169,7 @@ class Preamble(Environment):
|
||||
kw: list[Param] = [],
|
||||
returns: Type = UnitType(),
|
||||
type_vars: list[TypeVar] = [],
|
||||
py_function: Optional[Callable] = None,
|
||||
py_function: Optional[Callable[..., Any]] = None,
|
||||
):
|
||||
function: Type = self._make_function(
|
||||
name=name,
|
||||
@@ -135,5 +183,31 @@ class Preamble(Environment):
|
||||
if py_function is not None:
|
||||
self._python_funcs[name] = py_function
|
||||
|
||||
def get_py_func(self, name: str) -> Optional[Callable]:
|
||||
def _def_overloads(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
signatures: list[
|
||||
tuple[list[Param], list[Param], list[Param], Type, list[TypeVar]]
|
||||
],
|
||||
py_function: Optional[Callable[..., Any]] = None,
|
||||
):
|
||||
overloads: list[Type] = []
|
||||
for pos, mixed, kw, returns, type_vars in signatures:
|
||||
overloads.append(
|
||||
self._make_function(
|
||||
name=name,
|
||||
pos=pos,
|
||||
mixed=mixed,
|
||||
kw=kw,
|
||||
returns=returns,
|
||||
type_vars=type_vars,
|
||||
)
|
||||
)
|
||||
function: Type = OverloadedFunction(overloads=overloads)
|
||||
self.define(name, function)
|
||||
if py_function is not None:
|
||||
self._python_funcs[name] = py_function
|
||||
|
||||
def get_py_func(self, name: str) -> Optional[Callable[..., Any]]:
|
||||
return self._python_funcs.get(name)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,8 +7,10 @@ from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ColumnType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
DataFrameType,
|
||||
DerivedType,
|
||||
ExtensionType,
|
||||
Function,
|
||||
@@ -16,6 +18,7 @@ from midas.checker.types import (
|
||||
OverloadedFunction,
|
||||
Predicate,
|
||||
TopType,
|
||||
TupleType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
@@ -110,6 +113,15 @@ class TypesRegistry:
|
||||
raise ValueError(f"Predicate {name} already defined")
|
||||
self._predicates[name] = predicate
|
||||
|
||||
def is_builtin_subtype(self, name1: str, name2: str) -> bool:
|
||||
subtypes: set[str] = BUILTIN_SUBTYPES.get(name2, set())
|
||||
if name1 in subtypes:
|
||||
return True
|
||||
for subtype in subtypes:
|
||||
if self.is_builtin_subtype(name1, subtype):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||
"""Check whether `type1` is a subtype of `type2`
|
||||
|
||||
@@ -147,7 +159,7 @@ class TypesRegistry:
|
||||
return self.is_subtype(base1, type2)
|
||||
|
||||
case (BaseType(name=name1), BaseType(name=name2)):
|
||||
return name1 in BUILTIN_SUBTYPES.get(name2, set())
|
||||
return self.is_builtin_subtype(name1, name2)
|
||||
|
||||
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
||||
for k, t in props2.items():
|
||||
@@ -157,6 +169,24 @@ class TypesRegistry:
|
||||
return False
|
||||
return True
|
||||
|
||||
case (DataFrameType(columns=columns1), DataFrameType(columns=columns2)):
|
||||
# TODO: check order?
|
||||
by_name1: dict[str, DataFrameType.Column] = {
|
||||
col.name: col for col in columns1 if col.name is not None
|
||||
}
|
||||
for col2 in columns2:
|
||||
if col2.name not in by_name1:
|
||||
return False
|
||||
if not self.is_subtype(by_name1[col2.name].type, col2.type):
|
||||
return False
|
||||
return True
|
||||
|
||||
case (ColumnType(type=inner1), ColumnType(type=inner2)):
|
||||
# TODO: invariant, replace ColumnType with simple GenericType
|
||||
if not self.are_equivalent(inner1, inner2):
|
||||
return False
|
||||
return True
|
||||
|
||||
case (Function(), Function()):
|
||||
return self.is_func_subtype(type1, type2)
|
||||
|
||||
@@ -187,6 +217,9 @@ class TypesRegistry:
|
||||
|
||||
return False
|
||||
|
||||
def are_equivalent(self, type1: Type, type2: Type) -> bool:
|
||||
return self.is_subtype(type1, type2) and self.is_subtype(type2, type1)
|
||||
|
||||
# TODO: verify the logic in here
|
||||
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
|
||||
"""Check whether a function is a subtype of another
|
||||
@@ -201,92 +234,100 @@ class TypesRegistry:
|
||||
if not self.is_subtype(func1.returns, func2.returns):
|
||||
return False
|
||||
|
||||
pos1: list[Function.Argument] = func1.pos_args
|
||||
mixed1: list[Function.Argument] = func1.args
|
||||
kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args}
|
||||
pos2: list[Function.Argument] = func2.pos_args
|
||||
mixed2: list[Function.Argument] = func2.args
|
||||
kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args}
|
||||
pos1: list[Function.Parameter] = func1.params.pos
|
||||
mixed1: list[Function.Parameter] = func1.params.mixed
|
||||
kw1: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in func1.params.kw
|
||||
}
|
||||
pos2: list[Function.Parameter] = func2.params.pos
|
||||
mixed2: list[Function.Parameter] = func2.params.mixed
|
||||
kw2: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in func2.params.kw
|
||||
}
|
||||
|
||||
mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2}
|
||||
mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2}
|
||||
mixed_by_pos: dict[int, Function.Parameter] = {
|
||||
param.pos: param for param in mixed2
|
||||
}
|
||||
mixed_by_name: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in mixed2
|
||||
}
|
||||
|
||||
def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool:
|
||||
def is_arg_subtype(sub: Function.Parameter, sup: Function.Parameter) -> bool:
|
||||
if not self.is_subtype(sub.type, sup.type):
|
||||
return False
|
||||
if not sup.required and sub.required:
|
||||
return False
|
||||
return True
|
||||
|
||||
for arg1 in pos1:
|
||||
arg2: Function.Argument
|
||||
if arg1.pos < len(pos2):
|
||||
arg2 = pos2[arg1.pos]
|
||||
elif arg1.pos in mixed_by_pos:
|
||||
arg2 = mixed_by_pos[arg1.pos]
|
||||
elif not arg1.required:
|
||||
for param1 in pos1:
|
||||
param2: Function.Parameter
|
||||
if param1.pos < len(pos2):
|
||||
param2 = pos2[param1.pos]
|
||||
elif param1.pos in mixed_by_pos:
|
||||
param2 = mixed_by_pos[param1.pos]
|
||||
elif not param1.required:
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
if not is_arg_subtype(arg2, arg1):
|
||||
if not is_arg_subtype(param2, param1):
|
||||
return False
|
||||
|
||||
for name, arg1 in kw1.items():
|
||||
arg2: Function.Argument
|
||||
for name, param1 in kw1.items():
|
||||
param2: Function.Parameter
|
||||
if name in kw2:
|
||||
arg2 = kw2[name]
|
||||
param2 = kw2[name]
|
||||
elif name in mixed_by_name:
|
||||
arg2 = mixed_by_name[name]
|
||||
elif not arg1.required:
|
||||
param2 = mixed_by_name[name]
|
||||
elif not param1.required:
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
if not is_arg_subtype(arg2, arg1):
|
||||
if not is_arg_subtype(param2, param1):
|
||||
return False
|
||||
|
||||
for arg1 in mixed1:
|
||||
pos_arg2: Optional[Function.Argument] = None
|
||||
kw_arg2: Optional[Function.Argument] = None
|
||||
if arg1.name in kw2:
|
||||
kw_arg2 = kw2[arg1.name]
|
||||
elif arg1.name in mixed_by_name:
|
||||
kw_arg2 = mixed_by_name[arg1.name]
|
||||
if arg1.pos < len(pos2):
|
||||
pos_arg2 = pos2[arg1.pos]
|
||||
elif arg1.pos in mixed_by_pos:
|
||||
pos_arg2 = mixed_by_pos[arg1.pos]
|
||||
for param1 in mixed1:
|
||||
pos_param2: Optional[Function.Parameter] = None
|
||||
kw_param2: Optional[Function.Parameter] = None
|
||||
if param1.name in kw2:
|
||||
kw_param2 = kw2[param1.name]
|
||||
elif param1.name in mixed_by_name:
|
||||
kw_param2 = mixed_by_name[param1.name]
|
||||
if param1.pos < len(pos2):
|
||||
pos_param2 = pos2[param1.pos]
|
||||
elif param1.pos in mixed_by_pos:
|
||||
pos_param2 = mixed_by_pos[param1.pos]
|
||||
|
||||
# No match in func2 and arg is required
|
||||
if pos_arg2 is None and kw_arg2 is None and arg1.required:
|
||||
if pos_param2 is None and kw_param2 is None and param1.required:
|
||||
return False
|
||||
|
||||
# Matching keyword argument
|
||||
if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1):
|
||||
if kw_param2 is not None and not is_arg_subtype(kw_param2, param1):
|
||||
return False
|
||||
|
||||
# Matching positional argument
|
||||
if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1):
|
||||
if pos_param2 is not None and not is_arg_subtype(pos_param2, param1):
|
||||
return False
|
||||
|
||||
mixed_positions: set[int] = {a.pos for a in mixed1}
|
||||
mixed_names: set[str] = {a.name for a in mixed1}
|
||||
for arg2 in pos2:
|
||||
if not arg2.required:
|
||||
mixed_positions: set[int] = {param.pos for param in mixed1}
|
||||
mixed_names: set[str] = {param.name for param in mixed1}
|
||||
for param2 in pos2:
|
||||
if not param2.required:
|
||||
continue
|
||||
if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions:
|
||||
if param2.pos >= len(pos1) and param2.pos not in mixed_positions:
|
||||
return False
|
||||
|
||||
for name, arg2 in kw2.items():
|
||||
if not arg2.required:
|
||||
for name, param2 in kw2.items():
|
||||
if not param2.required:
|
||||
continue
|
||||
if name not in kw1 and name not in mixed_names:
|
||||
return False
|
||||
|
||||
for arg2 in mixed2:
|
||||
if arg2.required:
|
||||
for param2 in mixed2:
|
||||
if param2.required:
|
||||
continue
|
||||
pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions
|
||||
kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names
|
||||
pos_match: bool = param2.pos < len(pos1) or param2.pos in mixed_positions
|
||||
kw_match: bool = param2.name in kw1 or param2.name in mixed_names
|
||||
if not pos_match or not kw_match:
|
||||
return False
|
||||
|
||||
@@ -323,6 +364,9 @@ class TypesRegistry:
|
||||
body=substitute_typevars(body, substitutions),
|
||||
)
|
||||
|
||||
case BaseType(name="tuple"):
|
||||
return TupleType(items=tuple(args))
|
||||
|
||||
case _:
|
||||
raise ValueError(f"{type} is not a generic type")
|
||||
|
||||
@@ -416,3 +460,29 @@ class TypesRegistry:
|
||||
|
||||
def lookup_predicate(self, name: str) -> Optional[Predicate]:
|
||||
return self._predicates.get(name)
|
||||
|
||||
def _by_name_or_type(self, name_or_type: str | Type) -> Type:
|
||||
if isinstance(name_or_type, str):
|
||||
return self.get_type(name_or_type)
|
||||
return name_or_type
|
||||
|
||||
def list_of(self, item_type: str | Type) -> Type:
|
||||
list_ = self.get_type("list")
|
||||
return self.apply_generic(list_, [self._by_name_or_type(item_type)])
|
||||
|
||||
def tuple_of(self, *item_types: str | Type) -> Type:
|
||||
tuple_ = self.get_type("tuple")
|
||||
return self.apply_generic(
|
||||
tuple_,
|
||||
[self._by_name_or_type(item_type) for item_type in item_types],
|
||||
)
|
||||
|
||||
def dict_of(self, key_type: str | Type, value_type: str | Type) -> Type:
|
||||
dict_ = self.get_type("dict")
|
||||
return self.apply_generic(
|
||||
dict_,
|
||||
[
|
||||
self._by_name_or_type(key_type),
|
||||
self._by_name_or_type(value_type),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -93,7 +93,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
function (p.Function): the function to resolve
|
||||
"""
|
||||
self.begin_scope()
|
||||
for param in function.all_args:
|
||||
for param in function.params.all:
|
||||
self.declare(param.name)
|
||||
self.define(param.name)
|
||||
self.resolve(*function.body)
|
||||
@@ -128,6 +128,10 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
|
||||
case p.GetExpr():
|
||||
target.accept(self)
|
||||
|
||||
case p.SubscriptExpr():
|
||||
target.accept(self)
|
||||
|
||||
case _:
|
||||
raise Exception(f"Unsupported assignment to {target}")
|
||||
|
||||
@@ -232,5 +236,9 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
if expr.step is not None:
|
||||
self.resolve(expr.step)
|
||||
|
||||
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
|
||||
for item in expr.items:
|
||||
self.resolve(item)
|
||||
|
||||
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
||||
pass
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from typing import Optional, assert_never
|
||||
from typing import Optional, assert_never, cast
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.ast.printer import MidasPrinter
|
||||
@@ -45,28 +45,14 @@ class UnitType:
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Function:
|
||||
pos_args: list[Argument] = field(default_factory=list)
|
||||
args: list[Argument] = field(default_factory=list)
|
||||
kw_args: list[Argument] = field(default_factory=list)
|
||||
params: ParamSpec
|
||||
returns: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
args: list[str] = []
|
||||
if len(self.pos_args) != 0:
|
||||
args += list(map(str, self.pos_args))
|
||||
args.append("/")
|
||||
|
||||
if len(self.args) != 0:
|
||||
args += list(map(str, self.args))
|
||||
|
||||
if len(self.kw_args) != 0:
|
||||
args.append("*")
|
||||
args += list(map(str, self.kw_args))
|
||||
|
||||
return f"({', '.join(args)}) -> {self.returns}"
|
||||
return f"{self.params} -> {self.returns}"
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
class Parameter:
|
||||
pos: int
|
||||
name: str
|
||||
type: Type
|
||||
@@ -77,6 +63,28 @@ class Function:
|
||||
return f"{self.name}: {self.type}{opt}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
pos: list[Function.Parameter] = field(default_factory=list)
|
||||
mixed: list[Function.Parameter] = field(default_factory=list)
|
||||
kw: list[Function.Parameter] = field(default_factory=list)
|
||||
|
||||
def __str__(self) -> str:
|
||||
params: list[str] = []
|
||||
if len(self.pos) != 0:
|
||||
params += list(map(str, self.pos))
|
||||
params.append("/")
|
||||
|
||||
if len(self.mixed) != 0:
|
||||
params += list(map(str, self.mixed))
|
||||
|
||||
if len(self.kw) != 0:
|
||||
params.append("*")
|
||||
params += list(map(str, self.kw))
|
||||
|
||||
return f"({', '.join(params)})"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class OverloadedFunction:
|
||||
overloads: list[Type]
|
||||
@@ -156,13 +164,74 @@ class ConstraintType:
|
||||
return f"{self.type} where {printer.print(self.constraint)}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TupleType:
|
||||
items: tuple[Type, ...]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"({', '.join(map(str, self.items))})"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ColumnType:
|
||||
type: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Column[{self.type}]"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class DataFrameType:
|
||||
columns: list[Column]
|
||||
|
||||
def __str__(self) -> str:
|
||||
schema: list[str] = [f"{col.name}: {col.type}" for col in self.columns]
|
||||
return f"Frame[{', '.join(schema)}]"
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Column:
|
||||
index: int
|
||||
name: Optional[str]
|
||||
type: ColumnType
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class FrameGroupBy:
|
||||
frame: DataFrameType
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"FrameGroupBy[{self.frame}]"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ColumnGroupBy:
|
||||
column: ColumnType
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ColumnGroupBy[{self.column}]"
|
||||
|
||||
|
||||
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
def sub_argument(arg: Function.Argument):
|
||||
return Function.Argument(
|
||||
pos=arg.pos,
|
||||
name=arg.name,
|
||||
type=substitute_typevars(arg.type, substitutions),
|
||||
required=arg.required,
|
||||
def sub_parameter(param: Function.Parameter):
|
||||
return Function.Parameter(
|
||||
pos=param.pos,
|
||||
name=param.name,
|
||||
type=substitute_typevars(param.type, substitutions),
|
||||
required=param.required,
|
||||
)
|
||||
|
||||
def sub_param_spec(spec: ParamSpec):
|
||||
return ParamSpec(
|
||||
pos=list(map(sub_parameter, spec.pos)),
|
||||
mixed=list(map(sub_parameter, spec.mixed)),
|
||||
kw=list(map(sub_parameter, spec.kw)),
|
||||
)
|
||||
|
||||
def sub_column(col: DataFrameType.Column):
|
||||
return DataFrameType.Column(
|
||||
index=col.index,
|
||||
name=col.name,
|
||||
type=cast(ColumnType, substitute_typevars(col.type, substitutions)),
|
||||
)
|
||||
|
||||
match type:
|
||||
@@ -181,15 +250,11 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
)
|
||||
|
||||
case Function(
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
params=params,
|
||||
returns=returns,
|
||||
):
|
||||
return Function(
|
||||
pos_args=list(map(sub_argument, pos_args)),
|
||||
args=list(map(sub_argument, args)),
|
||||
kw_args=list(map(sub_argument, kw_args)),
|
||||
params=sub_param_spec(params),
|
||||
returns=substitute_typevars(returns, substitutions),
|
||||
)
|
||||
|
||||
@@ -252,6 +317,31 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
body=substitute_typevars(body, substitutions),
|
||||
)
|
||||
|
||||
case TupleType(items=items):
|
||||
return TupleType(
|
||||
items=tuple(substitute_typevars(item, substitutions) for item in items),
|
||||
)
|
||||
|
||||
case ColumnType(type=items_type):
|
||||
return ColumnType(
|
||||
type=substitute_typevars(items_type, substitutions),
|
||||
)
|
||||
|
||||
case DataFrameType(columns=columns):
|
||||
return DataFrameType(
|
||||
columns=list(map(sub_column, columns)),
|
||||
)
|
||||
|
||||
case FrameGroupBy(frame=frame):
|
||||
return FrameGroupBy(
|
||||
frame=cast(DataFrameType, substitute_typevars(frame, substitutions))
|
||||
)
|
||||
|
||||
case ColumnGroupBy(column=column):
|
||||
return ColumnGroupBy(
|
||||
column=cast(ColumnType, substitute_typevars(column, substitutions))
|
||||
)
|
||||
|
||||
case UnknownType() | UnitType():
|
||||
return type
|
||||
|
||||
@@ -272,14 +362,14 @@ def unfold_type(type: Type) -> Type:
|
||||
|
||||
|
||||
def to_annotation(type: Type) -> str:
|
||||
def _args_annotation(func: Function) -> str:
|
||||
if len(func.kw_args) != 0:
|
||||
def _params_annotation(spec: ParamSpec) -> str:
|
||||
if len(spec.kw) != 0:
|
||||
return "..."
|
||||
|
||||
args: str = ", ".join(
|
||||
to_annotation(arg.type) for arg in func.pos_args + func.args
|
||||
params: str = ", ".join(
|
||||
to_annotation(param.type) for param in spec.pos + spec.mixed
|
||||
)
|
||||
return f"[{args}]"
|
||||
return f"[{params}]"
|
||||
|
||||
match type:
|
||||
case TopType():
|
||||
@@ -297,8 +387,8 @@ def to_annotation(type: Type) -> str:
|
||||
case UnitType():
|
||||
return "None"
|
||||
|
||||
case Function(returns=returns):
|
||||
params_annot: str = _args_annotation(type)
|
||||
case Function(params=params, returns=returns):
|
||||
params_annot: str = _params_annotation(params)
|
||||
return f"Callable[{params_annot}, {to_annotation(returns)}]"
|
||||
|
||||
case OverloadedFunction():
|
||||
@@ -319,6 +409,21 @@ def to_annotation(type: Type) -> str:
|
||||
case ConstraintType():
|
||||
return str(type)
|
||||
|
||||
case TupleType(items=items):
|
||||
return f"Tuple[{', '.join(map(to_annotation, items))}]"
|
||||
|
||||
case ColumnType():
|
||||
return "pd.Series"
|
||||
|
||||
case DataFrameType():
|
||||
return "pd.DataFrame"
|
||||
|
||||
case FrameGroupBy():
|
||||
return "pd.api.typing.DataFrameGroupBy"
|
||||
|
||||
case ColumnGroupBy():
|
||||
return "pd.api.typing.SeriesGroupBy"
|
||||
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
@@ -344,4 +449,9 @@ Type = (
|
||||
| GenericType
|
||||
| AppliedType
|
||||
| ConstraintType
|
||||
| TupleType
|
||||
| ColumnType
|
||||
| DataFrameType
|
||||
| FrameGroupBy
|
||||
| ColumnGroupBy
|
||||
)
|
||||
|
||||
@@ -4,8 +4,11 @@ from typing import Optional
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
Function,
|
||||
GenericType,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
@@ -27,25 +30,26 @@ class Unifier:
|
||||
keywords: dict[str, Type],
|
||||
) -> Optional[Type]:
|
||||
concrete_func: Function = Function(
|
||||
pos_args=[
|
||||
Function.Argument(
|
||||
pos=i,
|
||||
name=str(i),
|
||||
type=arg,
|
||||
required=True,
|
||||
)
|
||||
for i, arg in enumerate(positional)
|
||||
],
|
||||
args=[],
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=len(positional) + i,
|
||||
name=name,
|
||||
type=arg,
|
||||
required=True,
|
||||
)
|
||||
for i, (name, arg) in enumerate(keywords.items())
|
||||
],
|
||||
params=ParamSpec(
|
||||
pos=[
|
||||
Function.Parameter(
|
||||
pos=i,
|
||||
name=str(i),
|
||||
type=arg,
|
||||
required=True,
|
||||
)
|
||||
for i, arg in enumerate(positional)
|
||||
],
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=len(positional) + i,
|
||||
name=name,
|
||||
type=arg,
|
||||
required=True,
|
||||
)
|
||||
for i, (name, arg) in enumerate(keywords.items())
|
||||
],
|
||||
),
|
||||
returns=TopType(), # TODO: use expected type
|
||||
)
|
||||
return self.unify_generic(type, concrete_func, match_return=False)
|
||||
@@ -98,8 +102,32 @@ class Unifier:
|
||||
|
||||
return substitutions
|
||||
|
||||
case (
|
||||
DataFrameType(columns=template_columns),
|
||||
DataFrameType(columns=concrete_columns),
|
||||
) if len(template_columns) == len(concrete_columns):
|
||||
substitutions: dict[str, Type] = {}
|
||||
for template_column, concrete_column in zip(
|
||||
template_columns, concrete_columns
|
||||
):
|
||||
if template_column.index != concrete_column or (
|
||||
template_column.name != concrete_column.name
|
||||
):
|
||||
self.logger.debug(
|
||||
f"Column mismatch: template={template_column}, concrete={concrete_column}"
|
||||
)
|
||||
raise UnificationError
|
||||
new_substistutions: dict[str, Type] = self.match(
|
||||
template_column.type, concrete_column.type
|
||||
)
|
||||
substitutions = self.merge(substitutions, new_substistutions)
|
||||
return substitutions
|
||||
|
||||
case (ColumnType(type=template_column), ColumnType(type=concrete_column)):
|
||||
return self.match(template_column, concrete_column)
|
||||
|
||||
case (Function(), Function()):
|
||||
mapped: list[tuple[Function.Argument, Function.Argument]] = (
|
||||
mapped: list[tuple[Function.Parameter, Function.Parameter]] = (
|
||||
self.map_params(template, concrete)
|
||||
)
|
||||
substitutions: dict[str, Type] = {}
|
||||
@@ -135,19 +163,23 @@ class Unifier:
|
||||
|
||||
def map_params(
|
||||
self, func1: Function, func2: Function
|
||||
) -> list[tuple[Function.Argument, Function.Argument]]:
|
||||
pos1: list[Function.Argument] = func1.pos_args
|
||||
mixed1: list[Function.Argument] = func1.args
|
||||
kw1: list[Function.Argument] = func1.kw_args
|
||||
) -> list[tuple[Function.Parameter, Function.Parameter]]:
|
||||
pos1: list[Function.Parameter] = func1.params.pos
|
||||
mixed1: list[Function.Parameter] = func1.params.mixed
|
||||
kw1: list[Function.Parameter] = func1.params.kw
|
||||
|
||||
pos2: list[Function.Argument] = func2.pos_args
|
||||
mixed2: list[Function.Argument] = func2.args
|
||||
kw2: list[Function.Argument] = func2.kw_args
|
||||
pos2: list[Function.Parameter] = func2.params.pos
|
||||
mixed2: list[Function.Parameter] = func2.params.mixed
|
||||
kw2: list[Function.Parameter] = func2.params.kw
|
||||
|
||||
mapped: list[tuple[Function.Argument, Function.Argument]] = []
|
||||
mapped: list[tuple[Function.Parameter, Function.Parameter]] = []
|
||||
|
||||
by_pos2: dict[int, Function.Argument] = {arg.pos: arg for arg in pos2 + mixed2}
|
||||
by_name2: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2 + kw2}
|
||||
by_pos2: dict[int, Function.Parameter] = {
|
||||
param.pos: param for param in pos2 + mixed2
|
||||
}
|
||||
by_name2: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in mixed2 + kw2
|
||||
}
|
||||
|
||||
for arg1 in pos1:
|
||||
if (arg2 := by_pos2.get(arg1.pos)) is not None:
|
||||
|
||||
@@ -77,14 +77,14 @@ class VarianceInferrer:
|
||||
match type:
|
||||
# Arguments are negative positions -> flip polarity
|
||||
# Return is positive position -> keep polarity
|
||||
case Function(pos_args=pos_args, args=mixed_args, kw_args=kw_args):
|
||||
all_args: list[Function.Argument] = pos_args + mixed_args + kw_args
|
||||
for arg in all_args:
|
||||
case Function(params=spec):
|
||||
all_params: list[Function.Parameter] = spec.pos + spec.mixed + spec.kw
|
||||
for param in all_params:
|
||||
self.walk(
|
||||
arg.type,
|
||||
param.type,
|
||||
-polarity,
|
||||
base_name,
|
||||
path + [f"arg:'{arg.name}'"],
|
||||
path + [f"param:'{param.name}'"],
|
||||
)
|
||||
|
||||
self.walk(type.returns, polarity, base_name, path + ["return"])
|
||||
@@ -109,10 +109,10 @@ class VarianceInferrer:
|
||||
Variance.COVARIANT: 1,
|
||||
Variance.CONTRAVARIANT: -1,
|
||||
}
|
||||
for arg, param in zip(args, params):
|
||||
for param, param in zip(args, params):
|
||||
param_polarity: Polarity = polarities[param.variance]
|
||||
self.walk(
|
||||
arg,
|
||||
param,
|
||||
cast(Polarity, polarity * param_polarity),
|
||||
base_name,
|
||||
path + [f"applied:'{name}'"],
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
from typing import Optional, TextIO
|
||||
|
||||
import click
|
||||
|
||||
@@ -19,18 +19,23 @@ from midas.utils import TypedAST
|
||||
@click.command(help="Compile source")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||
@click.option("-s", "--stubs", type=str, multiple=True)
|
||||
@click.option("--ignore-errors", is_flag=True)
|
||||
def compile(
|
||||
file: TextIO,
|
||||
types: tuple[TextIO],
|
||||
stubs: tuple[str],
|
||||
ignore_errors: bool,
|
||||
):
|
||||
source: str = file.read()
|
||||
source_path: Path = Path(file.name).resolve()
|
||||
|
||||
checker = TypeChecker()
|
||||
for types_file in types:
|
||||
checker.import_midas(Path(types_file.name).resolve())
|
||||
type_files: list[tuple[Path, Optional[str]]] = []
|
||||
for i, types_file in enumerate(types):
|
||||
in_path: Path = Path(types_file.name).resolve()
|
||||
checker.import_midas(in_path)
|
||||
type_files.append((in_path, stubs[i] if i < len(stubs) else None))
|
||||
|
||||
typed_ast: TypedAST = checker.type_check_source(source, str(source_path))
|
||||
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
|
||||
@@ -43,4 +48,4 @@ def compile(
|
||||
sys.exit(1)
|
||||
|
||||
generator = Generator(workdir=source_path.parent, types=checker.types)
|
||||
generator.generate(typed_ast, source_path)
|
||||
generator.generate(typed_ast, source_path, type_files=type_files)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ast
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
from typing import Optional, TextIO
|
||||
|
||||
import black
|
||||
import click
|
||||
@@ -38,15 +38,17 @@ class Handler(FileSystemEventHandler):
|
||||
|
||||
@click.command(help="Generate stubs from Midas definitions")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||
@click.option("-o", "--output", type=click.File("w"))
|
||||
@click.option("-w", "--watch", is_flag=True)
|
||||
def stubs(
|
||||
file: TextIO,
|
||||
output: TextIO,
|
||||
output: Optional[TextIO],
|
||||
watch: bool,
|
||||
):
|
||||
source_path: Path = Path(file.name).resolve()
|
||||
out_path: Path = Path(output.name).resolve()
|
||||
out_path: Path = source_path.with_suffix(".pyi")
|
||||
if output is not None:
|
||||
out_path = Path(output.name).resolve()
|
||||
generate_stubs(source_path, out_path)
|
||||
|
||||
if watch:
|
||||
|
||||
@@ -134,9 +134,9 @@ class PythonHighlighter(
|
||||
|
||||
def visit_base_type(self, node: p.BaseType) -> None:
|
||||
self.wrap(node, "base-type")
|
||||
if node.param is not None:
|
||||
self.wrap(node.param, "param")
|
||||
node.param.accept(self)
|
||||
for arg in node.args:
|
||||
self.wrap(arg, "arg")
|
||||
arg.accept(self)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||
self.wrap(node, "constraint-type")
|
||||
@@ -157,15 +157,18 @@ class PythonHighlighter(
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
self.wrap(stmt, "function")
|
||||
for arg in stmt.posonlyargs + stmt.args + stmt.kwonlyargs:
|
||||
self._highlight_function_argument(arg)
|
||||
self._highlight_param_spec(stmt.params)
|
||||
for body_stmt in stmt.body:
|
||||
body_stmt.accept(self)
|
||||
|
||||
def _highlight_function_argument(self, arg: p.Function.Argument) -> None:
|
||||
self.wrap(arg, "argument")
|
||||
if arg.type is not None:
|
||||
arg.type.accept(self)
|
||||
def _highlight_param_spec(self, spec: p.ParamSpec) -> None:
|
||||
for param in spec.all:
|
||||
self._highlight_function_param(param)
|
||||
|
||||
def _highlight_function_param(self, param: p.Function.Parameter) -> None:
|
||||
self.wrap(param, "parameter")
|
||||
if param.type is not None:
|
||||
param.type.accept(self)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
stmt.type.accept(self)
|
||||
@@ -247,6 +250,10 @@ class PythonHighlighter(
|
||||
if expr.step is not None:
|
||||
expr.step.accept(self)
|
||||
|
||||
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
|
||||
for item in expr.items:
|
||||
item.accept(self)
|
||||
|
||||
def visit_raw_expr(self, expr: p.RawExpr) -> None: ...
|
||||
|
||||
def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ...
|
||||
@@ -350,6 +357,14 @@ class MidasHighlighter(
|
||||
for param in spec.pos + spec.mixed + spec.kw:
|
||||
param.type.accept(self)
|
||||
|
||||
def visit_frame_type(self, type: m.FrameType) -> None:
|
||||
self.wrap(type, "frame")
|
||||
for column in type.columns:
|
||||
self._visit_frame_column(column)
|
||||
|
||||
def _visit_frame_column(self, column: m.FrameType.Column) -> None:
|
||||
self.wrap(column, "column")
|
||||
|
||||
|
||||
class DiagnosticsHighlighter(Highlighter):
|
||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
|
||||
|
||||
@@ -3,7 +3,7 @@ span {
|
||||
--col: 108, 233, 108;
|
||||
}
|
||||
|
||||
&.param {
|
||||
&.arg {
|
||||
--col: 103, 192, 224;
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ span {
|
||||
--col: 215, 103, 224;
|
||||
}
|
||||
|
||||
&.argument {
|
||||
&.parameter {
|
||||
--col: 103, 192, 224;
|
||||
}
|
||||
}
|
||||
@@ -68,7 +68,7 @@ class DiagnosticPrinter:
|
||||
|
||||
loc: Location = diagnostic.location
|
||||
if loc.lineno != loc.end_lineno:
|
||||
print(diagnostic)
|
||||
self.print_multiline(lines, diagnostic, indent)
|
||||
return
|
||||
|
||||
start_offset: int = loc.col_offset
|
||||
@@ -95,3 +95,27 @@ class DiagnosticPrinter:
|
||||
print(indent_str + before + subject + after)
|
||||
print(indent_str + cursor)
|
||||
print()
|
||||
|
||||
def print_multiline(
|
||||
self, all_lines: list[str], diagnostic: Diagnostic, indent: int = 4
|
||||
):
|
||||
loc: Location = diagnostic.location
|
||||
lines: list[str] = all_lines[loc.lineno - 1 : loc.end_lineno]
|
||||
|
||||
start_offset: int = loc.col_offset
|
||||
end_offset: int = loc.end_col_offset or (start_offset + 1)
|
||||
|
||||
indent_str: str = " " * indent
|
||||
color: int = self.COLORS.get(diagnostic.type, Ansi.WHITE)
|
||||
res: str = indent_str + lines[0][:start_offset]
|
||||
res += Ansi.FG(color) + lines[0][start_offset:]
|
||||
for line in lines[1:-1]:
|
||||
res += "\n" + indent_str + line
|
||||
res += "\n" + indent_str + lines[-1][:end_offset]
|
||||
res += Ansi.RESET + lines[-1][end_offset:]
|
||||
|
||||
print(diagnostic.location_str + ":")
|
||||
print(res)
|
||||
print()
|
||||
print(Ansi.FG(color) + diagnostic.message + Ansi.RESET)
|
||||
print()
|
||||
|
||||
59
midas/generator/collector.py
Normal file
59
midas/generator/collector.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import midas.ast.python as p
|
||||
|
||||
AssertionBuilder = Callable[..., ast.expr]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Assertion:
|
||||
bound_expr: p.Expr
|
||||
inputs: list[p.Expr]
|
||||
builder: AssertionBuilder
|
||||
message: str
|
||||
|
||||
def is_bound_to(self, expr: p.Expr) -> bool:
|
||||
return expr == self.bound_expr
|
||||
|
||||
|
||||
class AssertionCollector:
|
||||
def __init__(self):
|
||||
self.assertions: list[Assertion] = []
|
||||
self.definitions: dict[str, ast.stmt] = {}
|
||||
|
||||
def add(
|
||||
self,
|
||||
bound_expr: p.Expr,
|
||||
inputs: list[p.Expr],
|
||||
builder: AssertionBuilder,
|
||||
message: str,
|
||||
):
|
||||
self.assertions.append(
|
||||
Assertion(
|
||||
bound_expr=bound_expr,
|
||||
inputs=inputs,
|
||||
builder=builder,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
|
||||
def remove(self, assertion: Assertion):
|
||||
try:
|
||||
self.assertions.remove(assertion)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def define(self, name: str, stmt: ast.stmt):
|
||||
if name not in self.definitions:
|
||||
self.definitions[name] = stmt
|
||||
|
||||
def get_definitions(self) -> list[ast.stmt]:
|
||||
return list(self.definitions.values())
|
||||
|
||||
def get_assertions(self) -> list[Assertion]:
|
||||
return self.assertions
|
||||
|
||||
def get_assertions_for(self, expr: p.Expr) -> list[Assertion]:
|
||||
return list(filter(lambda a: a.is_bound_to(expr), self.assertions))
|
||||
@@ -5,6 +5,7 @@ import midas.ast.midas as m
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import (
|
||||
Function,
|
||||
ParamSpec,
|
||||
Predicate,
|
||||
Type,
|
||||
to_annotation,
|
||||
@@ -54,16 +55,16 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||
return expr.accept(self)
|
||||
case _:
|
||||
func = Function(
|
||||
pos_args=[],
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="_",
|
||||
type=self.types.get_type("Any"),
|
||||
required=True,
|
||||
)
|
||||
],
|
||||
kw_args=[],
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="_",
|
||||
type=self.types.get_type("Any"),
|
||||
required=True,
|
||||
)
|
||||
],
|
||||
),
|
||||
returns=self.types.get_type("bool"),
|
||||
)
|
||||
alias: str = self.make_alias(None)
|
||||
@@ -94,28 +95,28 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||
)
|
||||
return self.make_func(name, [ast.Return(value=body)], predicate.type)
|
||||
|
||||
def make_args(self, func: Function) -> ast.arguments:
|
||||
def make_args(self, params: ParamSpec) -> ast.arguments:
|
||||
return ast.arguments(
|
||||
posonlyargs=[
|
||||
ast.arg(
|
||||
arg=arg.name,
|
||||
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||
arg=param.name,
|
||||
annotation=ast.Constant(value=to_annotation(param.type)),
|
||||
)
|
||||
for arg in func.pos_args
|
||||
for param in params.pos
|
||||
],
|
||||
args=[
|
||||
ast.arg(
|
||||
arg=arg.name,
|
||||
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||
arg=param.name,
|
||||
annotation=ast.Constant(value=to_annotation(param.type)),
|
||||
)
|
||||
for arg in func.args
|
||||
for param in params.mixed
|
||||
],
|
||||
kwonlyargs=[
|
||||
ast.arg(
|
||||
arg=arg.name,
|
||||
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||
arg=param.name,
|
||||
annotation=ast.Constant(value=to_annotation(param.type)),
|
||||
)
|
||||
for arg in func.kw_args
|
||||
for param in params.kw
|
||||
],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
@@ -125,11 +126,11 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||
self, name: str, inner_body: list[ast.stmt], type: Type, level: int = 0
|
||||
) -> ast.stmt:
|
||||
match type:
|
||||
case Function(returns=Function()):
|
||||
case Function(params=params, returns=Function()):
|
||||
inner_name: str = f"inner{level}"
|
||||
return ast.FunctionDef(
|
||||
name=name,
|
||||
args=self.make_args(type),
|
||||
args=self.make_args(params),
|
||||
body=[
|
||||
self.make_func(inner_name, inner_body, type.returns, level + 1),
|
||||
ast.Return(value=ast.Name(id=inner_name)),
|
||||
@@ -138,10 +139,10 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||
decorator_list=[],
|
||||
)
|
||||
|
||||
case Function():
|
||||
case Function(params=params):
|
||||
return ast.FunctionDef(
|
||||
name=name,
|
||||
args=self.make_args(type),
|
||||
args=self.make_args(params),
|
||||
body=inner_body,
|
||||
returns=ast.Constant(value=to_annotation(type.returns)),
|
||||
decorator_list=[],
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import ast
|
||||
import logging
|
||||
import shutil
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
@@ -8,65 +9,96 @@ import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.printer import MidasPrinter
|
||||
from midas.checker.checker import TypeChecker
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
DataFrameType,
|
||||
DerivedType,
|
||||
ExtensionType,
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
TopType,
|
||||
TupleType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.generator.collector import Assertion, AssertionCollector
|
||||
from midas.generator.constraints import ConstraintGenerator
|
||||
from midas.generator.stubs import StubsGenerator
|
||||
from midas.utils import TypedAST
|
||||
|
||||
|
||||
@dataclass
|
||||
class Scope:
|
||||
pre_assertions: list[ast.stmt] = field(default_factory=list)
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
pre_assertions: list[ast.stmt] = field(default_factory=list[ast.stmt])
|
||||
aliases: list[str] = field(default_factory=list[str])
|
||||
|
||||
|
||||
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
IS_DATAFRAME_FUNC = "__midas_is_dataframe__"
|
||||
IS_COLUMN_FUNC = "__midas_is_column__"
|
||||
|
||||
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
|
||||
self.workdir: Path = workdir.resolve()
|
||||
self.build_dir: Path = self.workdir / "build" / "midas"
|
||||
self.rel_src_path: Path = Path()
|
||||
self.logger: logging.Logger = logging.getLogger("Generator")
|
||||
|
||||
self._typed_ast: TypedAST = TypedAST(
|
||||
stmts=[],
|
||||
judgements=[],
|
||||
evaluated_casts=[],
|
||||
assertions=AssertionCollector(),
|
||||
)
|
||||
self._alias_count: int = 0
|
||||
self._predicate_count: int = 0
|
||||
self._scopes: list[Scope] = []
|
||||
self._aliases: list[tuple[p.Expr, ast.expr]] = []
|
||||
|
||||
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
|
||||
self._constraints: list[tuple[m.Expr, ast.expr]] = []
|
||||
|
||||
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
|
||||
self.rel_src_path = src_path.resolve().relative_to(self.workdir)
|
||||
self.define_is_dataframe: bool = False
|
||||
self.define_is_column: bool = False
|
||||
|
||||
def set_src_path(self, path: Path):
|
||||
self.rel_src_path = path.resolve().relative_to(self.workdir)
|
||||
|
||||
def generate_ast(self, typed_ast: TypedAST) -> ast.AST:
|
||||
self._typed_ast = typed_ast
|
||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts, can_be_empty=True)
|
||||
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
|
||||
module = ast.Module(body=predicates + body, type_ignores=[])
|
||||
|
||||
body = predicates + body
|
||||
|
||||
if self.define_is_dataframe:
|
||||
body = [self._is_dataframe_definition()] + body
|
||||
|
||||
if self.define_is_column:
|
||||
body = [self._is_column_definition()] + body
|
||||
|
||||
module = ast.Module(body=body, type_ignores=[])
|
||||
module = ast.fix_missing_locations(module)
|
||||
return module
|
||||
|
||||
def generate(
|
||||
self, typed_ast: TypedAST, src_path: Path, out_path: Optional[Path] = None
|
||||
self,
|
||||
typed_ast: TypedAST,
|
||||
src_path: Path,
|
||||
out_path: Optional[Path] = None,
|
||||
type_files: Optional[list[tuple[Path, Optional[str]]]] = None,
|
||||
) -> Path:
|
||||
module: ast.AST = self.generate_ast(typed_ast, src_path)
|
||||
compiled: str = ast.unparse(module)
|
||||
self.set_src_path(src_path)
|
||||
if out_path is None:
|
||||
if self.build_dir.exists():
|
||||
shutil.rmtree(self.build_dir)
|
||||
@@ -78,43 +110,72 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
raise ValueError(
|
||||
f"Directory traversal, {self.rel_src_path} points outside of parent directory"
|
||||
)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
out_dir: Path = out_path.parent
|
||||
out_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if type_files is not None:
|
||||
for in_path, out_name in type_files:
|
||||
if out_name is None:
|
||||
out_name = in_path.stem
|
||||
self.generate_stubs(in_path, out_dir / f"{out_name}.py")
|
||||
|
||||
module: ast.AST = self.generate_ast(typed_ast)
|
||||
compiled: str = ast.unparse(module)
|
||||
|
||||
out_path.write_text(compiled)
|
||||
return out_path
|
||||
|
||||
def generate_stubs(self, in_path: Path, out_path: Path):
|
||||
checker = TypeChecker()
|
||||
checker.import_midas(in_path)
|
||||
generator = StubsGenerator(checker.types)
|
||||
module: ast.Module = generator.generate_stubs()
|
||||
module = ast.fix_missing_locations(module)
|
||||
output: str = ast.unparse(module)
|
||||
out_path.write_text(output)
|
||||
|
||||
def convert(self, expr: p.Expr) -> ast.expr:
|
||||
for expr2, alias in self._aliases:
|
||||
if expr2 == expr:
|
||||
return alias
|
||||
assertions = self._typed_ast.assertions.get_assertions_for(expr)
|
||||
if len(assertions) != 0:
|
||||
return self._apply_assertions(expr, assertions)
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> ast.expr:
|
||||
return ast.BinOp(
|
||||
left=expr.left.accept(self),
|
||||
left=self.convert(expr.left),
|
||||
op=expr.operator,
|
||||
right=expr.right.accept(self),
|
||||
right=self.convert(expr.right),
|
||||
)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> ast.expr:
|
||||
return ast.Compare(
|
||||
left=expr.left.accept(self),
|
||||
left=self.convert(expr.left),
|
||||
ops=[expr.operator],
|
||||
comparators=[expr.right.accept(self)],
|
||||
comparators=[self.convert(expr.right)],
|
||||
)
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> ast.expr:
|
||||
return ast.UnaryOp(
|
||||
op=expr.operator,
|
||||
operand=expr.right.accept(self),
|
||||
operand=self.convert(expr.right),
|
||||
)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=expr.callee.accept(self),
|
||||
args=[arg.accept(self) for arg in expr.arguments],
|
||||
func=self.convert(expr.callee),
|
||||
args=[self.convert(arg) for arg in expr.arguments],
|
||||
keywords=[
|
||||
ast.keyword(arg=name, value=arg.accept(self))
|
||||
ast.keyword(arg=name, value=self.convert(arg))
|
||||
for name, arg in expr.keywords.items()
|
||||
],
|
||||
)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> ast.expr:
|
||||
return ast.Attribute(
|
||||
value=expr.object.accept(self),
|
||||
value=self.convert(expr.object),
|
||||
attr=expr.name,
|
||||
)
|
||||
|
||||
@@ -127,51 +188,58 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> ast.expr:
|
||||
return ast.BoolOp(
|
||||
op=expr.operator,
|
||||
values=[expr.left.accept(self), expr.right.accept(self)],
|
||||
values=[self.convert(expr.left), self.convert(expr.right)],
|
||||
)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
||||
expr2: ast.expr = expr.expr.accept(self)
|
||||
expr2: ast.expr = self.convert(expr.expr)
|
||||
|
||||
if expr in self._typed_ast.evaluated_casts or expr.unsafe:
|
||||
return expr2
|
||||
|
||||
alias: ast.expr = self._make_alias(expr2)
|
||||
alias: ast.expr = self._make_alias(expr.expr, expr2)
|
||||
|
||||
type: Type = self._get_expr_type(expr)
|
||||
self._make_cast_asserts(expr.location, alias, type)
|
||||
asserts: list[ast.stmt] = self._make_cast_asserts(expr.location, alias, type)
|
||||
for assert_ in asserts:
|
||||
self._add_assert(assert_)
|
||||
|
||||
return alias
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
|
||||
return ast.IfExp(
|
||||
test=expr.test.accept(self),
|
||||
body=expr.if_true.accept(self),
|
||||
orelse=expr.if_false.accept(self),
|
||||
test=self.convert(expr.test),
|
||||
body=self.convert(expr.if_true),
|
||||
orelse=self.convert(expr.if_false),
|
||||
)
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> ast.expr:
|
||||
return ast.List(
|
||||
elts=[item.accept(self) for item in expr.items],
|
||||
elts=[self.convert(item) for item in expr.items],
|
||||
)
|
||||
|
||||
def visit_dict_expr(self, expr: p.DictExpr) -> ast.expr:
|
||||
return ast.Dict(
|
||||
keys=[key.accept(self) if key is not None else None for key in expr.keys],
|
||||
values=[value.accept(self) for value in expr.values],
|
||||
keys=[self.convert(key) if key is not None else None for key in expr.keys],
|
||||
values=[self.convert(value) for value in expr.values],
|
||||
)
|
||||
|
||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr:
|
||||
return ast.Subscript(
|
||||
value=expr.object.accept(self),
|
||||
slice=expr.index.accept(self),
|
||||
value=self.convert(expr.object),
|
||||
slice=self.convert(expr.index),
|
||||
)
|
||||
|
||||
def visit_slice_expr(self, expr: p.SliceExpr) -> ast.expr:
|
||||
return ast.Slice(
|
||||
lower=expr.lower.accept(self) if expr.lower is not None else None,
|
||||
upper=expr.upper.accept(self) if expr.upper is not None else None,
|
||||
step=expr.step.accept(self) if expr.step is not None else None,
|
||||
lower=self.convert(expr.lower) if expr.lower is not None else None,
|
||||
upper=self.convert(expr.upper) if expr.upper is not None else None,
|
||||
step=self.convert(expr.step) if expr.step is not None else None,
|
||||
)
|
||||
|
||||
def visit_tuple_expr(self, expr: p.TupleExpr) -> ast.expr:
|
||||
return ast.Tuple(
|
||||
elts=[self.convert(item) for item in expr.items],
|
||||
)
|
||||
|
||||
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
|
||||
@@ -179,28 +247,29 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt:
|
||||
return ast.Expr(
|
||||
value=stmt.expr.accept(self),
|
||||
value=self.convert(stmt.expr),
|
||||
)
|
||||
|
||||
def make_args(self, params: p.ParamSpec) -> ast.arguments:
|
||||
return ast.arguments(
|
||||
posonlyargs=[ast.arg(arg=param.name) for param in params.pos],
|
||||
args=[ast.arg(arg=param.name) for param in params.mixed],
|
||||
kwonlyargs=[ast.arg(arg=param.name) for param in params.kw],
|
||||
defaults=[
|
||||
self.convert(param.default)
|
||||
for param in params.pos + params.mixed
|
||||
if param.default is not None
|
||||
],
|
||||
kw_defaults=[
|
||||
self.convert(param.default) if param.default is not None else None
|
||||
for param in params.kw
|
||||
],
|
||||
)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> ast.stmt:
|
||||
return ast.FunctionDef(
|
||||
name=stmt.name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[ast.arg(arg=arg.name) for arg in stmt.posonlyargs],
|
||||
vararg=None,
|
||||
args=[ast.arg(arg=arg.name) for arg in stmt.args],
|
||||
kwonlyargs=[ast.arg(arg=arg.name) for arg in stmt.kwonlyargs],
|
||||
kwarg=None,
|
||||
defaults=[
|
||||
arg.default.accept(self)
|
||||
for arg in stmt.posonlyargs + stmt.args
|
||||
if arg.default is not None
|
||||
],
|
||||
kw_defaults=[
|
||||
arg.default.accept(self) if arg.default is not None else None
|
||||
for arg in stmt.kwonlyargs
|
||||
],
|
||||
),
|
||||
args=self.make_args(stmt.params),
|
||||
body=self._visit_body(stmt.body),
|
||||
decorator_list=[],
|
||||
)
|
||||
@@ -211,20 +280,20 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> ast.stmt:
|
||||
return ast.Assign(
|
||||
targets=[target.accept(self) for target in stmt.targets],
|
||||
value=stmt.value.accept(self),
|
||||
targets=[self.convert(target) for target in stmt.targets],
|
||||
value=self.convert(stmt.value),
|
||||
)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> ast.stmt:
|
||||
return ast.Return(
|
||||
value=stmt.value.accept(self) if stmt.value is not None else None,
|
||||
value=self.convert(stmt.value) if stmt.value is not None else None,
|
||||
)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> ast.stmt:
|
||||
return ast.If(
|
||||
test=stmt.test.accept(self),
|
||||
test=self.convert(stmt.test),
|
||||
body=self._visit_body(stmt.body),
|
||||
orelse=self._visit_body(stmt.orelse),
|
||||
orelse=self._visit_body(stmt.orelse, can_be_empty=True),
|
||||
)
|
||||
|
||||
def visit_pass(self, stmt: p.Pass) -> ast.stmt:
|
||||
@@ -232,8 +301,8 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def visit_for_stmt(self, stmt: p.ForStmt) -> ast.stmt:
|
||||
return ast.For(
|
||||
target=stmt.target.accept(self),
|
||||
iter=stmt.iterator.accept(self),
|
||||
target=self.convert(stmt.target),
|
||||
iter=self.convert(stmt.iterator),
|
||||
body=self._visit_body(stmt.body),
|
||||
orelse=[],
|
||||
)
|
||||
@@ -241,7 +310,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
def visit_raw_stmt(self, stmt: p.RawStmt) -> ast.stmt:
|
||||
return stmt.stmt
|
||||
|
||||
def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]:
|
||||
def _visit_body(
|
||||
self, stmts: list[p.Stmt], can_be_empty: bool = False
|
||||
) -> list[ast.stmt]:
|
||||
generated: list[ast.stmt] = []
|
||||
for stmt in stmts:
|
||||
scope = Scope()
|
||||
@@ -259,9 +330,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
# Remove redundant pass statements
|
||||
if len(generated) > 1:
|
||||
generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)]
|
||||
if len(generated) == 0 and not can_be_empty:
|
||||
generated = [ast.Pass()]
|
||||
return generated
|
||||
|
||||
def _make_alias(self, expr: ast.expr) -> ast.expr:
|
||||
def _make_alias(self, node: p.Expr, expr: ast.expr) -> ast.expr:
|
||||
name: str = f"__midas_a{self._alias_count}__"
|
||||
alias = ast.Name(id=name)
|
||||
self._alias_count += 1
|
||||
@@ -272,82 +345,182 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
value=expr,
|
||||
)
|
||||
)
|
||||
self._aliases.append((node, alias))
|
||||
return alias
|
||||
|
||||
def _add_assert(self, expr: ast.expr, message: str | ast.expr):
|
||||
def _build_assert(self, expr: ast.expr, message: str | ast.expr) -> ast.stmt:
|
||||
if isinstance(message, str):
|
||||
message = ast.Constant(value=message)
|
||||
self._scopes[-1].pre_assertions.append(
|
||||
ast.Assert(
|
||||
test=expr,
|
||||
msg=message,
|
||||
)
|
||||
return ast.Assert(
|
||||
test=expr,
|
||||
msg=message,
|
||||
)
|
||||
|
||||
def _add_assert(self, assertion: ast.stmt):
|
||||
self._scopes[-1].pre_assertions.append(assertion)
|
||||
|
||||
def _get_expr_type(self, query: p.Expr) -> Type:
|
||||
for expr, type in self._typed_ast.judgements:
|
||||
if expr == query:
|
||||
return type
|
||||
raise RuntimeError(f"Cannot get type judgement for {query}")
|
||||
|
||||
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
|
||||
def _make_cast_asserts(
|
||||
self, src_location: Location, expr: ast.expr, type: Type
|
||||
) -> list[ast.stmt]:
|
||||
match type:
|
||||
case UnknownType():
|
||||
pass
|
||||
case UnknownType() | TopType():
|
||||
return []
|
||||
|
||||
case BaseType(name=name):
|
||||
self._add_assert(
|
||||
ast.Call(
|
||||
func=ast.Name(id="isinstance"),
|
||||
args=[expr, ast.Name(id=name)],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_cast_assert_message(src_location, expr, type),
|
||||
)
|
||||
return [
|
||||
self._build_assert(
|
||||
ast.Call(
|
||||
func=ast.Name(id="isinstance"),
|
||||
args=[expr, ast.Name(id=name)],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_cast_assert_message(src_location, expr, type),
|
||||
)
|
||||
]
|
||||
|
||||
case DerivedType(type=base):
|
||||
self._make_cast_asserts(src_location, expr, base)
|
||||
return self._make_cast_asserts(src_location, expr, base)
|
||||
|
||||
case UnitType():
|
||||
self._add_assert(
|
||||
ast.Compare(
|
||||
left=expr,
|
||||
ops=[ast.Is()],
|
||||
comparators=[
|
||||
ast.Constant(value=None),
|
||||
],
|
||||
return [
|
||||
self._build_assert(
|
||||
ast.Compare(
|
||||
left=expr,
|
||||
ops=[ast.Is()],
|
||||
comparators=[
|
||||
ast.Constant(value=None),
|
||||
],
|
||||
),
|
||||
self._make_cast_assert_message(src_location, expr, type),
|
||||
),
|
||||
self._make_cast_assert_message(src_location, expr, type),
|
||||
)
|
||||
]
|
||||
|
||||
case AppliedType(body=body):
|
||||
self._make_cast_asserts(src_location, expr, body)
|
||||
return self._make_cast_asserts(src_location, expr, body)
|
||||
|
||||
case ConstraintType(type=base, constraint=constraint):
|
||||
self._make_cast_asserts(src_location, expr, base)
|
||||
self._make_constraint_assert(src_location, expr, constraint)
|
||||
asserts: list[ast.stmt] = self._make_cast_asserts(
|
||||
src_location, expr, base
|
||||
)
|
||||
asserts.append(
|
||||
self._make_constraint_assert(src_location, expr, constraint)
|
||||
)
|
||||
return asserts
|
||||
|
||||
case TypeVar(bound=bound):
|
||||
# TODO: check with type from arguments / use call-site context
|
||||
if bound is not None:
|
||||
self._make_cast_asserts(src_location, expr, bound)
|
||||
if bound is None:
|
||||
return []
|
||||
return self._make_cast_asserts(src_location, expr, bound)
|
||||
|
||||
case TupleType(items=items):
|
||||
asserts: list[ast.stmt] = [
|
||||
self._build_assert(
|
||||
ast.Call(
|
||||
func=ast.Name(id="isinstance"),
|
||||
args=[expr, ast.Name(id="tuple")],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_cast_assert_message(src_location, expr, type),
|
||||
),
|
||||
]
|
||||
assert isinstance(expr, ast.Tuple)
|
||||
for item, item_type in zip(expr.elts, items):
|
||||
asserts.extend(
|
||||
self._make_cast_asserts(src_location, item, item_type)
|
||||
)
|
||||
return asserts
|
||||
|
||||
case DataFrameType(columns=columns):
|
||||
self.define_is_dataframe = True
|
||||
asserts: list[ast.stmt] = [
|
||||
self._build_assert(
|
||||
ast.Call(
|
||||
func=ast.Name(id=self.IS_DATAFRAME_FUNC),
|
||||
args=[expr],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_cast_assert_message(
|
||||
src_location, expr, type, ": Not a dataframe"
|
||||
),
|
||||
),
|
||||
]
|
||||
for column in columns:
|
||||
asserts.append(
|
||||
self._build_assert(
|
||||
ast.Compare(
|
||||
left=ast.Constant(value=column.name),
|
||||
ops=[ast.In()],
|
||||
comparators=[expr],
|
||||
),
|
||||
self._make_cast_assert_message(
|
||||
src_location,
|
||||
expr,
|
||||
type,
|
||||
f": Missing column {column.name}",
|
||||
),
|
||||
)
|
||||
)
|
||||
asserts.extend(
|
||||
self._make_cast_asserts(
|
||||
src_location,
|
||||
ast.Subscript(
|
||||
value=expr, slice=ast.Constant(value=column.name)
|
||||
),
|
||||
column.type,
|
||||
)
|
||||
)
|
||||
return asserts
|
||||
|
||||
case ColumnType():
|
||||
self.define_is_column = True
|
||||
asserts: list[ast.stmt] = [
|
||||
self._build_assert(
|
||||
ast.Call(
|
||||
func=ast.Name(id=self.IS_COLUMN_FUNC),
|
||||
args=[expr],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_cast_assert_message(
|
||||
src_location, expr, type, ": Not a column"
|
||||
),
|
||||
),
|
||||
]
|
||||
inner_assert: Optional[ast.stmt] = self._make_column_inner_assert(
|
||||
src_location, expr, type
|
||||
)
|
||||
if inner_assert is not None:
|
||||
asserts.append(inner_assert)
|
||||
return asserts
|
||||
|
||||
case (
|
||||
TopType()
|
||||
| Function()
|
||||
Function()
|
||||
| OverloadedFunction()
|
||||
| ComplexType()
|
||||
| ExtensionType()
|
||||
| GenericType()
|
||||
| FrameGroupBy()
|
||||
| ColumnGroupBy()
|
||||
):
|
||||
raise NotImplementedError(f"Can't make assertion for type {type}")
|
||||
self.logger.warning(f"Can't make assertion for type {type}")
|
||||
return []
|
||||
|
||||
# Ensure exhaustiveness
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
def _make_cast_assert_message(
|
||||
self, location: Location, expr: ast.expr, type: Type
|
||||
self,
|
||||
location: Location,
|
||||
expr: ast.expr,
|
||||
type: Type,
|
||||
extra: Optional[str] = None,
|
||||
) -> ast.expr:
|
||||
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
|
||||
@@ -365,15 +538,15 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
),
|
||||
conversion=-1,
|
||||
),
|
||||
ast.Constant(f" to {type}"),
|
||||
ast.Constant(f" to {type}{extra or ''}"),
|
||||
]
|
||||
)
|
||||
|
||||
def _make_constraint_assert(
|
||||
self, src_location: Location, expr: ast.expr, constraint: m.Expr
|
||||
):
|
||||
) -> ast.stmt:
|
||||
test_func: ast.expr = self._get_constraint(constraint)
|
||||
self._add_assert(
|
||||
return self._build_assert(
|
||||
ast.Call(
|
||||
func=test_func,
|
||||
args=[expr],
|
||||
@@ -401,3 +574,117 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
constraint: ast.expr = self._constraint_generator.generate(expr)
|
||||
self._constraints.append((expr, constraint))
|
||||
return constraint
|
||||
|
||||
def _is_dataframe_definition(self) -> ast.stmt:
|
||||
"""
|
||||
def IS_DATAFRAME_FUNC(obj) -> bool:
|
||||
import pandas as pd
|
||||
return isinstance(obj, pd.DataFrame)
|
||||
"""
|
||||
|
||||
return ast.FunctionDef(
|
||||
name=self.IS_DATAFRAME_FUNC,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[ast.arg(arg="obj")],
|
||||
args=[],
|
||||
kwonlyargs=[],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
),
|
||||
body=[
|
||||
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
|
||||
ast.Return(
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="isinstance"),
|
||||
args=[
|
||||
ast.Name(id="obj"),
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="DataFrame",
|
||||
),
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
),
|
||||
],
|
||||
decorator_list=[],
|
||||
returns=ast.Name(id="bool"),
|
||||
)
|
||||
|
||||
def _is_column_definition(self) -> ast.stmt:
|
||||
"""
|
||||
def IS_COLUMN_FUNC(obj) -> bool:
|
||||
import pandas as pd
|
||||
return isinstance(obj, pd.Series)
|
||||
"""
|
||||
|
||||
return ast.FunctionDef(
|
||||
name=self.IS_COLUMN_FUNC,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[ast.arg(arg="obj")],
|
||||
args=[],
|
||||
kwonlyargs=[],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
),
|
||||
body=[
|
||||
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
|
||||
ast.Return(
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="isinstance"),
|
||||
args=[
|
||||
ast.Name(id="obj"),
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="Series",
|
||||
),
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
),
|
||||
],
|
||||
decorator_list=[],
|
||||
returns=ast.Name(id="bool"),
|
||||
)
|
||||
|
||||
def _make_column_inner_assert(
|
||||
self, src_location: Location, column: ast.expr, type: ColumnType
|
||||
) -> Optional[ast.stmt]:
|
||||
# TODO: improve message, maybe chain contexts
|
||||
col: ast.expr = ast.Name(id="col")
|
||||
body: list[ast.stmt] = self._make_cast_asserts(src_location, col, type.type)
|
||||
if len(body) == 0:
|
||||
return None
|
||||
return ast.For(
|
||||
target=col,
|
||||
iter=column,
|
||||
body=body,
|
||||
orelse=[],
|
||||
)
|
||||
|
||||
def _convert_assertion(self, assertion: Assertion) -> ast.stmt:
|
||||
inputs: list[ast.expr] = []
|
||||
|
||||
for input in assertion.inputs:
|
||||
converted: ast.expr = self.convert(input)
|
||||
alias: ast.expr = self._make_alias(input, converted)
|
||||
inputs.append(alias)
|
||||
|
||||
test: ast.expr = assertion.builder(*inputs)
|
||||
location: Location = assertion.bound_expr.location
|
||||
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||
return self._build_assert(
|
||||
test, f"{loc_str}: AssertionError: {assertion.message}"
|
||||
)
|
||||
|
||||
def _apply_assertions(self, expr: p.Expr, assertions: list[Assertion]) -> ast.expr:
|
||||
for assertion in assertions:
|
||||
assert_stmt: ast.stmt
|
||||
assert_stmt = self._convert_assertion(assertion)
|
||||
self._add_assert(assert_stmt)
|
||||
|
||||
# Mutating list in frozen dataclass
|
||||
# Not ideal but easiest way to avoid duplicate assertions
|
||||
self._typed_ast.assertions.remove(assertion)
|
||||
|
||||
return expr.accept(self)
|
||||
|
||||
@@ -6,14 +6,20 @@ from midas.checker.registry import Member, TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
DataFrameType,
|
||||
DerivedType,
|
||||
ExtensionType,
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
TupleType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
@@ -30,6 +36,7 @@ class StubsGenerator:
|
||||
self.types: TypesRegistry = types
|
||||
self.stubs: list[ast.stmt] = []
|
||||
self.typing_imports: set[str] = set()
|
||||
self.import_pandas: bool = False
|
||||
self.protocol_idx: int = 0
|
||||
self.stub_idx: int = 0
|
||||
self.type_var_idx: int = 0
|
||||
@@ -38,6 +45,7 @@ class StubsGenerator:
|
||||
def generate_stubs(self) -> ast.Module:
|
||||
self.stubs = []
|
||||
self.typing_imports = set()
|
||||
self.import_pandas = False
|
||||
for name, type in self.types._types.items():
|
||||
# Skip builtin types, not just based on name so the user can override
|
||||
# TODO: check if added members on builtin type
|
||||
@@ -53,7 +61,7 @@ class StubsGenerator:
|
||||
continue
|
||||
self.generate_stub(name, type)
|
||||
|
||||
imports = [
|
||||
imports: list[ast.stmt] = [
|
||||
ast.ImportFrom(
|
||||
module="__future__",
|
||||
names=[ast.alias(name="annotations")],
|
||||
@@ -70,11 +78,37 @@ class StubsGenerator:
|
||||
level=0,
|
||||
)
|
||||
)
|
||||
if self.import_pandas:
|
||||
imports.append(
|
||||
ast.Import(
|
||||
names=[
|
||||
ast.alias(
|
||||
name="pandas",
|
||||
asname="pd",
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
return ast.Module(body=imports + self.stubs, type_ignores=[])
|
||||
|
||||
def generate_stub(self, name: str, type: Type):
|
||||
base_type: Type = type
|
||||
|
||||
# TODO: improve
|
||||
match type:
|
||||
case DerivedType(name=name_) | GenericType(name=name_) if name_ == name:
|
||||
pass
|
||||
case UnitType() if name == "None":
|
||||
pass
|
||||
case TopType() if name == "Any":
|
||||
pass
|
||||
case _:
|
||||
alias = ast.Assign(
|
||||
targets=[ast.Name(id=name)], value=self.dump_type(type)
|
||||
)
|
||||
self.add_stub(alias)
|
||||
return
|
||||
|
||||
members: dict[str, Member] = self.types._members.get(name, {})
|
||||
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
|
||||
return
|
||||
@@ -231,6 +265,57 @@ class StubsGenerator:
|
||||
case ConstraintType():
|
||||
return self.dump_type(type.type)
|
||||
|
||||
case TupleType(items=items):
|
||||
return ast.Subscript(
|
||||
value=ast.Name(id="tuple"),
|
||||
slice=ast.Tuple(
|
||||
elts=[self.dump_type(item) for item in items],
|
||||
),
|
||||
)
|
||||
|
||||
case ColumnType(type=inner):
|
||||
self.import_pandas = True
|
||||
return ast.Subscript(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="Series",
|
||||
),
|
||||
slice=self.dump_type(inner),
|
||||
)
|
||||
|
||||
case DataFrameType():
|
||||
self.import_pandas = True
|
||||
return ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="DataFrame",
|
||||
)
|
||||
|
||||
case FrameGroupBy():
|
||||
self.import_pandas = True
|
||||
return ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="api",
|
||||
),
|
||||
attr="typing",
|
||||
),
|
||||
attr="DataFrameGroupBy",
|
||||
)
|
||||
|
||||
case ColumnGroupBy():
|
||||
self.import_pandas = True
|
||||
return ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="pd"),
|
||||
attr="api",
|
||||
),
|
||||
attr="typing",
|
||||
),
|
||||
attr="SeriesGroupBy",
|
||||
)
|
||||
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
@@ -244,7 +329,7 @@ class StubsGenerator:
|
||||
return [
|
||||
ast.FunctionDef(
|
||||
name=name,
|
||||
args=self.dump_args(method, with_self=True),
|
||||
args=self.dump_params(method.params, with_self=True),
|
||||
returns=self.dump_type(method.returns),
|
||||
body=[ast.Expr(value=Empty)],
|
||||
decorator_list=[ast.Name(id="overload")] if overloaded else [],
|
||||
@@ -264,24 +349,33 @@ class StubsGenerator:
|
||||
)
|
||||
]
|
||||
|
||||
def dump_args(self, func: Function, with_self: bool = False) -> ast.arguments:
|
||||
def dump_params(self, params: ParamSpec, with_self: bool = False) -> ast.arguments:
|
||||
pos: list[ast.arg] = [
|
||||
ast.arg(arg=f"_{arg.pos}", annotation=self.dump_type(arg.type))
|
||||
for arg in func.pos_args
|
||||
ast.arg(
|
||||
arg=f"_{param.pos}",
|
||||
annotation=self.dump_type(param.type),
|
||||
)
|
||||
for param in params.pos
|
||||
]
|
||||
mixed: list[ast.arg] = [
|
||||
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
|
||||
for arg in func.args
|
||||
ast.arg(
|
||||
arg=param.name,
|
||||
annotation=self.dump_type(param.type),
|
||||
)
|
||||
for param in params.mixed
|
||||
]
|
||||
kw: list[ast.arg] = [
|
||||
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
|
||||
for arg in func.kw_args
|
||||
ast.arg(
|
||||
arg=param.name,
|
||||
annotation=self.dump_type(param.type),
|
||||
)
|
||||
for param in params.kw
|
||||
]
|
||||
defaults: list[ast.expr] = [
|
||||
Empty for arg in func.pos_args + func.args if not arg.required
|
||||
Empty for param in params.pos + params.mixed if not param.required
|
||||
]
|
||||
kw_defaults: list[Optional[ast.expr]] = [
|
||||
None if arg.required else Empty for arg in func.kw_args
|
||||
None if param.required else Empty for param in params.kw
|
||||
]
|
||||
if with_self:
|
||||
arg = ast.arg(arg="self", annotation=None)
|
||||
@@ -307,7 +401,7 @@ class StubsGenerator:
|
||||
body=[
|
||||
ast.FunctionDef(
|
||||
name="__call__",
|
||||
args=self.dump_args(func, with_self=True),
|
||||
args=self.dump_params(func.params, with_self=True),
|
||||
returns=self.dump_type(func.returns),
|
||||
body=[ast.Expr(value=Empty)],
|
||||
decorator_list=[],
|
||||
|
||||
@@ -16,9 +16,10 @@ class Lexer(ABC):
|
||||
"""An abstract lexer which provides methods to easily extend it into a concrete one
|
||||
|
||||
This implementation is based on the [_Crafting Interpreters_][1] book by Robert Nystrom,
|
||||
more specifically on my [previous Python implementation](https://git.kb28.ch/HEL/pebble)
|
||||
more specifically on my [previous Python implementation][2]
|
||||
|
||||
[1]: https://craftinginterpreters.com/
|
||||
[2]: https://git.kb28.ch/HEL/pebble
|
||||
"""
|
||||
|
||||
def __init__(self, source: str, file: Optional[str] = None) -> None:
|
||||
@@ -168,6 +169,6 @@ class Lexer(ABC):
|
||||
def scan_token(self) -> None:
|
||||
"""Scan a token
|
||||
|
||||
This function should (at least) consume the current character and produce the appropriate token(s), using `add_token`
|
||||
This function should (at least) consume the current character and produce the appropriate token(s), using :func:`add_token`
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -46,8 +46,8 @@ class MidasLexer(Lexer):
|
||||
self.add_token(TokenType.UNDERSCORE)
|
||||
case "-" if self.match(">"):
|
||||
self.add_token(TokenType.ARROW)
|
||||
# case "+":
|
||||
# self.add_token(TokenType.PLUS)
|
||||
case "+":
|
||||
self.add_token(TokenType.PLUS)
|
||||
case "-":
|
||||
self.add_token(TokenType.MINUS)
|
||||
case "*":
|
||||
@@ -81,6 +81,12 @@ class MidasLexer(Lexer):
|
||||
return None
|
||||
|
||||
def scan_string(self, opening: str):
|
||||
"""Scan the rest of a string and add it as a token
|
||||
|
||||
Args:
|
||||
opening (str): the opening quote or double quote, to be matched
|
||||
at the end of the string
|
||||
"""
|
||||
while self.peek() != opening and not self.is_at_end():
|
||||
self.advance()
|
||||
|
||||
@@ -147,6 +153,18 @@ class MidasLexer(Lexer):
|
||||
self.add_token(TokenType.COMMENT)
|
||||
|
||||
def is_identifier_char(self, char: str, *, start: bool) -> bool:
|
||||
"""Check whether a character is a valid as part of an identifier
|
||||
|
||||
Identifiers can contain any alphanumerical character or underscore.
|
||||
They cannot start with a digit.
|
||||
|
||||
Args:
|
||||
char (str): the character to check
|
||||
start (bool): whether this is the first character of the identifier
|
||||
|
||||
Returns:
|
||||
bool: `True` if the character is valid, `False` otherwise
|
||||
"""
|
||||
if char == "_":
|
||||
return True
|
||||
if char.isalpha():
|
||||
|
||||
@@ -25,7 +25,7 @@ class TokenType(Enum):
|
||||
DOT = auto()
|
||||
|
||||
# Operators
|
||||
# PLUS = auto()
|
||||
PLUS = auto()
|
||||
MINUS = auto()
|
||||
STAR = auto()
|
||||
SLASH = auto()
|
||||
@@ -104,6 +104,15 @@ class Token:
|
||||
)
|
||||
|
||||
def location_to(self, to: Token) -> Location:
|
||||
"""Create a new :class:`Location` spanning from this token to another
|
||||
|
||||
Args:
|
||||
to (Token): the end token
|
||||
|
||||
Returns:
|
||||
Location: a new :class:`Location` starting at this token and ending
|
||||
at `to`, both included
|
||||
"""
|
||||
return Location.span(self.get_location(), to.get_location())
|
||||
|
||||
@property
|
||||
|
||||
@@ -16,6 +16,9 @@ class TokenError:
|
||||
def get_report(self) -> str:
|
||||
"""Get a detailed error message
|
||||
|
||||
The error message is formatted as "(<position>) Error at <token>: <message>".
|
||||
For example: "(L2:5) Error at '3': Expected ')' after arguments."
|
||||
|
||||
Returns:
|
||||
str: the complete error message
|
||||
"""
|
||||
@@ -32,9 +35,10 @@ class Parser(ABC, Generic[T]):
|
||||
"""An abstract parser which provides methods to easily extend it into a concrete one
|
||||
|
||||
This implementation is based on the [_Crafting Interpreters_][1] book by Robert Nystrom,
|
||||
more specifically on my [previous Python implementation](https://git.kb28.ch/HEL/pebble)
|
||||
more specifically on my [previous Python implementation][2]
|
||||
|
||||
[1]: https://craftinginterpreters.com/
|
||||
[2]: https://git.kb28.ch/HEL/pebble
|
||||
"""
|
||||
|
||||
IGNORE: set[TokenType] = {
|
||||
@@ -173,7 +177,7 @@ class Parser(ABC, Generic[T]):
|
||||
error_msg (str): the error message if the token doesn't match
|
||||
|
||||
Raises:
|
||||
SyntaxError: if the current token doesn't match the given type
|
||||
ParsingError: if the current token doesn't match the given type
|
||||
|
||||
Returns:
|
||||
Token: the current token which matched the given type
|
||||
|
||||
@@ -10,6 +10,7 @@ from midas.ast.midas import (
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
ExtensionType,
|
||||
FrameType,
|
||||
FunctionType,
|
||||
GenericType,
|
||||
GetExpr,
|
||||
@@ -34,10 +35,11 @@ from midas.parser.base import Parser
|
||||
from midas.parser.errors import ParsingError
|
||||
|
||||
|
||||
class MidasParser(Parser):
|
||||
class MidasParser(Parser[list[Stmt]]):
|
||||
"""A simple parser for midas type definitions"""
|
||||
|
||||
SYNC_BOUNDARY: set[TokenType] = {
|
||||
TokenType.ALIAS,
|
||||
TokenType.TYPE,
|
||||
TokenType.EXTEND,
|
||||
TokenType.PREDICATE,
|
||||
@@ -72,10 +74,10 @@ class MidasParser(Parser):
|
||||
def declaration(self) -> Optional[Stmt]:
|
||||
"""Try and parse a declaration
|
||||
|
||||
Any parsing error is caught and None is returned
|
||||
Any parsing error is caught and `None` is returned
|
||||
|
||||
Returns:
|
||||
Optional[Stmt]: the parsed Midas statement, or None if a ParsingError was raised
|
||||
Optional[Stmt]: the parsed Midas statement, or `None` if a ParsingError was raised
|
||||
"""
|
||||
try:
|
||||
if self.match(TokenType.TYPE):
|
||||
@@ -94,23 +96,14 @@ class MidasParser(Parser):
|
||||
def type_declaration(self) -> TypeStmt:
|
||||
"""Parse a type declaration
|
||||
|
||||
A type declaration can either be a simple type alias or a new complex type.
|
||||
In either case, it can have an optional template expression after its name, wrapped in brackets.
|
||||
A simple type alias is derived from a base type expression, and can have a optional constraint expression preceded by the `where` keyword.
|
||||
A full simple type alias is thus written:
|
||||
```
|
||||
type Name[Template](TypeExpr) where Condition
|
||||
```
|
||||
A type declaration creates a named subtype of a type expression.
|
||||
It can have an optional template expression after its name, wrapped in brackets, to handle type parameters.
|
||||
|
||||
A new complex type has a set of properties which are named, have a type and an optional constraint expression (also preceded by the `where` keyword).
|
||||
A full complex type definition is thus written:
|
||||
```
|
||||
type Name[Template] {
|
||||
prop1: TypeExpr1 where Condition1
|
||||
prop2: TypeExpr2 where Condition2
|
||||
...
|
||||
}
|
||||
```
|
||||
A type statement consists of:
|
||||
- the `type` keyword
|
||||
- a name (identifier)
|
||||
- (optional) type parameters
|
||||
- a body, a type expression (see :func:`type_expr`)
|
||||
|
||||
Returns:
|
||||
TypeStmt: the parsed type declaration statement
|
||||
@@ -164,11 +157,16 @@ class MidasParser(Parser):
|
||||
def alias_declaration(self) -> AliasStmt:
|
||||
"""Parse an alias declaration
|
||||
|
||||
An alias statement consists of:
|
||||
- the `alias` keyword
|
||||
- a name (identifier)
|
||||
- a body, a type expression (see :func:`type_expr`)
|
||||
|
||||
Returns:
|
||||
AliasStmt: the parsed alias declaration statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume_identifier("Expected type name")
|
||||
name: Token = self.consume_identifier("Expected alias name")
|
||||
|
||||
self.consume(TokenType.EQUAL, "Expected '=' before alias definition")
|
||||
|
||||
@@ -183,8 +181,8 @@ class MidasParser(Parser):
|
||||
def type_expr(self) -> Type:
|
||||
"""Parse a type expression
|
||||
|
||||
A type is an identifier, optionally followed by a template expression.
|
||||
It can also optionally be followed by a '?' to indicate a nullable type
|
||||
A type expression can either be a function type (see :func:`function`)
|
||||
or a constraint type (see :func:`constraint_type`)
|
||||
|
||||
Returns:
|
||||
TypeExpr: the parsed type expression
|
||||
@@ -204,6 +202,15 @@ class MidasParser(Parser):
|
||||
return base
|
||||
|
||||
def constraint_type(self) -> Type:
|
||||
"""Parse a constraint type expression
|
||||
|
||||
A constraint type consists of a base type (see :func:`base_type`),
|
||||
optionally followed by the `where` keyword and a constraint
|
||||
expression (see :func:`constraint`)
|
||||
|
||||
Returns:
|
||||
Type: the parsed constraint type expression
|
||||
"""
|
||||
type: Type = self.base_type()
|
||||
if self.match(TokenType.WHERE):
|
||||
constraint: Expr = self.constraint()
|
||||
@@ -215,6 +222,14 @@ class MidasParser(Parser):
|
||||
return type
|
||||
|
||||
def base_type(self) -> Type:
|
||||
"""Parse a base type expression
|
||||
|
||||
A base type is either a parenthesized type expression (see :func:`type_expr`)
|
||||
or a generic type (see :func:`generic_type`)
|
||||
|
||||
Returns:
|
||||
Type: the parsed base type expression
|
||||
"""
|
||||
if self.match(TokenType.LEFT_PAREN):
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
|
||||
@@ -226,8 +241,21 @@ class MidasParser(Parser):
|
||||
return self.generic_type()
|
||||
|
||||
def generic_type(self) -> Type:
|
||||
type: Type = self.named_type()
|
||||
"""Parse a generic type expression
|
||||
|
||||
A generic type consists of a named type (see :func:`named_type`),
|
||||
optionally followed by type arguments in brackets.
|
||||
|
||||
The special `Frame` type accepts a frame schema instead of type
|
||||
arguments (see :func:`frame_type`).
|
||||
|
||||
Returns:
|
||||
Type: the parsed generic type
|
||||
"""
|
||||
type: NamedType = self.named_type()
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
if type.name.lexeme == "Frame":
|
||||
return self.frame_type()
|
||||
args: list[Type] = self.type_args()
|
||||
return GenericType(
|
||||
location=Location.span(type.location, self.previous().get_location()),
|
||||
@@ -237,6 +265,13 @@ class MidasParser(Parser):
|
||||
return type
|
||||
|
||||
def type_args(self) -> list[Type]:
|
||||
"""Parse a list of type arguments
|
||||
|
||||
Type arguments are a comma-separated list of type expression wrapped in brackets.
|
||||
|
||||
Returns:
|
||||
list[Type]: the list of type arguments, if any, or an empty list
|
||||
"""
|
||||
args: list[Type] = []
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments")
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||
@@ -246,7 +281,14 @@ class MidasParser(Parser):
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
|
||||
return args
|
||||
|
||||
def named_type(self) -> Type:
|
||||
def named_type(self) -> NamedType:
|
||||
"""Parse a named type expression
|
||||
|
||||
A named type is an identifier token
|
||||
|
||||
Returns:
|
||||
NamedType: the parsed named type expression
|
||||
"""
|
||||
name: Token = self.consume_identifier("Expected type name")
|
||||
return NamedType(
|
||||
location=name.get_location(),
|
||||
@@ -254,13 +296,13 @@ class MidasParser(Parser):
|
||||
)
|
||||
|
||||
def complex_type(self) -> ComplexType:
|
||||
"""Parse a type definition body
|
||||
"""Parse a complex type expression
|
||||
|
||||
A type definition body is a set of whitespace-separated
|
||||
property statements enclosed in curly braces
|
||||
A complex type consists of zero or more member statements enclosed in
|
||||
curly braces
|
||||
|
||||
Returns:
|
||||
ComplexType: the parsed complex type
|
||||
ComplexType: the parsed complex type expression
|
||||
"""
|
||||
left: Token = self.consume(
|
||||
TokenType.LEFT_BRACE, "Expected '{' to start type body"
|
||||
@@ -281,10 +323,50 @@ class MidasParser(Parser):
|
||||
members=members,
|
||||
)
|
||||
|
||||
def constraint(self) -> Expr:
|
||||
"""Parse a constraint
|
||||
def frame_type(self) -> FrameType:
|
||||
"""Parse a frame type expression
|
||||
|
||||
A constraint is basically a logical predicate
|
||||
A frame type consists of:
|
||||
- the `Frame` identifier
|
||||
- an opening bracket `[`
|
||||
- a list of comma-separated column expression consisting of:
|
||||
- a name (token)
|
||||
- a colon `:`
|
||||
- a type expression (see :func:`type_expr`)
|
||||
- a closing bracket `]`
|
||||
|
||||
Returns:
|
||||
FrameType: the parsed frame type
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
self.consume(TokenType.LEFT_BRACKET, "Expected '[' to start frame schema")
|
||||
|
||||
columns: list[FrameType.Column] = []
|
||||
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
|
||||
name: Token = self.advance()
|
||||
self.consume(TokenType.COLON, "Expected ':' between column name and type")
|
||||
type: Type = self.type_expr()
|
||||
columns.append(
|
||||
FrameType.Column(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
type=type,
|
||||
)
|
||||
)
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Unclosed frame schema")
|
||||
|
||||
return FrameType(
|
||||
location=keyword.location_to(self.previous()),
|
||||
columns=columns,
|
||||
)
|
||||
|
||||
def constraint(self) -> Expr:
|
||||
"""Parse a constraint expression
|
||||
|
||||
A constraint is an expression (see :func:`expression`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed constraint expression
|
||||
@@ -292,10 +374,20 @@ class MidasParser(Parser):
|
||||
return self.expression()
|
||||
|
||||
def expression(self) -> Expr:
|
||||
"""Parse an expression
|
||||
|
||||
An expression consists of a logical AND expression (see :func:`and_`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
return self.and_()
|
||||
|
||||
def and_(self) -> Expr:
|
||||
"""Parse a logical AND expression or a simpler expression
|
||||
"""Parse a logical AND expression
|
||||
|
||||
An AND consists of one or more equality expressions (see :func:`equality`)
|
||||
separated by logical AND operators (`&`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
@@ -311,7 +403,10 @@ class MidasParser(Parser):
|
||||
return expr
|
||||
|
||||
def equality(self) -> Expr:
|
||||
"""Parse a logical equality expression or a simpler expression
|
||||
"""Parse an equality expression
|
||||
|
||||
An equality consists of one or more comparison expressions (see :func:`comparison`)
|
||||
separated by equality operators (`==`, `!=`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
@@ -327,18 +422,59 @@ class MidasParser(Parser):
|
||||
return expr
|
||||
|
||||
def comparison(self) -> Expr:
|
||||
"""Parse a logical comparison expression or a simpler expression
|
||||
"""Parse a comparison expression
|
||||
|
||||
A comparison consists of one or more term expressions (see :func:`term`)
|
||||
separated by comparison operators (`<`, `<=`, `>`, `>=`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.unary()
|
||||
expr: Expr = self.term()
|
||||
while self.match(
|
||||
TokenType.LESS,
|
||||
TokenType.LESS_EQUAL,
|
||||
TokenType.GREATER,
|
||||
TokenType.GREATER_EQUAL,
|
||||
):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.term()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = BinaryExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
return expr
|
||||
|
||||
def term(self) -> Expr:
|
||||
"""Parse a term expression
|
||||
|
||||
A term consists of one or more factor expressions (see :func:`factor`)
|
||||
separated by weak arithmetic operators (`+`, `-`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.factor()
|
||||
while self.match(TokenType.PLUS, TokenType.MINUS):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.factor()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = BinaryExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
return expr
|
||||
|
||||
def factor(self) -> Expr:
|
||||
"""Parse a factor expression
|
||||
|
||||
A factor consists of one or more unary expressions (see :func:`unary`)
|
||||
separated by strong arithmetic operators (`*`, `/`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.unary()
|
||||
while self.match(TokenType.STAR, TokenType.SLASH):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.unary()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
@@ -348,12 +484,15 @@ class MidasParser(Parser):
|
||||
return expr
|
||||
|
||||
def unary(self) -> Expr:
|
||||
"""Parse a unary expression or a simpler expression
|
||||
"""Parse a unary expression
|
||||
|
||||
A unary consists of a call expression (see :func:`call`) optionally
|
||||
preceded by zero or more unary operators (`+`, `-`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
if self.match(TokenType.MINUS):
|
||||
if self.match(TokenType.PLUS, TokenType.MINUS):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.unary()
|
||||
location: Location = Location.span(operator.get_location(), right.location)
|
||||
@@ -361,16 +500,48 @@ class MidasParser(Parser):
|
||||
return self.call()
|
||||
|
||||
def call(self) -> Expr:
|
||||
"""Parse a call expression
|
||||
|
||||
A call consists of a reference expression (see :func:`reference`)
|
||||
optionally followed by zero or more argument groups.
|
||||
|
||||
Argument groups are parenthesize, comma-separated list of arguments (see :func:`finish_call`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.reference()
|
||||
while self.match(TokenType.LEFT_PAREN):
|
||||
expr = self.finish_call(expr)
|
||||
return expr
|
||||
|
||||
def finish_call(self, callee: Expr) -> Expr:
|
||||
"""Parse an argument group, i.e. the arguments of a call
|
||||
|
||||
Arguments are either passed positionally or by name (keyword argument).
|
||||
All positional arguments must come before any keyword argument and
|
||||
vice-versa. Arguments are separated by commas.
|
||||
|
||||
A positional argument simply consists of an expression (see :func:`expression`)
|
||||
|
||||
A keyword argument consists of and identifier, followed by the equal `=`
|
||||
token and an expression (see :func:`expression`).
|
||||
|
||||
Args:
|
||||
callee (Expr): the callee expression
|
||||
|
||||
Raises:
|
||||
ParsingError: if a positional argument is passed after a keyword
|
||||
argument or if a keyword argument's name is invalid (i.e. not
|
||||
an identifier)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed call expression
|
||||
"""
|
||||
pos_args: list[Expr] = []
|
||||
kw_args: dict[str, Expr] = {}
|
||||
keywords: bool = False
|
||||
while not self.match(TokenType.RIGHT_PAREN):
|
||||
while not self.check(TokenType.RIGHT_PAREN):
|
||||
if self.check_identifier() and self.check_next(TokenType.EQUAL):
|
||||
keywords = True
|
||||
keyword: Token = self.advance()
|
||||
@@ -386,13 +557,14 @@ class MidasParser(Parser):
|
||||
else:
|
||||
value = self.expression()
|
||||
if self.check(TokenType.EQUAL):
|
||||
error_msg: str
|
||||
if keywords:
|
||||
raise self.error(self.peek(), "Invalid keyword argument name")
|
||||
error_msg = "Invalid keyword argument name"
|
||||
else:
|
||||
raise self.error(
|
||||
self.peek(),
|
||||
"Cannot pass positional arguments after a keyword argument",
|
||||
error_msg = (
|
||||
"Cannot pass positional arguments after a keyword argument"
|
||||
)
|
||||
raise self.error(self.peek(), error_msg)
|
||||
pos_args.append(value)
|
||||
|
||||
if not self.match(TokenType.COMMA):
|
||||
@@ -409,7 +581,12 @@ class MidasParser(Parser):
|
||||
)
|
||||
|
||||
def reference(self) -> Expr:
|
||||
"""Parse an attribute access expression or a simpler expression
|
||||
"""Parse a reference expression
|
||||
|
||||
A reference consists of a primary expression (see :func:`primary`)
|
||||
optionally followed by zero or more attribute accesses.
|
||||
|
||||
An attribute access consists of a dot `.` token followed by an identifier
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
@@ -424,7 +601,12 @@ class MidasParser(Parser):
|
||||
def primary(self) -> Expr:
|
||||
"""Parse a primary expression
|
||||
|
||||
This includes literals (booleans, numbers, etc.), wildcards, identifiers and grouped expressions
|
||||
This includes literals (booleans, numbers, etc.), wildcards, identifiers
|
||||
and grouped expressions
|
||||
|
||||
Raises:
|
||||
ParsingError: if a primary expressions cannot be parsed from the
|
||||
following tokens
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
@@ -457,14 +639,41 @@ class MidasParser(Parser):
|
||||
raise self.error(self.peek(), "Expected expression")
|
||||
|
||||
def consume_identifier(self, message: str = "Expected identifier") -> Token:
|
||||
"""Consume the current token if it is a valid identifier or raise an error (see :func:`check_identifier`)
|
||||
|
||||
If the current token is not a valid identifier, an error is raised
|
||||
with the provided message
|
||||
|
||||
Args:
|
||||
message (str, optional): the error message. Defaults to "Expected identifier".
|
||||
|
||||
Raises:
|
||||
ParsingError: if the current token is not a valid identifier
|
||||
|
||||
Returns:
|
||||
Token: the current token which is a valid identifier
|
||||
"""
|
||||
if not self.match_identifier():
|
||||
raise self.error(self.peek(), message)
|
||||
return self.previous()
|
||||
|
||||
def match_identifier(self) -> bool:
|
||||
"""Consume the next token if it is a valid identifier (see :func:`check_identifier`)
|
||||
|
||||
Returns:
|
||||
bool: whether a token was matched and consumed
|
||||
"""
|
||||
return self.match(TokenType.IDENTIFIER, *KEYWORDS.values())
|
||||
|
||||
def check_identifier(self) -> bool:
|
||||
"""Check whether the current token is a valid identifier
|
||||
|
||||
A valid identifier is either an identifier token or a keyword token.
|
||||
This function always returns False if the parser is at the EOF token
|
||||
|
||||
Returns:
|
||||
bool: True if the current token is a valid identifier and not EOF
|
||||
"""
|
||||
for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]:
|
||||
if self.check(tt):
|
||||
return True
|
||||
@@ -473,7 +682,14 @@ class MidasParser(Parser):
|
||||
def member_stmt(self) -> MemberStmt:
|
||||
"""Parse a member statement
|
||||
|
||||
A type member statement is written `prop name: Type` or `def name: Type`
|
||||
A member statement is written consists of:
|
||||
- the `prop` (for a property) or `def` (for a method) keyword
|
||||
- an name (identifier)
|
||||
- a colon `:`
|
||||
- a type expression (see :func:`type_expr`)
|
||||
|
||||
Raises:
|
||||
ParsingError: if the first token is neither `prop` nor `def`
|
||||
|
||||
Returns:
|
||||
MemberStmt: the parsed member statement
|
||||
@@ -500,7 +716,13 @@ class MidasParser(Parser):
|
||||
def extend_declaration(self) -> ExtendStmt:
|
||||
"""Parse an extension definition
|
||||
|
||||
An extension is written `extend Type { operations }` or `extend[S <: T, U] Type { operations }`
|
||||
An extension statement consists of:
|
||||
- the `extend` keyword
|
||||
- a type name (identifier)
|
||||
- (optional) type parameters (see :func:`type_params`)
|
||||
- an opening brace `{`
|
||||
- zero or more member statements (see :func:`member_stmt`)
|
||||
- a closing brace `}`
|
||||
|
||||
Returns:
|
||||
ExtendStmt: the parsed extension statement
|
||||
@@ -525,7 +747,12 @@ class MidasParser(Parser):
|
||||
def predicate_declaration(self) -> PredicateStmt:
|
||||
"""Parse a predicate declaration
|
||||
|
||||
A predicate is written `predicate Name(subject: Type) = constraint_expression`
|
||||
A predicate statement consists of:
|
||||
- the `predicate` keyword
|
||||
- a name (identifier)
|
||||
- (optional) zero or more parameter specs (see :func:`function_params`)
|
||||
- an equal sign `=`
|
||||
- a body, a constraint expression (see :func:`constraint`)
|
||||
|
||||
Returns:
|
||||
PredicateStmt: the parsed predicate declaration statement
|
||||
@@ -536,7 +763,7 @@ class MidasParser(Parser):
|
||||
|
||||
params: list[ParamSpec] = []
|
||||
while self.check(TokenType.LEFT_PAREN):
|
||||
params.append(self.function_args())
|
||||
params.append(self.function_params())
|
||||
|
||||
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
|
||||
body: Expr = self.constraint()
|
||||
@@ -548,7 +775,18 @@ class MidasParser(Parser):
|
||||
)
|
||||
|
||||
def function(self) -> FunctionType:
|
||||
params: ParamSpec = self.function_args()
|
||||
"""Parse a function type expression
|
||||
|
||||
A function consists of:
|
||||
- the `fn` keyword
|
||||
- a parameter spec (see :func:`function_params`)
|
||||
- the arrow keyword `->`
|
||||
- a result type expression (see :func:`type_expr`)
|
||||
|
||||
Returns:
|
||||
FunctionType: the parsed function type expression
|
||||
"""
|
||||
params: ParamSpec = self.function_params()
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
result: Type = self.type_expr()
|
||||
@@ -559,36 +797,53 @@ class MidasParser(Parser):
|
||||
returns=result,
|
||||
)
|
||||
|
||||
def function_args(self) -> ParamSpec:
|
||||
def function_params(self) -> ParamSpec:
|
||||
"""Parse a parameter spec
|
||||
|
||||
A parameter spec consists of zero or more comma-separated parameters,
|
||||
wrapped in parentheses.
|
||||
|
||||
Like in Python, it can contain positional-only, mixed and keyword-only
|
||||
parameters (separated by `/` and `*`).
|
||||
|
||||
Each parameter has a type (see :func:`type_expr`),
|
||||
preceded by a name (identifier) and a colon `:` (not required for
|
||||
positional-only parameters).
|
||||
|
||||
Returns:
|
||||
ParamSpec: the parsed parameter spec
|
||||
"""
|
||||
l_paren: Token = self.consume(
|
||||
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
||||
)
|
||||
pos_args: list[FunctionType.Argument] = []
|
||||
args: list[FunctionType.Argument] = []
|
||||
kw_args: list[FunctionType.Argument] = []
|
||||
pos: list[FunctionType.Parameter] = []
|
||||
mixed: list[FunctionType.Parameter] = []
|
||||
kw: list[FunctionType.Parameter] = []
|
||||
|
||||
args_first_tokens: list[Token] = []
|
||||
mixed_first_tokens: list[Token] = []
|
||||
|
||||
section: int = 0
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
|
||||
match section:
|
||||
case 0 if self.match(TokenType.SLASH):
|
||||
pos_args = args
|
||||
args = []
|
||||
args_first_tokens = []
|
||||
pos = mixed
|
||||
mixed = []
|
||||
mixed_first_tokens = []
|
||||
section = 1
|
||||
case 0 | 1 if self.match(TokenType.STAR):
|
||||
section = 2
|
||||
case _:
|
||||
# Record first token of mixed argument for errors if unnamed
|
||||
# Record first token of mixed parameters for errors if unnamed
|
||||
if section != 2:
|
||||
args_first_tokens.append(self.peek())
|
||||
mixed_first_tokens.append(self.peek())
|
||||
|
||||
name: Optional[Token] = None
|
||||
if section == 2:
|
||||
name = self.consume_identifier("Expected keyword argument name")
|
||||
name = self.consume_identifier(
|
||||
"Expected keyword parameter name"
|
||||
)
|
||||
self.consume(
|
||||
TokenType.COLON, "Expected ':' after argument name"
|
||||
TokenType.COLON, "Expected ':' after parameter name"
|
||||
)
|
||||
elif self.check_identifier() and self.check_next(TokenType.COLON):
|
||||
name = self.advance()
|
||||
@@ -596,24 +851,24 @@ class MidasParser(Parser):
|
||||
|
||||
type: Type = self.type_expr()
|
||||
optional: bool = self.match(TokenType.QMARK)
|
||||
arg = FunctionType.Argument(
|
||||
param = FunctionType.Parameter(
|
||||
location=None,
|
||||
name=name,
|
||||
type=type,
|
||||
required=not optional,
|
||||
)
|
||||
if section == 2:
|
||||
kw_args.append(arg)
|
||||
kw.append(param)
|
||||
else:
|
||||
args.append(arg)
|
||||
mixed.append(param)
|
||||
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
|
||||
for arg, token in zip(args, args_first_tokens):
|
||||
if arg.name is None:
|
||||
for param, token in zip(mixed, mixed_first_tokens):
|
||||
if param.name is None:
|
||||
# Not raised because we can keep parsing
|
||||
self.error(token, "Unnamed mixed argument")
|
||||
self.error(token, "Unnamed mixed parameter")
|
||||
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
||||
return ParamSpec(l_paren=l_paren, pos=pos_args, mixed=args, kw=kw_args)
|
||||
return ParamSpec(l_paren=l_paren, pos=pos, mixed=mixed, kw=kw)
|
||||
|
||||
@@ -23,6 +23,7 @@ from midas.ast.python import (
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
MidasType,
|
||||
ParamSpec,
|
||||
RawExpr,
|
||||
RawStmt,
|
||||
ReturnStmt,
|
||||
@@ -30,6 +31,7 @@ from midas.ast.python import (
|
||||
Stmt,
|
||||
SubscriptExpr,
|
||||
TernaryExpr,
|
||||
TupleExpr,
|
||||
TypeAssign,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
@@ -48,6 +50,8 @@ class UnsupportedSyntaxError(Exception):
|
||||
|
||||
|
||||
class PythonParser:
|
||||
"""A parser to convert raw Python `ast` nodes in custom IR nodes"""
|
||||
|
||||
CAST_FUNCTION = "cast"
|
||||
UNSAFE_CAST_FUNCTION = "unsafe_cast"
|
||||
|
||||
@@ -211,27 +215,10 @@ class PythonParser:
|
||||
match node:
|
||||
case ast.FunctionDef(
|
||||
name=name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=posonlyargs,
|
||||
args=args,
|
||||
vararg=sink,
|
||||
kwonlyargs=kwonlyargs,
|
||||
kwarg=kw_sink,
|
||||
defaults=defaults,
|
||||
kw_defaults=kw_defaults,
|
||||
),
|
||||
args=args,
|
||||
returns=returns,
|
||||
body=raw_body,
|
||||
):
|
||||
|
||||
def parse_args(
|
||||
args_list: list[ast.arg], defaults: list[Optional[Expr]]
|
||||
) -> list[Function.Argument]:
|
||||
return [
|
||||
self._parse_function_argument(arg, default)
|
||||
for arg, default in zip(args_list, defaults)
|
||||
]
|
||||
|
||||
body: list[Stmt] = []
|
||||
for stmt in raw_body:
|
||||
stmts = self.parse_stmt(stmt)
|
||||
@@ -240,54 +227,58 @@ class PythonParser:
|
||||
elif stmts is not None:
|
||||
body.extend(stmts)
|
||||
|
||||
parsed_defaults: list[Optional[Expr]] = [
|
||||
self.parse_expr(default) for default in defaults
|
||||
]
|
||||
n_posargs: int = len(posonlyargs)
|
||||
n_args: int = len(args)
|
||||
n_all_posargs = n_posargs + n_args
|
||||
parsed_defaults = [
|
||||
None,
|
||||
] * (n_all_posargs - len(defaults)) + parsed_defaults
|
||||
|
||||
posargs_defaults: list[Optional[Expr]] = parsed_defaults[:n_posargs]
|
||||
args_defaults: list[Optional[Expr]] = parsed_defaults[n_posargs:]
|
||||
kwargs_defaults: list[Optional[Expr]] = [
|
||||
self.parse_expr(default) if default is not None else None
|
||||
for default in kw_defaults
|
||||
]
|
||||
|
||||
return Function(
|
||||
location=loc,
|
||||
name=name,
|
||||
posonlyargs=parse_args(posonlyargs, posargs_defaults),
|
||||
args=parse_args(args, args_defaults),
|
||||
sink=(
|
||||
self._parse_function_argument(sink, None)
|
||||
if sink is not None
|
||||
else None
|
||||
),
|
||||
kwonlyargs=parse_args(kwonlyargs, kwargs_defaults),
|
||||
kw_sink=(
|
||||
self._parse_function_argument(kw_sink, None)
|
||||
if kw_sink is not None
|
||||
else None
|
||||
),
|
||||
params=self._parse_param_spec(args),
|
||||
returns=self._parse_type(returns) if returns is not None else None,
|
||||
body=body,
|
||||
)
|
||||
case _:
|
||||
print(f"Unsupported function definition: {ast.unparse(node)}")
|
||||
|
||||
def _parse_function_argument(
|
||||
def _parse_param_spec(self, args: ast.arguments) -> ParamSpec:
|
||||
def parse_params(
|
||||
args_list: list[ast.arg], defaults: list[Optional[Expr]]
|
||||
) -> list[Function.Parameter]:
|
||||
return [
|
||||
self._parse_function_parameter(arg, default)
|
||||
for arg, default in zip(args_list, defaults)
|
||||
]
|
||||
|
||||
defaults: list[ast.expr] = args.defaults
|
||||
parsed_defaults: list[Optional[Expr]] = [
|
||||
self.parse_expr(default) for default in defaults
|
||||
]
|
||||
n_pos: int = len(args.posonlyargs)
|
||||
n_mixed: int = len(args.args)
|
||||
n_all_pos = n_pos + n_mixed
|
||||
parsed_defaults = [
|
||||
None,
|
||||
] * (n_all_pos - len(defaults)) + parsed_defaults
|
||||
|
||||
pos_defaults: list[Optional[Expr]] = parsed_defaults[:n_pos]
|
||||
mixed_defaults: list[Optional[Expr]] = parsed_defaults[n_pos:]
|
||||
kw_defaults: list[Optional[Expr]] = [
|
||||
self.parse_expr(default) if default is not None else None
|
||||
for default in args.kw_defaults
|
||||
]
|
||||
|
||||
return ParamSpec(
|
||||
pos=parse_params(args.posonlyargs, pos_defaults),
|
||||
mixed=parse_params(args.args, mixed_defaults),
|
||||
kw=parse_params(args.kwonlyargs, kw_defaults),
|
||||
)
|
||||
|
||||
def _parse_function_parameter(
|
||||
self, arg: ast.arg, default: Optional[Expr]
|
||||
) -> Function.Argument:
|
||||
) -> Function.Parameter:
|
||||
loc: Location = Location.from_ast(arg)
|
||||
name: str = arg.arg
|
||||
type: Optional[MidasType] = None
|
||||
if arg.annotation is not None:
|
||||
type = self._parse_type(arg.annotation)
|
||||
return Function.Argument(
|
||||
return Function.Parameter(
|
||||
location=loc,
|
||||
name=name,
|
||||
type=type,
|
||||
@@ -300,26 +291,28 @@ class PythonParser:
|
||||
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
|
||||
return self._parse_frame_type(schema)
|
||||
|
||||
case ast.Subscript(value=ast.Name(id=name), slice=param):
|
||||
case ast.Subscript(value=ast.Name(id=name), slice=arg):
|
||||
args: tuple[MidasType, ...] = (
|
||||
tuple(self._parse_type(a) for a in arg.elts)
|
||||
if isinstance(arg, ast.Tuple)
|
||||
else (self._parse_type(arg),)
|
||||
)
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base=name,
|
||||
param=self._parse_type(param),
|
||||
args=args,
|
||||
)
|
||||
|
||||
case ast.Name(id=name):
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base=name,
|
||||
param=None,
|
||||
args=(),
|
||||
)
|
||||
|
||||
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
|
||||
left = self._parse_type(left_expr)
|
||||
match left:
|
||||
case None:
|
||||
raise InvalidSyntaxError()
|
||||
|
||||
# If chained constraints, separate base type and rebuild constraint
|
||||
case ConstraintType(type=left_type, constraint=left_constraint):
|
||||
constraint = ast.BinOp(
|
||||
@@ -345,7 +338,7 @@ class PythonParser:
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base="None",
|
||||
param=None,
|
||||
args=(),
|
||||
)
|
||||
|
||||
case _:
|
||||
@@ -477,6 +470,12 @@ class PythonParser:
|
||||
step=self.parse_expr(step) if step is not None else None,
|
||||
)
|
||||
|
||||
case ast.Tuple(elts=items):
|
||||
return TupleExpr(
|
||||
location=location,
|
||||
items=tuple(self.parse_expr(item) for item in items),
|
||||
)
|
||||
|
||||
case _:
|
||||
print(f"Unsupported expression: {ast.unparse(node)}")
|
||||
return RawExpr(location=location, expr=node)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import Generic, TypeVar
|
||||
from typing import cast as typing_cast
|
||||
|
||||
cast = typing_cast
|
||||
@@ -32,3 +33,20 @@ This operation is unsound, use at your own risk!
|
||||
|
||||
_**Internal Python documentation**_
|
||||
"""
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Frame(Generic[T]):
|
||||
"""A `Frame` is the abstract type implemented by `DataFrame`
|
||||
|
||||
A frame contains any number of named columns (see :class:`Column`)
|
||||
"""
|
||||
|
||||
|
||||
class Column(Generic[T]):
|
||||
"""A `Column` is the abstract type implemented by `Series`
|
||||
|
||||
A column contains a any number of values of the same type
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Callable, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.checker.types import Type
|
||||
from midas.generator.collector import AssertionCollector
|
||||
|
||||
AllowRepeat = Callable[[object], bool]
|
||||
|
||||
@@ -63,3 +64,4 @@ class TypedAST:
|
||||
stmts: list[p.Stmt]
|
||||
judgements: list[tuple[p.Expr, Type]]
|
||||
evaluated_casts: list[p.CastExpr]
|
||||
assertions: AssertionCollector
|
||||
|
||||
43
tests/__main__.py
Normal file
43
tests/__main__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Type
|
||||
|
||||
from midas.cli.ansi import Ansi
|
||||
from tests.base import Tester
|
||||
from tests.checker import CheckerTester
|
||||
from tests.generator import GeneratorTester
|
||||
from tests.midas import MidasTester
|
||||
from tests.python import PythonTester
|
||||
|
||||
|
||||
def print_banner(name: str):
|
||||
horizontal: str = "+" + "-" * (len(name) + 2) + "+"
|
||||
print(horizontal)
|
||||
print(f"| {name} |")
|
||||
print(horizontal)
|
||||
|
||||
|
||||
def run_tests(tester_cls: Type[Tester]) -> bool:
|
||||
print_banner(tester_cls.__name__)
|
||||
tester: Tester = tester_cls()
|
||||
success: bool = tester.run_all_tests()
|
||||
print()
|
||||
return success
|
||||
|
||||
|
||||
def main():
|
||||
testers: list[Type[Tester]] = [
|
||||
PythonTester,
|
||||
MidasTester,
|
||||
CheckerTester,
|
||||
GeneratorTester,
|
||||
]
|
||||
|
||||
success: bool = all(map(run_tests, testers))
|
||||
|
||||
if success:
|
||||
print(Ansi.FG(Ansi.BRIGHT_GREEN) + "All tests passed!" + Ansi.RESET)
|
||||
else:
|
||||
print(Ansi.FG(Ansi.BRIGHT_RED) + "Some tests failed!" + Ansi.RESET)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -7,6 +7,8 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Protocol
|
||||
|
||||
from midas.cli.ansi import Ansi
|
||||
|
||||
|
||||
class CaseResult(Protocol):
|
||||
def dumps(self) -> str: ...
|
||||
@@ -44,8 +46,11 @@ class Tester(ABC):
|
||||
|
||||
print(rule)
|
||||
for i, test in enumerate(tests):
|
||||
print(f"Case {i+1}/{n}: {test.resolve().relative_to(self.CASES_DIR)}")
|
||||
path: Path = test.resolve().relative_to(self.CASES_DIR)
|
||||
print(f"{Ansi.FG(Ansi.BRIGHT_CYAN)}Case {i+1}/{n}: {path}{Ansi.RESET}")
|
||||
print(Ansi.DIM, end="")
|
||||
success: bool = self._run_test(test)
|
||||
print(Ansi.RESET, end="")
|
||||
if success:
|
||||
successes += 1
|
||||
else:
|
||||
@@ -146,8 +151,9 @@ class Tester(ABC):
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
case None:
|
||||
print("No subcommand provided. Available subcommands: run, update")
|
||||
sys.exit(1)
|
||||
success: bool = tester.run_all_tests()
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
case _:
|
||||
print(f"Unknown subcommand '{args.subcommand}'")
|
||||
sys.exit(1)
|
||||
|
||||
@@ -4,7 +4,35 @@
|
||||
"type": "Warning",
|
||||
"location": {
|
||||
"start": [
|
||||
6,
|
||||
8,
|
||||
12
|
||||
],
|
||||
"end": [
|
||||
8,
|
||||
43
|
||||
]
|
||||
},
|
||||
"message": "ConstraintType not yet supported"
|
||||
},
|
||||
{
|
||||
"type": "Warning",
|
||||
"location": {
|
||||
"start": [
|
||||
10,
|
||||
10
|
||||
],
|
||||
"end": [
|
||||
10,
|
||||
18
|
||||
]
|
||||
},
|
||||
"message": "Unknown type 'datetime'"
|
||||
},
|
||||
{
|
||||
"type": "Warning",
|
||||
"location": {
|
||||
"start": [
|
||||
13,
|
||||
4
|
||||
],
|
||||
"end": [
|
||||
@@ -12,7 +40,7 @@
|
||||
5
|
||||
]
|
||||
},
|
||||
"message": "FrameType not yet supported"
|
||||
"message": "Unknown type '_'"
|
||||
}
|
||||
],
|
||||
"judgments": []
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,7 +24,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Meter",
|
||||
"param": null
|
||||
"args": []
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
@@ -62,7 +62,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Second",
|
||||
"param": null
|
||||
"args": []
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
|
||||
@@ -100,41 +100,6 @@
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:5",
|
||||
"to": "L11:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "maximum"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [],
|
||||
"args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "b",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:13",
|
||||
@@ -161,6 +126,43 @@
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:5",
|
||||
"to": "L11:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "maximum"
|
||||
},
|
||||
"type": {
|
||||
"params": {
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "b",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:5",
|
||||
|
||||
@@ -72,29 +72,6 @@
|
||||
}
|
||||
],
|
||||
"judgments": [
|
||||
{
|
||||
"location": {
|
||||
"from": "L26:0",
|
||||
"to": "L26:5"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "print"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "object",
|
||||
"type": {},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"returns": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L27:4",
|
||||
@@ -325,6 +302,31 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L26:0",
|
||||
"to": "L26:5"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "print"
|
||||
},
|
||||
"type": {
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "object",
|
||||
"type": {},
|
||||
"required": false
|
||||
}
|
||||
],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L26:0",
|
||||
|
||||
@@ -63,31 +63,6 @@
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:11",
|
||||
"to": "L6:15"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "bool"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "object",
|
||||
"type": {},
|
||||
"required": false
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:16",
|
||||
@@ -135,6 +110,33 @@
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:11",
|
||||
"to": "L6:15"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "bool"
|
||||
},
|
||||
"type": {
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "object",
|
||||
"type": {},
|
||||
"required": false
|
||||
}
|
||||
],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:11",
|
||||
@@ -367,94 +369,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:17",
|
||||
"to": "L12:20"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "map"
|
||||
},
|
||||
"type": {
|
||||
"name": "map",
|
||||
"params": [
|
||||
{
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
},
|
||||
{
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"pos_args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "transform",
|
||||
"type": {
|
||||
"pos_args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "v",
|
||||
"type": {
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "iterable",
|
||||
"type": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:21",
|
||||
@@ -465,18 +379,20 @@
|
||||
"name": "double"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [],
|
||||
"args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "value",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [],
|
||||
"params": {
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "value",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "float"
|
||||
}
|
||||
@@ -503,6 +419,98 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:17",
|
||||
"to": "L12:20"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "map"
|
||||
},
|
||||
"type": {
|
||||
"name": "map",
|
||||
"params": [
|
||||
{
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
},
|
||||
{
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "transform",
|
||||
"type": {
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "v",
|
||||
"type": {
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "iterable",
|
||||
"type": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:17",
|
||||
@@ -538,94 +546,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L13:15",
|
||||
"to": "L13:18"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "map"
|
||||
},
|
||||
"type": {
|
||||
"name": "map",
|
||||
"params": [
|
||||
{
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
},
|
||||
{
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"pos_args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "transform",
|
||||
"type": {
|
||||
"pos_args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "v",
|
||||
"type": {
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "iterable",
|
||||
"type": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L13:19",
|
||||
@@ -636,18 +556,20 @@
|
||||
"name": "double"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [],
|
||||
"args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "value",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [],
|
||||
"params": {
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "value",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "float"
|
||||
}
|
||||
@@ -674,6 +596,98 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L13:15",
|
||||
"to": "L13:18"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "map"
|
||||
},
|
||||
"type": {
|
||||
"name": "map",
|
||||
"params": [
|
||||
{
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
},
|
||||
{
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "transform",
|
||||
"type": {
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "v",
|
||||
"type": {
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "iterable",
|
||||
"type": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L13:15",
|
||||
@@ -699,94 +713,6 @@
|
||||
},
|
||||
"type": {}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:11",
|
||||
"to": "L14:14"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "map"
|
||||
},
|
||||
"type": {
|
||||
"name": "map",
|
||||
"params": [
|
||||
{
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
},
|
||||
{
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"pos_args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "transform",
|
||||
"type": {
|
||||
"pos_args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "v",
|
||||
"type": {
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "iterable",
|
||||
"type": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:15",
|
||||
@@ -797,18 +723,20 @@
|
||||
"name": "is_odd"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [],
|
||||
"args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "value",
|
||||
"type": {
|
||||
"name": "int"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [],
|
||||
"params": {
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "value",
|
||||
"type": {
|
||||
"name": "int"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
@@ -835,6 +763,98 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:11",
|
||||
"to": "L14:14"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "map"
|
||||
},
|
||||
"type": {
|
||||
"name": "map",
|
||||
"params": [
|
||||
{
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
},
|
||||
{
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "transform",
|
||||
"type": {
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "v",
|
||||
"type": {
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "iterable",
|
||||
"type": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "T",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
{
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
"variance": "INVARIANT"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"name": "list"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L14:11",
|
||||
|
||||
117
tests/cases/checker/09_frame_ops.py
Normal file
117
tests/cases/checker/09_frame_ops.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
df1: Frame[a:int, b:float]
|
||||
df2: Frame[a:int, b:float]
|
||||
|
||||
_: Any
|
||||
|
||||
# Arithmetic
|
||||
_ = df1 + df2
|
||||
_ = df1 - df2
|
||||
_ = df1 * df2
|
||||
_ = df1 / df2
|
||||
_ = df1 // df2
|
||||
_ = df1 % df2
|
||||
_ = df1**df2
|
||||
|
||||
# Comparisons
|
||||
_ = df1 < df2
|
||||
_ = df1 > df2
|
||||
_ = df1 <= df2
|
||||
_ = df1 >= df2
|
||||
_ = df1 != df2
|
||||
_ = df1 == df2
|
||||
|
||||
# Aggregate
|
||||
_ = df1.kurt()
|
||||
_ = df1.kurtosis()
|
||||
_ = df1.max()
|
||||
_ = df1.mean()
|
||||
_ = df1.median()
|
||||
_ = df1.min()
|
||||
_ = df1.mode()
|
||||
_ = df1.prod()
|
||||
_ = df1.product()
|
||||
_ = df1.std()
|
||||
_ = df1.sum()
|
||||
_ = df1.var()
|
||||
|
||||
# Groupby
|
||||
df_gb = df1.groupby(by="a")
|
||||
|
||||
_ = df_gb.kurt()
|
||||
_ = df_gb.max()
|
||||
_ = df_gb.mean()
|
||||
_ = df_gb.median()
|
||||
_ = df_gb.min()
|
||||
_ = df_gb.prod()
|
||||
_ = df_gb.std()
|
||||
_ = df_gb.sum()
|
||||
_ = df_gb.var()
|
||||
|
||||
|
||||
# Columns
|
||||
|
||||
col1 = df1["a"]
|
||||
col2 = df1["a"]
|
||||
|
||||
# Arithmetic
|
||||
_ = col1 + col2
|
||||
_ = col1 - col2
|
||||
_ = col1 * col2
|
||||
_ = col1 / col2
|
||||
_ = col1 // col2
|
||||
_ = col1 % col2
|
||||
_ = col1**col2
|
||||
|
||||
# Comparisons
|
||||
_ = col1 < col2
|
||||
_ = col1 > col2
|
||||
_ = col1 <= col2
|
||||
_ = col1 >= col2
|
||||
_ = col1 != col2
|
||||
_ = col1 == col2
|
||||
|
||||
# Aggregate
|
||||
_ = col1.kurt()
|
||||
_ = col1.kurtosis()
|
||||
_ = col1.max()
|
||||
_ = col1.mean()
|
||||
_ = col1.median()
|
||||
_ = col1.min()
|
||||
_ = col1.mode()
|
||||
_ = col1.prod()
|
||||
_ = col1.product()
|
||||
_ = col1.std()
|
||||
_ = col1.sum()
|
||||
_ = col1.var()
|
||||
|
||||
# Groupby
|
||||
col_gb = col1.groupby(level=0)
|
||||
|
||||
_ = col_gb.kurt()
|
||||
_ = col_gb.max()
|
||||
_ = col_gb.mean()
|
||||
_ = col_gb.median()
|
||||
_ = col_gb.min()
|
||||
_ = col_gb.prod()
|
||||
_ = col_gb.std()
|
||||
_ = col_gb.sum()
|
||||
_ = col_gb.var()
|
||||
|
||||
# Attributes
|
||||
_ = df1.ndim # int
|
||||
_ = df1.size # int
|
||||
_ = df1.shape # (int, int)
|
||||
_ = col1.ndim # int
|
||||
_ = col1.size # int
|
||||
_ = col1.shape # (int)
|
||||
_ = col1.T # Column[int]
|
||||
|
||||
|
||||
# Misc
|
||||
_ = df1.head()
|
||||
_ = df1.tail()
|
||||
_ = col1.head()
|
||||
_ = col1.tail()
|
||||
4924
tests/cases/checker/09_frame_ops.py.ref.json
Normal file
4924
tests/cases/checker/09_frame_ops.py.ref.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -16,7 +16,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "bool",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -25,7 +25,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -36,7 +36,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
"args": []
|
||||
},
|
||||
"constraint": "(_ > 0) + (_ < 250)"
|
||||
}
|
||||
@@ -47,7 +47,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "str",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -56,7 +56,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "datetime",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -65,7 +65,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -79,7 +79,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "_",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "GeoLocation",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -28,11 +28,13 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "BaseType",
|
||||
"base": "GeoLocation",
|
||||
"param": null
|
||||
}
|
||||
"args": [
|
||||
{
|
||||
"_type": "BaseType",
|
||||
"base": "GeoLocation",
|
||||
"args": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -65,11 +67,13 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "BaseType",
|
||||
"base": "GeoLocation",
|
||||
"param": null
|
||||
}
|
||||
"args": [
|
||||
{
|
||||
"_type": "BaseType",
|
||||
"base": "GeoLocation",
|
||||
"args": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -117,7 +121,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Latitude",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -146,7 +150,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Latitude",
|
||||
"param": null
|
||||
"args": []
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -175,11 +179,13 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Difference",
|
||||
"param": {
|
||||
"_type": "BaseType",
|
||||
"base": "Latitude",
|
||||
"param": null
|
||||
}
|
||||
"args": [
|
||||
{
|
||||
"_type": "BaseType",
|
||||
"base": "Latitude",
|
||||
"args": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -217,7 +223,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
"args": []
|
||||
},
|
||||
"constraint": "_ >= 0"
|
||||
}
|
||||
@@ -230,7 +236,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
"args": []
|
||||
},
|
||||
"constraint": "_ >= 0"
|
||||
}
|
||||
@@ -252,7 +258,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
"args": []
|
||||
},
|
||||
"constraint": "Positive"
|
||||
}
|
||||
@@ -265,7 +271,7 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
"args": []
|
||||
},
|
||||
"constraint": "Positive"
|
||||
}
|
||||
|
||||
@@ -7,58 +7,65 @@
|
||||
{
|
||||
"_type": "Function",
|
||||
"name": "func",
|
||||
"posonlyargs": [],
|
||||
"args": [
|
||||
{
|
||||
"name": "col1",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 1"
|
||||
}
|
||||
"params": {
|
||||
"_type": "ParamSpec",
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"name": "col1",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"args": [
|
||||
{
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"args": []
|
||||
},
|
||||
"constraint": "0 <= _ <= 1"
|
||||
}
|
||||
]
|
||||
},
|
||||
"default": null
|
||||
},
|
||||
"default": null
|
||||
},
|
||||
{
|
||||
"name": "col2",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 1"
|
||||
}
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"sink": null,
|
||||
"kwonlyargs": [],
|
||||
"kw_sink": null,
|
||||
{
|
||||
"name": "col2",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"args": [
|
||||
{
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"args": []
|
||||
},
|
||||
"constraint": "0 <= _ <= 1"
|
||||
}
|
||||
]
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 2"
|
||||
}
|
||||
"args": [
|
||||
{
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"args": []
|
||||
},
|
||||
"constraint": "0 <= _ <= 2"
|
||||
}
|
||||
]
|
||||
},
|
||||
"body": [
|
||||
{
|
||||
@@ -67,15 +74,17 @@
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 2"
|
||||
}
|
||||
"args": [
|
||||
{
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"args": []
|
||||
},
|
||||
"constraint": "0 <= _ <= 2"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -111,41 +120,42 @@
|
||||
{
|
||||
"_type": "Function",
|
||||
"name": "func2",
|
||||
"posonlyargs": [
|
||||
{
|
||||
"name": "a",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"args": [
|
||||
{
|
||||
"name": "b",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"sink": null,
|
||||
"kwonlyargs": [
|
||||
{
|
||||
"name": "c",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "str",
|
||||
"param": null
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"kw_sink": null,
|
||||
"params": {
|
||||
"_type": "ParamSpec",
|
||||
"pos": [
|
||||
{
|
||||
"name": "a",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"args": []
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"mixed": [
|
||||
{
|
||||
"name": "b",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"args": []
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"kw": [
|
||||
{
|
||||
"name": "c",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "str",
|
||||
"args": []
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
]
|
||||
},
|
||||
"returns": null,
|
||||
"body": []
|
||||
}
|
||||
|
||||
@@ -46,7 +46,8 @@ class GeneratorTester(Tester):
|
||||
|
||||
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
|
||||
generator = Generator(workdir=path.parent, types=checker.types)
|
||||
result.compiled_ast = generator.generate_ast(typed_ast, path)
|
||||
generator.set_src_path(path)
|
||||
result.compiled_ast = generator.generate_ast(typed_ast)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from midas.ast.midas import (
|
||||
AliasStmt,
|
||||
BinaryExpr,
|
||||
CallExpr,
|
||||
ComplexType,
|
||||
@@ -8,6 +9,7 @@ from midas.ast.midas import (
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
ExtensionType,
|
||||
FrameType,
|
||||
FunctionType,
|
||||
GenericType,
|
||||
GetExpr,
|
||||
@@ -60,6 +62,13 @@ class MidasAstJsonSerializer(
|
||||
"bound": self._serialize_optional(param.bound),
|
||||
}
|
||||
|
||||
def visit_alias_stmt(self, stmt: AliasStmt) -> dict:
|
||||
return {
|
||||
"_type": "AliasStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"type": stmt.type.accept(self),
|
||||
}
|
||||
|
||||
def visit_member_stmt(self, stmt: MemberStmt) -> dict:
|
||||
return {
|
||||
"_type": "MemberStmt",
|
||||
@@ -179,16 +188,16 @@ class MidasAstJsonSerializer(
|
||||
def _serialize_param_spec(self, spec: ParamSpec) -> dict:
|
||||
return {
|
||||
"_type": "ParamSpec",
|
||||
"pos": [self._serialize_func_arg(arg) for arg in spec.pos],
|
||||
"mixed": [self._serialize_func_arg(arg) for arg in spec.mixed],
|
||||
"kw": [self._serialize_func_arg(arg) for arg in spec.kw],
|
||||
"pos": [self._serialize_func_param(arg) for arg in spec.pos],
|
||||
"mixed": [self._serialize_func_param(arg) for arg in spec.mixed],
|
||||
"kw": [self._serialize_func_param(arg) for arg in spec.kw],
|
||||
}
|
||||
|
||||
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
|
||||
def _serialize_func_param(self, param: FunctionType.Parameter) -> dict:
|
||||
return {
|
||||
"name": arg.name.lexeme if arg.name is not None else None,
|
||||
"type": arg.type.accept(self),
|
||||
"required": arg.required,
|
||||
"name": param.name.lexeme if param.name is not None else None,
|
||||
"type": param.type.accept(self),
|
||||
"required": param.required,
|
||||
}
|
||||
|
||||
def visit_extension_type(self, type: ExtensionType) -> dict:
|
||||
@@ -197,3 +206,15 @@ class MidasAstJsonSerializer(
|
||||
"base": type.base.accept(self),
|
||||
"extension": type.extension.accept(self),
|
||||
}
|
||||
|
||||
def visit_frame_type(self, type: FrameType) -> dict:
|
||||
return {
|
||||
"_type": "FrameType",
|
||||
"columns": [self._serialize_column(col) for col in type.columns],
|
||||
}
|
||||
|
||||
def _serialize_column(self, column: FrameType.Column):
|
||||
return {
|
||||
"name": column.name.lexeme,
|
||||
"type": column.type.accept(self),
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ from midas.ast.python import (
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
MidasType,
|
||||
ParamSpec,
|
||||
Pass,
|
||||
RawExpr,
|
||||
RawStmt,
|
||||
@@ -30,6 +31,7 @@ from midas.ast.python import (
|
||||
Stmt,
|
||||
SubscriptExpr,
|
||||
TernaryExpr,
|
||||
TupleExpr,
|
||||
TypeAssign,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
@@ -98,7 +100,7 @@ class PythonAstJsonSerializer(
|
||||
return {
|
||||
"_type": "BaseType",
|
||||
"base": node.base,
|
||||
"param": self._serialize_optional(node.param),
|
||||
"args": self._serialize_list(node.args),
|
||||
}
|
||||
|
||||
def visit_constraint_type(self, node: ConstraintType) -> dict:
|
||||
@@ -127,32 +129,30 @@ class PythonAstJsonSerializer(
|
||||
"expr": stmt.expr.accept(self),
|
||||
}
|
||||
|
||||
def _serialize_argument(self, arg: Function.Argument) -> dict:
|
||||
return {
|
||||
"name": arg.name,
|
||||
"type": self._serialize_optional(arg.type),
|
||||
"default": self._serialize_optional(arg.default),
|
||||
}
|
||||
|
||||
def visit_function(self, stmt: Function) -> dict:
|
||||
return {
|
||||
"_type": "Function",
|
||||
"name": stmt.name,
|
||||
"posonlyargs": [self._serialize_argument(arg) for arg in stmt.posonlyargs],
|
||||
"args": [self._serialize_argument(arg) for arg in stmt.args],
|
||||
"sink": (
|
||||
self._serialize_argument(stmt.sink) if stmt.sink is not None else None
|
||||
),
|
||||
"kwonlyargs": [self._serialize_argument(arg) for arg in stmt.kwonlyargs],
|
||||
"kw_sink": (
|
||||
self._serialize_argument(stmt.kw_sink)
|
||||
if stmt.kw_sink is not None
|
||||
else None
|
||||
),
|
||||
"params": self._serialize_param_spec(stmt.params),
|
||||
"returns": self._serialize_optional(stmt.returns),
|
||||
"body": self._serialize_list(stmt.body),
|
||||
}
|
||||
|
||||
def _serialize_param_spec(self, spec: ParamSpec) -> dict:
|
||||
return {
|
||||
"_type": "ParamSpec",
|
||||
"pos": [self._serialize_func_param(arg) for arg in spec.pos],
|
||||
"mixed": [self._serialize_func_param(arg) for arg in spec.mixed],
|
||||
"kw": [self._serialize_func_param(arg) for arg in spec.kw],
|
||||
}
|
||||
|
||||
def _serialize_func_param(self, param: Function.Parameter) -> dict:
|
||||
return {
|
||||
"name": param.name,
|
||||
"type": self._serialize_optional(param.type),
|
||||
"default": self._serialize_optional(param.default),
|
||||
}
|
||||
|
||||
def visit_type_assign(self, stmt: TypeAssign) -> dict:
|
||||
return {
|
||||
"_type": "TypeAssign",
|
||||
@@ -302,6 +302,12 @@ class PythonAstJsonSerializer(
|
||||
"step": self._serialize_optional(expr.step),
|
||||
}
|
||||
|
||||
def visit_tuple_expr(self, expr: TupleExpr) -> dict:
|
||||
return {
|
||||
"_type": "TupleExpr",
|
||||
"items": [item.accept(self) for item in expr.items],
|
||||
}
|
||||
|
||||
def visit_raw_expr(self, expr: RawExpr) -> dict:
|
||||
return {
|
||||
"_type": "RawExpr",
|
||||
|
||||
Reference in New Issue
Block a user