Compare commits
34 Commits
2a8b7d559c
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 6eea0c02e0 | |||
|
3205e7b961
|
|||
|
0aba134290
|
|||
|
1f0bcab2ca
|
|||
|
db8d88ef35
|
|||
|
7695d50537
|
|||
|
8461d05fa6
|
|||
|
43d2118db7
|
|||
|
6a87b5396f
|
|||
|
e6a581ba6e
|
|||
|
2a7aac69ed
|
|||
|
eb5bf19c61
|
|||
|
657406ea01
|
|||
|
2974386110
|
|||
|
92ca6b6732
|
|||
|
6aacdb98b7
|
|||
|
1b100b6ceb
|
|||
|
6b4c7d27bc
|
|||
|
2523d638f7
|
|||
|
5fc7461e29
|
|||
|
c5154bde81
|
|||
|
d07e8ac0ca
|
|||
|
3380995082
|
|||
|
7efc44c496
|
|||
|
ca94443699
|
|||
|
c513a85cf2
|
|||
|
2a106c5d07
|
|||
| 9672dfd588 | |||
|
7639ccc94d
|
|||
| a4a2ed5d64 | |||
|
e5cb90aff6
|
|||
|
75f8e4af53
|
|||
|
42c2d7a098
|
|||
| 5ce3b4abed |
90
README.md
90
README.md
@@ -1,4 +1,4 @@
|
|||||||
# Midas
|
<h1>Midas</h1>
|
||||||
|
|
||||||
*Midas* is a type system to _Maintain Integrity of Data with Annotated Structures_. In Greek mythology, [Midas](https://en.wikipedia.org/wiki/Midas) was a Phrygian king who was blessed with the gift of turning everything he touched into gold.
|
*Midas* is a type system to _Maintain Integrity of Data with Annotated Structures_. In Greek mythology, [Midas](https://en.wikipedia.org/wiki/Midas) was a Phrygian king who was blessed with the gift of turning everything he touched into gold.
|
||||||
|
|
||||||
@@ -6,6 +6,24 @@
|
|||||||
|
|
||||||
This framework is being developed as part of a Bachelor's Thesis by Louis Heredero at HEI Sion.
|
This framework is being developed as part of a Bachelor's Thesis by Louis Heredero at HEI Sion.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>Table of Contents</strong></summary>
|
||||||
|
|
||||||
|
- [Requirements](#requirements)
|
||||||
|
- [Installation](#installation)
|
||||||
|
- [Commands](#commands)
|
||||||
|
- [Type Checking](#type-checking)
|
||||||
|
- [Compiling](#compiling)
|
||||||
|
- [Formatting](#formatting)
|
||||||
|
- [Highlighting](#highlighting)
|
||||||
|
- [Dumping the AST](#dumping-the-ast)
|
||||||
|
- [Dumping the Registry](#dumping-the-registry)
|
||||||
|
- [Showing Type Judgements](#showing-type-judgements)
|
||||||
|
- [Validating Definitions](#validating-definitions)
|
||||||
|
- [Tests](#tests)
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
- Python 3.11+
|
- Python 3.11+
|
||||||
@@ -32,10 +50,26 @@ This framework is being developed as part of a Bachelor's Thesis by Louis Herede
|
|||||||
|
|
||||||
## Commands
|
## Commands
|
||||||
|
|
||||||
### Compiling
|
<!--
|
||||||
|
check
|
||||||
|
compile
|
||||||
|
format
|
||||||
|
highlight
|
||||||
|
parse
|
||||||
|
dump_registry
|
||||||
|
types
|
||||||
|
validate
|
||||||
|
-->
|
||||||
|
|
||||||
> [!NOTE]
|
### Type Checking
|
||||||
> In the current state of the project, the `compile` command doesn't generate any runnable code, it only runs the parsers and type checker on the provided files
|
|
||||||
|
```shell
|
||||||
|
midas check -t types.midas source.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This command parses the given files and run the type checkers against the Midas definitions and Python program. Diagnostics are then printed showing warnings and errors.
|
||||||
|
|
||||||
|
### Compiling
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
midas compile -t types.midas source.py
|
midas compile -t types.midas source.py
|
||||||
@@ -43,14 +77,22 @@ midas compile -t types.midas source.py
|
|||||||
|
|
||||||
With the `compile` command, you can process a source Python file, with any number of custom type definition files (`-t FILE` option), and the type checker will verify the coherence of your program and generate the runnable code with valid syntax and runtime assertions.
|
With the `compile` command, you can process a source Python file, with any number of custom type definition files (`-t FILE` option), and the type checker will verify the coherence of your program and generate the runnable code with valid syntax and runtime assertions.
|
||||||
|
|
||||||
The optional `-l FILE` option lets you produce a highlighted version of the source code showing diagnostics from the type checker (see [Highlighting](#highlighting))
|
### Formatting
|
||||||
|
|
||||||
|
```shell
|
||||||
|
midas format types.midas
|
||||||
|
midas format types.midas -o formatted.midas
|
||||||
|
```
|
||||||
|
|
||||||
|
This command parses the given Midas file and outputs a pretty printed file from the AST.
|
||||||
|
|
||||||
### Highlighting
|
### Highlighting
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
midas utils highlight source.py
|
midas highlight source.py
|
||||||
# or
|
midas highlight source.py -o highlighted.html
|
||||||
midas utils highlight types.midas
|
midas highlight types.midas
|
||||||
|
midas highlight types.midas -o highlighted.html
|
||||||
```
|
```
|
||||||
|
|
||||||
The `highlight` command takes in a source file (Python or Midas), runs the appropriate parser and outputs an HTML file containing the source code with added highlighting. This highlighting takes the form of hoverable annotations showing some of the parsed structures (e.g. a function definition, an assignment, a generic type, etc.)
|
The `highlight` command takes in a source file (Python or Midas), runs the appropriate parser and outputs an HTML file containing the source code with added highlighting. This highlighting takes the form of hoverable annotations showing some of the parsed structures (e.g. a function definition, an assignment, a generic type, etc.)
|
||||||
@@ -60,14 +102,35 @@ The optional `-o FILE` option can be used to specify an output path. By default,
|
|||||||
### Dumping the AST
|
### Dumping the AST
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
midas utils dump-ast source.py
|
midas parse source.py
|
||||||
# or
|
midas parse types.midas
|
||||||
midas utils dump-ast types.midas
|
|
||||||
```
|
```
|
||||||
|
|
||||||
For debugging purposes, you can output the AST parsed from a Python or Midas file. For Python files, the `-p` flags lets you toggle the custom AST parsing. Without `-p`, the raw AST is returned, as produced by the builtin `ast` module. This flag has no effect on Midas files.
|
For debugging purposes, you can output the AST parsed from a Python or Midas file. For Python files, the `--raw` flags lets you toggle the custom AST parsing. With `--raw`, the raw AST is returned, as produced by the builtin `ast` module. This flag has no effect on Midas files.
|
||||||
|
|
||||||
The optional `-o FILE` option can be used to specify an output path. By default, the file is printed in stdout (equivalent to `-o -`).
|
### Dumping the Registry
|
||||||
|
|
||||||
|
```shell
|
||||||
|
midas dump-registry -t types.midas
|
||||||
|
```
|
||||||
|
|
||||||
|
This command processes the given Midas definitions and dumps the contents of the types registry.
|
||||||
|
|
||||||
|
### Showing Type Judgements
|
||||||
|
|
||||||
|
```shell
|
||||||
|
midas types -t types.midas source.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This command type checks the given Python source file and logs all typing judgements made by the type checker.
|
||||||
|
|
||||||
|
### Validating Definitions
|
||||||
|
|
||||||
|
```shell
|
||||||
|
midas validate types.midas
|
||||||
|
```
|
||||||
|
|
||||||
|
This command lets you validate a Midas definition file by running the parser and type checker, verifying syntax and references.
|
||||||
|
|
||||||
## Tests
|
## Tests
|
||||||
|
|
||||||
@@ -77,6 +140,7 @@ Several snapshot tests are available to assert the good behaviour of the parsers
|
|||||||
uv run -m tests.midas run -a
|
uv run -m tests.midas run -a
|
||||||
uv run -m tests.python run -a
|
uv run -m tests.python run -a
|
||||||
uv run -m tests.checker run -a
|
uv run -m tests.checker run -a
|
||||||
|
uv run -m tests.generator run -a
|
||||||
```
|
```
|
||||||
|
|
||||||
**Available subcommands:**
|
**Available subcommands:**
|
||||||
|
|||||||
23
gen/midas.py
23
gen/midas.py
@@ -26,6 +26,14 @@ class MemberKind(Enum):
|
|||||||
METHOD = auto()
|
METHOD = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class ParamSpec:
|
||||||
|
l_paren: Token
|
||||||
|
pos: list[FunctionType.Argument]
|
||||||
|
mixed: list[FunctionType.Argument]
|
||||||
|
kw: list[FunctionType.Argument]
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
|
|
||||||
|
|
||||||
@@ -50,9 +58,8 @@ class ExtendStmt:
|
|||||||
|
|
||||||
class PredicateStmt:
|
class PredicateStmt:
|
||||||
name: Token
|
name: Token
|
||||||
subject: Token
|
params: list[ParamSpec]
|
||||||
type: Type
|
body: Expr
|
||||||
condition: Expr
|
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
@@ -78,6 +85,12 @@ class UnaryExpr:
|
|||||||
right: Expr
|
right: Expr
|
||||||
|
|
||||||
|
|
||||||
|
class CallExpr:
|
||||||
|
callee: Expr
|
||||||
|
arguments: list[Expr]
|
||||||
|
keywords: dict[str, Expr]
|
||||||
|
|
||||||
|
|
||||||
class GetExpr:
|
class GetExpr:
|
||||||
expr: Expr
|
expr: Expr
|
||||||
name: Token
|
name: Token
|
||||||
@@ -128,9 +141,7 @@ class ExtensionType:
|
|||||||
|
|
||||||
|
|
||||||
class FunctionType:
|
class FunctionType:
|
||||||
pos_args: list[Argument]
|
params: ParamSpec
|
||||||
args: list[Argument]
|
|
||||||
kw_args: list[Argument]
|
|
||||||
returns: Type
|
returns: Type
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
|||||||
@@ -157,6 +157,11 @@ class ListExpr:
|
|||||||
items: list[Expr]
|
items: list[Expr]
|
||||||
|
|
||||||
|
|
||||||
|
class DictExpr:
|
||||||
|
keys: list[Optional[Expr]]
|
||||||
|
values: list[Expr]
|
||||||
|
|
||||||
|
|
||||||
class SubscriptExpr:
|
class SubscriptExpr:
|
||||||
object: Expr
|
object: Expr
|
||||||
index: Expr
|
index: Expr
|
||||||
|
|||||||
@@ -27,6 +27,14 @@ class MemberKind(Enum):
|
|||||||
METHOD = auto()
|
METHOD = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class ParamSpec:
|
||||||
|
l_paren: Token
|
||||||
|
pos: list[FunctionType.Argument]
|
||||||
|
mixed: list[FunctionType.Argument]
|
||||||
|
kw: list[FunctionType.Argument]
|
||||||
|
|
||||||
|
|
||||||
##############
|
##############
|
||||||
# Statements #
|
# Statements #
|
||||||
##############
|
##############
|
||||||
@@ -86,9 +94,8 @@ class ExtendStmt(Stmt):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class PredicateStmt(Stmt):
|
class PredicateStmt(Stmt):
|
||||||
name: Token
|
name: Token
|
||||||
subject: Token
|
params: list[ParamSpec]
|
||||||
type: Type
|
body: Expr
|
||||||
condition: Expr
|
|
||||||
|
|
||||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
return visitor.visit_predicate_stmt(self)
|
return visitor.visit_predicate_stmt(self)
|
||||||
@@ -116,6 +123,9 @@ class Expr(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
|
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_call_expr(self, expr: CallExpr) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_get_expr(self, expr: GetExpr) -> T: ...
|
def visit_get_expr(self, expr: GetExpr) -> T: ...
|
||||||
|
|
||||||
@@ -161,6 +171,16 @@ class UnaryExpr(Expr):
|
|||||||
return visitor.visit_unary_expr(self)
|
return visitor.visit_unary_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CallExpr(Expr):
|
||||||
|
callee: Expr
|
||||||
|
arguments: list[Expr]
|
||||||
|
keywords: dict[str, Expr]
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_call_expr(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class GetExpr(Expr):
|
class GetExpr(Expr):
|
||||||
expr: Expr
|
expr: Expr
|
||||||
@@ -279,9 +299,7 @@ class ExtensionType(Type):
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class FunctionType(Type):
|
class FunctionType(Type):
|
||||||
pos_args: list[Argument]
|
params: ParamSpec
|
||||||
args: list[Argument]
|
|
||||||
kw_args: list[Argument]
|
|
||||||
returns: Type
|
returns: Type
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
|||||||
@@ -150,13 +150,17 @@ class MidasAstPrinter(
|
|||||||
self._write_line("PredicateStmt")
|
self._write_line("PredicateStmt")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||||
self._write_line(f'subject: "{stmt.subject.lexeme}"')
|
self._write_line("params")
|
||||||
self._write_line("type")
|
with self._child_level():
|
||||||
|
for i, spec in enumerate(stmt.params):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.params) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._visit_param_spec(spec)
|
||||||
|
|
||||||
|
self._write_line("body", last=True)
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
stmt.type.accept(self)
|
stmt.body.accept(self)
|
||||||
self._write_line("condition", last=True)
|
|
||||||
with self._child_level(single=True):
|
|
||||||
stmt.condition.accept(self)
|
|
||||||
|
|
||||||
# Expressions
|
# Expressions
|
||||||
|
|
||||||
@@ -195,6 +199,29 @@ class MidasAstPrinter(
|
|||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
expr.right.accept(self)
|
expr.right.accept(self)
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: m.CallExpr) -> None:
|
||||||
|
self._write_line("CallExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("callee")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.callee.accept(self)
|
||||||
|
self._write_line("arguments")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(expr.arguments):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(expr.arguments) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
arg.accept(self)
|
||||||
|
self._write_line("keywords", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(expr.keywords) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._write_line(name)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
arg.accept(self)
|
||||||
|
|
||||||
def visit_get_expr(self, expr: m.GetExpr):
|
def visit_get_expr(self, expr: m.GetExpr):
|
||||||
self._write_line("GetExpr")
|
self._write_line("GetExpr")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
@@ -276,34 +303,41 @@ class MidasAstPrinter(
|
|||||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||||
self._write_line("FunctionType")
|
self._write_line("FunctionType")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
self._write_line("pos_args")
|
self._write_line("params")
|
||||||
with self._child_level():
|
with self._child_level(single=True):
|
||||||
for i, arg in enumerate(type.pos_args):
|
self._visit_param_spec(type.params)
|
||||||
self._idx = i
|
|
||||||
if i == len(type.pos_args) - 1:
|
|
||||||
self._mark_last()
|
|
||||||
self._print_function_arg(arg)
|
|
||||||
|
|
||||||
self._write_line("args")
|
|
||||||
with self._child_level():
|
|
||||||
for i, arg in enumerate(type.args):
|
|
||||||
self._idx = i
|
|
||||||
if i == len(type.args) - 1:
|
|
||||||
self._mark_last()
|
|
||||||
self._print_function_arg(arg)
|
|
||||||
|
|
||||||
self._write_line("kw_args")
|
|
||||||
with self._child_level():
|
|
||||||
for i, arg in enumerate(type.kw_args):
|
|
||||||
self._idx = i
|
|
||||||
if i == len(type.kw_args) - 1:
|
|
||||||
self._mark_last()
|
|
||||||
self._print_function_arg(arg)
|
|
||||||
|
|
||||||
self._write_line("returns", last=True)
|
self._write_line("returns", last=True)
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
type.returns.accept(self)
|
type.returns.accept(self)
|
||||||
|
|
||||||
|
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
|
||||||
|
self._write_line("ParamSpec")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("pos")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(spec.pos):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(spec.pos) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_function_arg(arg)
|
||||||
|
|
||||||
|
self._write_line("mixed")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(spec.mixed):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(spec.mixed) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_function_arg(arg)
|
||||||
|
|
||||||
|
self._write_line("kw", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(spec.kw):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(spec.kw) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_function_arg(arg)
|
||||||
|
|
||||||
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
|
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
|
||||||
self._write_line("Argument")
|
self._write_line("Argument")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
@@ -367,10 +401,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
|
|
||||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||||
name: str = stmt.name.lexeme
|
name: str = stmt.name.lexeme
|
||||||
subject: str = stmt.subject.lexeme
|
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
|
||||||
type: str = stmt.type.accept(self)
|
body: str = stmt.body.accept(self)
|
||||||
condition: str = stmt.condition.accept(self)
|
return self.indented(f"predicate {name}{sig} = {body}")
|
||||||
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
|
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||||
left: str = expr.left.accept(self)
|
left: str = expr.left.accept(self)
|
||||||
@@ -389,6 +422,12 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
right: str = expr.right.accept(self)
|
right: str = expr.right.accept(self)
|
||||||
return f"{operator}{right}"
|
return f"{operator}{right}"
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: m.CallExpr) -> str:
|
||||||
|
args: list[str] = [arg.accept(self) for arg in expr.arguments] + [
|
||||||
|
f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items()
|
||||||
|
]
|
||||||
|
return f"{expr.callee.accept(self)}({', '.join(args)})"
|
||||||
|
|
||||||
def visit_get_expr(self, expr: m.GetExpr):
|
def visit_get_expr(self, expr: m.GetExpr):
|
||||||
expr_: str = expr.expr.accept(self)
|
expr_: str = expr.expr.accept(self)
|
||||||
name: str = expr.name.lexeme
|
name: str = expr.name.lexeme
|
||||||
@@ -436,9 +475,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
|
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
|
||||||
|
|
||||||
def visit_function_type(self, type: m.FunctionType) -> str:
|
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||||
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
|
spec: str = self._visit_param_spec(type.params)
|
||||||
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
|
return f"fn {spec} -> {type.returns.accept(self)}"
|
||||||
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
|
|
||||||
|
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
|
||||||
|
pos_args: list[str] = [self._print_arg(arg) for arg in spec.pos]
|
||||||
|
mixed_args: list[str] = [self._print_arg(arg) for arg in spec.mixed]
|
||||||
|
kw_args: list[str] = [self._print_arg(arg) for arg in spec.kw]
|
||||||
args: list[str] = pos_args
|
args: list[str] = pos_args
|
||||||
|
|
||||||
if len(pos_args) != 0:
|
if len(pos_args) != 0:
|
||||||
@@ -447,8 +490,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
if len(kw_args) != 0:
|
if len(kw_args) != 0:
|
||||||
args.append("*")
|
args.append("*")
|
||||||
args += kw_args
|
args += kw_args
|
||||||
|
return f"({', '.join(args)})"
|
||||||
return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}"
|
|
||||||
|
|
||||||
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
|
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
|
||||||
res: str = ""
|
res: str = ""
|
||||||
@@ -745,6 +787,27 @@ class PythonAstPrinter(
|
|||||||
self._mark_last()
|
self._mark_last()
|
||||||
item.accept(self)
|
item.accept(self)
|
||||||
|
|
||||||
|
def visit_dict_expr(self, expr: p.DictExpr) -> None:
|
||||||
|
self._write_line("DictExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("keys")
|
||||||
|
with self._child_level():
|
||||||
|
for i, key in enumerate(expr.keys):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(expr.keys) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
if key is None:
|
||||||
|
self._write_line("None")
|
||||||
|
else:
|
||||||
|
key.accept(self)
|
||||||
|
self._write_line("values", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, value in enumerate(expr.values):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(expr.values) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
value.accept(self)
|
||||||
|
|
||||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
||||||
self._write_line("SubscriptExpr")
|
self._write_line("SubscriptExpr")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
|
|||||||
@@ -259,6 +259,9 @@ class Expr(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_list_expr(self, expr: ListExpr) -> T: ...
|
def visit_list_expr(self, expr: ListExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_dict_expr(self, expr: DictExpr) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_subscript_expr(self, expr: SubscriptExpr) -> T: ...
|
def visit_subscript_expr(self, expr: SubscriptExpr) -> T: ...
|
||||||
|
|
||||||
@@ -370,6 +373,15 @@ class ListExpr(Expr):
|
|||||||
return visitor.visit_list_expr(self)
|
return visitor.visit_list_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DictExpr(Expr):
|
||||||
|
keys: list[Optional[Expr]]
|
||||||
|
values: list[Expr]
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_dict_expr(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class SubscriptExpr(Expr):
|
class SubscriptExpr(Expr):
|
||||||
object: Expr
|
object: Expr
|
||||||
|
|||||||
@@ -150,3 +150,32 @@ extend list[T] {
|
|||||||
|
|
||||||
prop __doc__: str
|
prop __doc__: str
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extend dict[K, V] {
|
||||||
|
def copy: fn() -> dict[K, V]
|
||||||
|
def keys: fn() -> list[K] // TODO: use builtin types
|
||||||
|
def values: fn() -> list[V] // TODO: use builtin types
|
||||||
|
// def items: fn() -> list[tuple[K, V]] // TODO: use builtin types
|
||||||
|
|
||||||
|
// def get: fn(key: K, default: None = None, /) -> V | None
|
||||||
|
def get: fn(key: K, default: V, /) -> V
|
||||||
|
// def get: fn[T](key: K, default: T, /) -> V | T
|
||||||
|
def pop: fn(key: K, /) -> V
|
||||||
|
def pop: fn(key: K, default: V, /) -> V
|
||||||
|
// def pop: fn[T](key: K, default: T, /) -> V | T
|
||||||
|
def __len__: fn() -> int
|
||||||
|
def __getitem__: fn(key: K, /) -> V
|
||||||
|
def __setitem__: fn(key: K, value: V, /) -> None
|
||||||
|
def __delitem__: fn(key: K, /) -> None
|
||||||
|
// def __iter__: fn() -> Iterator[K]
|
||||||
|
def __eq__: fn(value: object, /) -> bool
|
||||||
|
// def __reversed__: fn() -> Iterator[K]
|
||||||
|
|
||||||
|
def __or__: fn(value: dict[K, V], /) -> dict[K, V]
|
||||||
|
// def __or__: fn[K2, V2](value: dict[K2, V2], /) -> dict[K | K2, V | V2]
|
||||||
|
def __ror__: fn(value: dict[K, V], /) -> dict[K, V]
|
||||||
|
// def __ror__: fn[K2, V2](value: dict[K2, V2], /) -> dict[K | K2, V | V2]
|
||||||
|
// def __ior__: fn(value: SupportsKeysAndGetItem[K, V], /) -> dict[K, V]
|
||||||
|
// def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V]
|
||||||
|
|
||||||
|
}
|
||||||
@@ -15,6 +15,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||||
|
"object": {"float", "list", "dict"},
|
||||||
"float": {"int"},
|
"float": {"int"},
|
||||||
"int": {"bool"},
|
"int": {"bool"},
|
||||||
}
|
}
|
||||||
@@ -39,3 +40,14 @@ def define_builtins(reg: TypesRegistry):
|
|||||||
body=BaseType(name="list"),
|
body=BaseType(name="list"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
dict = reg.define_type(
|
||||||
|
"dict",
|
||||||
|
GenericType(
|
||||||
|
name="dict",
|
||||||
|
params=[
|
||||||
|
TypeVar(name="K", bound=None),
|
||||||
|
TypeVar(name="V", bound=None),
|
||||||
|
],
|
||||||
|
body=BaseType(name="dict"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,27 +1,64 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import midas.ast.midas as m
|
import midas.ast.midas as m
|
||||||
|
from midas.ast.location import Location
|
||||||
from midas.checker.builtins import define_builtins
|
from midas.checker.builtins import define_builtins
|
||||||
|
from midas.checker.environment import Environment
|
||||||
|
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
|
||||||
|
from midas.checker.preamble import Preamble
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.reporter import FileReporter, Reporter
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
AliasType,
|
||||||
|
AppliedType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
|
ConstraintType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
|
OverloadedFunction,
|
||||||
|
Predicate,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnknownType,
|
UnknownType,
|
||||||
|
unfold_type,
|
||||||
)
|
)
|
||||||
from midas.lexer.midas import MidasLexer
|
from midas.lexer.midas import MidasLexer
|
||||||
from midas.lexer.token import Token
|
from midas.lexer.token import Token
|
||||||
from midas.parser.midas import MidasParser
|
from midas.parser.midas import MidasParser
|
||||||
|
|
||||||
|
|
||||||
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class TypedParamSpec:
|
||||||
|
pos: list[Function.Argument]
|
||||||
|
mixed: list[Function.Argument]
|
||||||
|
kw: list[Function.Argument]
|
||||||
|
|
||||||
|
|
||||||
|
TypedExpr = tuple[m.Expr, Type]
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class MappedArgument:
|
||||||
|
expr: m.Expr
|
||||||
|
type: Type
|
||||||
|
argument: Function.Argument
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class OverloadCandidate:
|
||||||
|
function: Function
|
||||||
|
mapped: list[MappedArgument]
|
||||||
|
|
||||||
|
|
||||||
|
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
|
||||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||||
|
|
||||||
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||||
@@ -31,12 +68,18 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
self.types: TypesRegistry = types
|
self.types: TypesRegistry = types
|
||||||
self._local_variables: dict[str, TypeVar] = {}
|
self._local_variables: dict[str, TypeVar] = {}
|
||||||
|
|
||||||
|
self._predicate_params: dict[str, Type] = {}
|
||||||
|
|
||||||
self._current_name: Optional[str] = None
|
self._current_name: Optional[str] = None
|
||||||
|
|
||||||
define_builtins(self.types)
|
define_builtins(self.types)
|
||||||
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
|
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
|
||||||
self.process(builtins_path.read_text(), str(builtins_path))
|
self.process(builtins_path.read_text(), str(builtins_path))
|
||||||
|
|
||||||
|
self._bool: Type = self.get_type("bool")
|
||||||
|
|
||||||
|
self._preamble: Environment = Preamble(self.types)
|
||||||
|
|
||||||
def process(self, source: str, path: Optional[str]):
|
def process(self, source: str, path: Optional[str]):
|
||||||
self.reporter = self.reporter.for_file(path)
|
self.reporter = self.reporter.for_file(path)
|
||||||
lexer: MidasLexer = MidasLexer(source)
|
lexer: MidasLexer = MidasLexer(source)
|
||||||
@@ -47,6 +90,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
self.reporter.error(error.token.get_location(), error.message)
|
self.reporter.error(error.token.get_location(), error.message)
|
||||||
self.resolve(stmts)
|
self.resolve(stmts)
|
||||||
|
|
||||||
|
def type_of(self, expr: m.Expr) -> Type:
|
||||||
|
type: Type = expr.accept(self)
|
||||||
|
return type
|
||||||
|
|
||||||
def get_type(self, name: str) -> Type:
|
def get_type(self, name: str) -> Type:
|
||||||
"""Get a type from its name
|
"""Get a type from its name
|
||||||
|
|
||||||
@@ -63,6 +110,19 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
return self._local_variables[name]
|
return self._local_variables[name]
|
||||||
return self.types.get_type(name)
|
return self.types.get_type(name)
|
||||||
|
|
||||||
|
def get_variable(self, name: str) -> Type:
|
||||||
|
if name in self._predicate_params:
|
||||||
|
return self._predicate_params[name]
|
||||||
|
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
|
||||||
|
if predicate is not None:
|
||||||
|
return predicate.type
|
||||||
|
|
||||||
|
global_: Optional[Type] = self._preamble.get(name)
|
||||||
|
if global_ is not None:
|
||||||
|
return global_
|
||||||
|
|
||||||
|
raise NameError(f"Unknown variable '{name}'")
|
||||||
|
|
||||||
def resolve(self, stmts: list[m.Stmt]):
|
def resolve(self, stmts: list[m.Stmt]):
|
||||||
"""Process a sequence of statements
|
"""Process a sequence of statements
|
||||||
|
|
||||||
@@ -72,6 +132,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
for stmt in stmts:
|
for stmt in stmts:
|
||||||
stmt.accept(self)
|
stmt.accept(self)
|
||||||
|
|
||||||
|
def assert_bool(self, expr: m.Expr):
|
||||||
|
type: Type = self.type_of(expr)
|
||||||
|
if not self.types.is_subtype(type, self._bool):
|
||||||
|
self.reporter.error(expr.location, f"Must be a boolean but is {type}")
|
||||||
|
|
||||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||||
name: str = stmt.name.lexeme
|
name: str = stmt.name.lexeme
|
||||||
self._current_name = name
|
self._current_name = name
|
||||||
@@ -106,31 +171,163 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
)
|
)
|
||||||
|
|
||||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||||
self.reporter.warning(stmt.location, "PredicateStmt not yet supported")
|
for spec in stmt.params:
|
||||||
|
for param in spec.mixed:
|
||||||
|
assert param.name is not None
|
||||||
|
self._predicate_params[param.name.lexeme] = param.type.accept(self)
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
type: Type = self.type_of(stmt.body)
|
||||||
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
|
params: list[TypedParamSpec] = [
|
||||||
|
self._visit_param_spec(spec) for spec in stmt.params
|
||||||
|
]
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
|
if not self._is_valid_predicate(type):
|
||||||
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
|
self.reporter.error(
|
||||||
|
stmt.body.location,
|
||||||
|
f"Predicate function body must evaluate to a boolean, got {type}",
|
||||||
|
)
|
||||||
|
if len(params) != 0:
|
||||||
|
type = self._bool
|
||||||
|
for spec in reversed(params):
|
||||||
|
type = Function(
|
||||||
|
pos_args=spec.pos,
|
||||||
|
args=spec.mixed,
|
||||||
|
kw_args=spec.kw,
|
||||||
|
returns=type,
|
||||||
|
)
|
||||||
|
self._predicate_params = {}
|
||||||
|
self.types.define_predicate(
|
||||||
|
stmt.name.lexeme,
|
||||||
|
Predicate(
|
||||||
|
type=type,
|
||||||
|
body=stmt.body,
|
||||||
|
alias=len(params) == 0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
|
def _is_valid_predicate(self, body: Type) -> bool:
|
||||||
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
|
match body:
|
||||||
|
case Function(returns=returns):
|
||||||
|
return self._is_valid_predicate(returns)
|
||||||
|
case _ if self.types.is_subtype(body, self._bool):
|
||||||
|
return True
|
||||||
|
case _:
|
||||||
|
return False
|
||||||
|
|
||||||
def visit_get_expr(self, expr: m.GetExpr) -> None:
|
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type:
|
||||||
self.reporter.warning(expr.location, "GetExpr not yet supported")
|
self.assert_bool(expr.left)
|
||||||
|
self.assert_bool(expr.right)
|
||||||
|
return self._bool
|
||||||
|
|
||||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
|
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
|
||||||
self.reporter.warning(expr.location, "VariableExpr not yet supported")
|
method: Optional[str] = MIDAS_BINARY_METHODS.get(expr.operator.type)
|
||||||
|
if method is None:
|
||||||
|
self.logger.warning(f"Unsupported operator {expr.operator.lexeme}")
|
||||||
|
self.reporter.warning(
|
||||||
|
expr.location, f"Unsupported operator {expr.operator.lexeme}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||||
|
|
||||||
|
def _visit_binary_expr(
|
||||||
|
self, location: Location, left_expr: m.Expr, right_expr: m.Expr, method: str
|
||||||
|
) -> Type:
|
||||||
|
left: Type = self.type_of(left_expr)
|
||||||
|
right: Type = self.type_of(right_expr)
|
||||||
|
|
||||||
|
operation: Optional[Type] = self.types.lookup_member(left, method)
|
||||||
|
if operation is None:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Undefined operation {method} between {left} and {right}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
result: Optional[Type] = self._get_call_result(
|
||||||
|
location,
|
||||||
|
operation,
|
||||||
|
[(right_expr, right)],
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
return result or UnknownType()
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
|
||||||
|
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
|
||||||
|
if method is None:
|
||||||
|
self.logger.warning(f"Unsupported operator {expr.operator.lexeme}")
|
||||||
|
self.reporter.warning(
|
||||||
|
expr.location, f"Unsupported operator {expr.operator.lexeme}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
operand: Type = self.type_of(expr.right)
|
||||||
|
operation: Optional[Type] = self.types.lookup_member(operand, method)
|
||||||
|
if operation is None:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Undefined operation {method} for {operand}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
result: Optional[Type] = self._get_call_result(
|
||||||
|
expr.location,
|
||||||
|
operation,
|
||||||
|
[],
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
return result or UnknownType()
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: m.CallExpr) -> Type:
|
||||||
|
callee: Type = expr.callee.accept(self)
|
||||||
|
positional: list[TypedExpr] = [
|
||||||
|
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||||
|
]
|
||||||
|
keywords: dict[str, TypedExpr] = {
|
||||||
|
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
self._get_call_result(
|
||||||
|
expr.location,
|
||||||
|
callee,
|
||||||
|
positional,
|
||||||
|
keywords,
|
||||||
|
)
|
||||||
|
or UnknownType()
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: m.GetExpr) -> Type:
|
||||||
|
object: Type = expr.expr.accept(self)
|
||||||
|
member: Optional[Type] = self.types.lookup_member(object, expr.name.lexeme)
|
||||||
|
if member is None:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Unknown member '{expr.name.lexeme}' of {object}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return member
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: m.VariableExpr) -> Type:
|
||||||
|
return self.get_variable(expr.name.lexeme)
|
||||||
|
|
||||||
|
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
|
||||||
return expr.expr.accept(self)
|
return expr.expr.accept(self)
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type:
|
||||||
self.reporter.warning(expr.location, "LiteralExpr not yet supported")
|
match expr.value:
|
||||||
|
case bool(): # Must be before int
|
||||||
|
return self.types.get_type("bool")
|
||||||
|
case int():
|
||||||
|
return self.types.get_type("int")
|
||||||
|
case float():
|
||||||
|
return self.types.get_type("float")
|
||||||
|
case str():
|
||||||
|
return self.types.get_type("str")
|
||||||
|
case _:
|
||||||
|
self.reporter.warning(expr.location, f"Unknown literal {expr}")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type:
|
||||||
self.reporter.warning(expr.location, "WildcardExpr not yet supported")
|
return self.get_variable("_")
|
||||||
|
|
||||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||||
name: str = type.name.lexeme
|
name: str = type.name.lexeme
|
||||||
@@ -153,10 +350,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||||
type_: Type = type.type.accept(self)
|
return ConstraintType(
|
||||||
type.constraint.accept(self)
|
type=type.type.accept(self),
|
||||||
# TODO
|
constraint=type.constraint,
|
||||||
return UnknownType()
|
)
|
||||||
|
|
||||||
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
|
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
|
||||||
return ComplexType(
|
return ComplexType(
|
||||||
@@ -172,8 +369,17 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
)
|
)
|
||||||
|
|
||||||
def visit_function_type(self, type: m.FunctionType) -> Type:
|
def visit_function_type(self, type: m.FunctionType) -> Type:
|
||||||
n_pos_args: int = len(type.pos_args)
|
params: TypedParamSpec = self._visit_param_spec(type.params)
|
||||||
n_args: int = len(type.args)
|
return Function(
|
||||||
|
pos_args=params.pos,
|
||||||
|
args=params.mixed,
|
||||||
|
kw_args=params.kw,
|
||||||
|
returns=type.returns.accept(self),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec:
|
||||||
|
n_pos: int = len(spec.pos)
|
||||||
|
n_mixed: int = len(spec.mixed)
|
||||||
|
|
||||||
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
|
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
|
||||||
return Function.Argument(
|
return Function.Argument(
|
||||||
@@ -183,14 +389,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
required=arg.required,
|
required=arg.required,
|
||||||
)
|
)
|
||||||
|
|
||||||
return Function(
|
return TypedParamSpec(
|
||||||
pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
|
pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)],
|
||||||
args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
|
mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)],
|
||||||
kw_args=[
|
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
|
||||||
process_arg(arg, i + n_pos_args + n_args)
|
|
||||||
for i, arg in enumerate(type.kw_args)
|
|
||||||
],
|
|
||||||
returns=type.returns.accept(self),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _resolve_type_params(self, params: list[m.TypeParam]):
|
def _resolve_type_params(self, params: list[m.TypeParam]):
|
||||||
@@ -204,3 +406,343 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
self._local_variables[name] = var
|
self._local_variables[name] = var
|
||||||
vars.append(var)
|
vars.append(var)
|
||||||
return vars
|
return vars
|
||||||
|
|
||||||
|
def _get_call_result(
|
||||||
|
self,
|
||||||
|
location: Location,
|
||||||
|
callee: Type,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> Optional[Type]:
|
||||||
|
"""Get the result type of a function call
|
||||||
|
|
||||||
|
If the function has overloads, the function will try to resolve the
|
||||||
|
appropriate signature.
|
||||||
|
Argument types are matched to the defined parameters.
|
||||||
|
The function doesn't take the raw expression as a parameter to accommodate
|
||||||
|
for desugared calls such as for operators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location (Location): the call location
|
||||||
|
callee (Type): the called function
|
||||||
|
positional (list[TypedExpr]): the list positional arguments
|
||||||
|
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the return type of the call, or `None` if either
|
||||||
|
the call is invalid or no overload matched the arguments uniquely
|
||||||
|
"""
|
||||||
|
match callee:
|
||||||
|
case Function() as function:
|
||||||
|
valid: bool
|
||||||
|
mapped: list[MappedArgument]
|
||||||
|
valid, mapped = self.map_call_arguments(
|
||||||
|
function, location, positional, keywords
|
||||||
|
)
|
||||||
|
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
||||||
|
if not valid:
|
||||||
|
return None
|
||||||
|
return function.returns
|
||||||
|
|
||||||
|
case OverloadedFunction(overloads=overloads):
|
||||||
|
function = self._match_overload(
|
||||||
|
overloads, location, positional, keywords, report_errors
|
||||||
|
)
|
||||||
|
if function is None:
|
||||||
|
return None
|
||||||
|
return function.returns
|
||||||
|
|
||||||
|
case AppliedType(body=body):
|
||||||
|
return self._get_call_result(
|
||||||
|
location, body, positional, keywords, report_errors
|
||||||
|
)
|
||||||
|
|
||||||
|
case UnknownType():
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
case _:
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(location, f"{callee} is not callable")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _are_arguments_valid(
|
||||||
|
self,
|
||||||
|
arguments: list[MappedArgument],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
"""Check whether the passed argument types correspond to their matched parameter definitions
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
||||||
|
"""
|
||||||
|
valid: bool = True
|
||||||
|
for arg in arguments:
|
||||||
|
if not self.types.is_subtype(arg.type, arg.argument.type):
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
arg.expr.location,
|
||||||
|
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
return valid
|
||||||
|
|
||||||
|
def _match_overload(
|
||||||
|
self,
|
||||||
|
overloads: list[Type],
|
||||||
|
location: Location,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> Optional[Function]:
|
||||||
|
"""Try and resolve the appropriate overload for the given arguments
|
||||||
|
|
||||||
|
Args:
|
||||||
|
overloads (list[Type]): the list of possible overloads
|
||||||
|
location (Location): the call location
|
||||||
|
positional (list[TypedExpr]): the list of positional arguments
|
||||||
|
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Function]: the resolved function signature if it can be
|
||||||
|
determined unambiguously, or `None`.
|
||||||
|
"""
|
||||||
|
candidates: list[OverloadCandidate] = []
|
||||||
|
for overload in overloads:
|
||||||
|
function: Type = unfold_type(overload)
|
||||||
|
if not isinstance(function, Function):
|
||||||
|
if report_errors:
|
||||||
|
self.logger.error(
|
||||||
|
f"Overload is not a function: {overload} is {function}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
valid, mapped = self.map_call_arguments(
|
||||||
|
function=function,
|
||||||
|
location=location,
|
||||||
|
positional=positional,
|
||||||
|
keywords=keywords,
|
||||||
|
report_errors=False,
|
||||||
|
)
|
||||||
|
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||||
|
candidates.append(
|
||||||
|
OverloadCandidate(
|
||||||
|
function=function,
|
||||||
|
mapped=mapped,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||||
|
kw_types: str = ", ".join(
|
||||||
|
f"{name}: {type}" for name, (_, type) in keywords.items()
|
||||||
|
)
|
||||||
|
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||||
|
|
||||||
|
n_candidates: int = len(candidates)
|
||||||
|
|
||||||
|
# Exactly 1 match -> return it
|
||||||
|
if n_candidates == 1:
|
||||||
|
return candidates[0].function
|
||||||
|
|
||||||
|
# No match -> invalid call
|
||||||
|
if n_candidates == 0:
|
||||||
|
overloads_str: str = ", ".join(map(str, overloads))
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"No matching overload in [{overloads_str}] {for_args}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Multiple matches -> see if one <: all others (more specific)
|
||||||
|
for i1, c1 in enumerate(candidates):
|
||||||
|
mapped1: list[MappedArgument] = c1.mapped
|
||||||
|
best_match: bool = True
|
||||||
|
for i2, c2 in enumerate(candidates):
|
||||||
|
if i1 == i2:
|
||||||
|
continue
|
||||||
|
mapped2: list[MappedArgument] = c2.mapped
|
||||||
|
if not self._are_mapped_subtypes(mapped1, mapped2):
|
||||||
|
best_match = False
|
||||||
|
break
|
||||||
|
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
||||||
|
if best_match:
|
||||||
|
return c1.function
|
||||||
|
|
||||||
|
candidates_str: str = ", ".join(
|
||||||
|
str(candidate.function) for candidate in candidates
|
||||||
|
)
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Multiple matching overloads {for_args}: {candidates_str}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def map_call_arguments(
|
||||||
|
self,
|
||||||
|
function: Function,
|
||||||
|
location: Location,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> tuple[bool, list[MappedArgument]]:
|
||||||
|
"""Map call arguments to a function's parameters as defined in its signature
|
||||||
|
|
||||||
|
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||||
|
with the arguments passed at the call site
|
||||||
|
|
||||||
|
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
||||||
|
unless `report_errors` is set to `False`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function (Function): the function definition
|
||||||
|
location (Location): the call location
|
||||||
|
positional (list[TypedExpr]): the list of positional arguments
|
||||||
|
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
||||||
|
the call is valid and the list of mapped arguments
|
||||||
|
"""
|
||||||
|
set_args: set[str] = set()
|
||||||
|
|
||||||
|
required_positional: list[str] = [
|
||||||
|
arg.name for arg in function.pos_args + function.args if arg.required
|
||||||
|
]
|
||||||
|
required_keyword: list[str] = [
|
||||||
|
arg.name for arg in function.kw_args if arg.required
|
||||||
|
]
|
||||||
|
|
||||||
|
mapped: list[MappedArgument] = []
|
||||||
|
|
||||||
|
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||||
|
mixed_params: list[Function.Argument] = list(function.args)
|
||||||
|
kw_params: dict[str, Function.Argument] = {
|
||||||
|
arg.name: arg for arg in function.kw_args
|
||||||
|
}
|
||||||
|
|
||||||
|
valid_call: bool = True
|
||||||
|
|
||||||
|
# TODO: handle *args and **kwargs sinks
|
||||||
|
for arg in positional:
|
||||||
|
param: Function.Argument
|
||||||
|
if len(pos_params) != 0:
|
||||||
|
param = pos_params.pop(0)
|
||||||
|
elif len(mixed_params) != 0:
|
||||||
|
param = mixed_params.pop(0)
|
||||||
|
else:
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, "Too many positional arguments"
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
break
|
||||||
|
name: str = param.name
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||||
|
for name, arg in keywords.items():
|
||||||
|
param: Function.Argument
|
||||||
|
if name not in kw_params:
|
||||||
|
if report_errors:
|
||||||
|
if name in set_args:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, f"Multiple values for argument '{name}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
continue
|
||||||
|
param = kw_params.pop(name)
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def join_args(args: list[str]) -> str:
|
||||||
|
args = list(map(lambda a: f"'{a}'", args))
|
||||||
|
if len(args) == 0:
|
||||||
|
return ""
|
||||||
|
if len(args) == 1:
|
||||||
|
return args[0]
|
||||||
|
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||||
|
|
||||||
|
if len(required_positional) != 0:
|
||||||
|
plural: str = "" if len(required_positional) == 1 else "s"
|
||||||
|
args: str = join_args(required_positional)
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Missing required positional argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
|
||||||
|
if len(required_keyword) != 0:
|
||||||
|
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||||
|
args: str = join_args(required_keyword)
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Missing required keyword argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
|
||||||
|
return valid_call, mapped
|
||||||
|
|
||||||
|
def _are_mapped_subtypes(
|
||||||
|
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
|
||||||
|
) -> bool:
|
||||||
|
"""Check whether the given argument mappings are subtype/supertype of one another
|
||||||
|
|
||||||
|
This function checks whether the argument mappings `mapped1` are subtypes
|
||||||
|
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
||||||
|
of the corresponding parameter in `mapped2`, `False` is returned.
|
||||||
|
|
||||||
|
This is used to check whether a given overload is
|
||||||
|
a more specific function/ a subtype of another.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
||||||
|
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
||||||
|
"""
|
||||||
|
by_expr: dict[m.Expr, Type] = {}
|
||||||
|
for arg in mapped1:
|
||||||
|
by_expr[arg.expr] = arg.argument.type
|
||||||
|
|
||||||
|
for arg in mapped2:
|
||||||
|
type2: Type = arg.argument.type
|
||||||
|
type1: Type = by_expr[arg.expr]
|
||||||
|
if not self.types.is_subtype(type1, type2):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import ast
|
import ast
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
OPERATOR_METHODS: dict[Type[ast.operator], str] = {
|
from midas.lexer.token import TokenType
|
||||||
|
|
||||||
|
PY_OPERATOR_METHODS: dict[Type[ast.operator], str] = {
|
||||||
ast.Add: "__add__",
|
ast.Add: "__add__",
|
||||||
ast.Sub: "__sub__",
|
ast.Sub: "__sub__",
|
||||||
ast.Mult: "__mul__",
|
ast.Mult: "__mul__",
|
||||||
@@ -17,9 +19,9 @@ OPERATOR_METHODS: dict[Type[ast.operator], str] = {
|
|||||||
ast.FloorDiv: "__floordiv__",
|
ast.FloorDiv: "__floordiv__",
|
||||||
}
|
}
|
||||||
|
|
||||||
COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
PY_COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
||||||
ast.Eq: "__eq__",
|
ast.Eq: "__eq__",
|
||||||
# ast.NotEq: "__noteq__",
|
ast.NotEq: "__eq__",
|
||||||
ast.Lt: "__lt__",
|
ast.Lt: "__lt__",
|
||||||
ast.LtE: "__le__",
|
ast.LtE: "__le__",
|
||||||
ast.Gt: "__gt__",
|
ast.Gt: "__gt__",
|
||||||
@@ -30,9 +32,40 @@ COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
|||||||
# ast.NotIn: "__notin__",
|
# ast.NotIn: "__notin__",
|
||||||
}
|
}
|
||||||
|
|
||||||
UNARY_METHODS: dict[Type[ast.unaryop], str] = {
|
PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = {
|
||||||
ast.Invert: "__invert__",
|
ast.Invert: "__invert__",
|
||||||
# ast.Not: "",
|
# ast.Not: "",
|
||||||
ast.UAdd: "__pos__",
|
ast.UAdd: "__pos__",
|
||||||
ast.USub: "__neg__",
|
ast.USub: "__neg__",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
|
||||||
|
# TokenType.PLUS: "__add__",
|
||||||
|
TokenType.MINUS: "__sub__",
|
||||||
|
TokenType.STAR: "__mul__",
|
||||||
|
TokenType.SLASH: "__truediv__",
|
||||||
|
# TokenType.MODULO: "__mod__",
|
||||||
|
# TokenType.POW: "__pow__",
|
||||||
|
# ast.BitOr: "__or__",
|
||||||
|
# ast.BitXor: "__xor__",
|
||||||
|
# ast.BitAnd: "__and__",
|
||||||
|
# ast.FloorDiv: "__floordiv__",
|
||||||
|
TokenType.EQUAL_EQUAL: "__eq__",
|
||||||
|
TokenType.BANG_EQUAL: "__eq__",
|
||||||
|
TokenType.LESS: "__lt__",
|
||||||
|
TokenType.LESS_EQUAL: "__le__",
|
||||||
|
TokenType.GREATER: "__gt__",
|
||||||
|
TokenType.GREATER_EQUAL: "__ge__",
|
||||||
|
# ast.Is: "__is__",
|
||||||
|
# ast.IsNot: "__isnot__",
|
||||||
|
# ast.In: "__in__",
|
||||||
|
# ast.NotIn: "__notin__",
|
||||||
|
}
|
||||||
|
|
||||||
|
MIDAS_UNARY_METHODS: dict[TokenType, str] = {
|
||||||
|
# ast.Invert: "__invert__",
|
||||||
|
# ast.Not: "",
|
||||||
|
# TokenType.PLUS: "__pos__",
|
||||||
|
TokenType.MINUS: "__neg__",
|
||||||
|
}
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ class Preamble(Environment):
|
|||||||
# TODO: more specific arg types
|
# TODO: more specific arg types
|
||||||
self._def_function(
|
self._def_function(
|
||||||
name=name,
|
name=name,
|
||||||
pos=[Param("object", TopType())],
|
pos=[Param("object", TopType(), required=False)],
|
||||||
returns=self._types.get_type(name),
|
returns=self._types.get_type(name),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,11 @@ from typing import Optional
|
|||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS
|
from midas.checker.operators import (
|
||||||
|
PY_COMPARATOR_METHODS,
|
||||||
|
PY_OPERATOR_METHODS,
|
||||||
|
PY_UNARY_METHODS,
|
||||||
|
)
|
||||||
from midas.checker.preamble import Preamble
|
from midas.checker.preamble import Preamble
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.reporter import FileReporter, Reporter
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
@@ -376,7 +380,7 @@ class PythonTyper(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
method: Optional[str] = PY_OPERATOR_METHODS.get(expr.operator.__class__)
|
||||||
if method is None:
|
if method is None:
|
||||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
self.reporter.warning(
|
self.reporter.warning(
|
||||||
@@ -387,7 +391,7 @@ class PythonTyper(
|
|||||||
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||||
|
|
||||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||||
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
method: Optional[str] = PY_COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||||
if method is None:
|
if method is None:
|
||||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
self.reporter.warning(
|
self.reporter.warning(
|
||||||
@@ -420,7 +424,7 @@ class PythonTyper(
|
|||||||
return result or UnknownType()
|
return result or UnknownType()
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
||||||
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
|
method: Optional[str] = PY_UNARY_METHODS.get(expr.operator.__class__)
|
||||||
if method is None:
|
if method is None:
|
||||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
self.reporter.warning(
|
self.reporter.warning(
|
||||||
@@ -552,6 +556,46 @@ class PythonTyper(
|
|||||||
)
|
)
|
||||||
return self.types.apply_generic(list_type, [UnknownType()])
|
return self.types.apply_generic(list_type, [UnknownType()])
|
||||||
|
|
||||||
|
def visit_dict_expr(self, expr: p.DictExpr) -> Type:
|
||||||
|
dict_type: Type = self.types.get_type("dict")
|
||||||
|
|
||||||
|
key_types: list[Type] = []
|
||||||
|
value_types: list[Type] = []
|
||||||
|
for key, value in zip(expr.keys, expr.values):
|
||||||
|
if key is None:
|
||||||
|
self.reporter.warning(
|
||||||
|
value.location, "Dictionary unpacking not supported"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
key_types.append(self.type_of(key))
|
||||||
|
value_types.append(self.type_of(value))
|
||||||
|
|
||||||
|
key_types = self.types.reduce_types(key_types)
|
||||||
|
value_types = self.types.reduce_types(value_types)
|
||||||
|
|
||||||
|
if len(key_types) == 0 or len(value_types) == 0:
|
||||||
|
return dict_type
|
||||||
|
|
||||||
|
key_type: Type = UnknownType()
|
||||||
|
value_type: Type = UnknownType()
|
||||||
|
|
||||||
|
if len(key_types) == 1:
|
||||||
|
key_type = key_types[0]
|
||||||
|
else:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Heterogeneous dict keys: {key_types}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(value_types) == 1:
|
||||||
|
value_type = value_types[0]
|
||||||
|
else:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Heterogeneous dict values: {value_types}",
|
||||||
|
)
|
||||||
|
return self.types.apply_generic(dict_type, [key_type, value_type])
|
||||||
|
|
||||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type:
|
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type:
|
||||||
object: Type = self.type_of(expr.object)
|
object: Type = self.type_of(expr.object)
|
||||||
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
|
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
|
||||||
@@ -612,7 +656,7 @@ class PythonTyper(
|
|||||||
If the function has overloads, the function will try to resolve the
|
If the function has overloads, the function will try to resolve the
|
||||||
appropriate signature.
|
appropriate signature.
|
||||||
Argument types are matched to the defined parameters.
|
Argument types are matched to the defined parameters.
|
||||||
The function doesn't take the raw expression as a parameter to accomodate
|
The function doesn't take the raw expression as a parameter to accommodate
|
||||||
for desugared calls such as for operators.
|
for desugared calls such as for operators.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -703,7 +747,7 @@ class PythonTyper(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[Function]: the resolved function signature if it can be
|
Optional[Function]: the resolved function signature if it can be
|
||||||
determined unambigously, or `None`.
|
determined unambiguously, or `None`.
|
||||||
"""
|
"""
|
||||||
candidates: list[OverloadCandidate] = []
|
candidates: list[OverloadCandidate] = []
|
||||||
for overload in overloads:
|
for overload in overloads:
|
||||||
|
|||||||
@@ -7,10 +7,12 @@ from midas.checker.types import (
|
|||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
|
ConstraintType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
|
Predicate,
|
||||||
TopType,
|
TopType,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
@@ -24,6 +26,7 @@ class TypesRegistry:
|
|||||||
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
|
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
|
||||||
self._types: dict[str, Type] = {}
|
self._types: dict[str, Type] = {}
|
||||||
self._members: dict[str, dict[str, Type]] = {}
|
self._members: dict[str, dict[str, Type]] = {}
|
||||||
|
self._predicates: dict[str, Predicate] = {}
|
||||||
|
|
||||||
def get_type(self, name: str) -> Type:
|
def get_type(self, name: str) -> Type:
|
||||||
"""Get a type from its name
|
"""Get a type from its name
|
||||||
@@ -81,6 +84,11 @@ class TypesRegistry:
|
|||||||
else:
|
else:
|
||||||
members[member_name] = member_type
|
members[member_name] = member_type
|
||||||
|
|
||||||
|
def define_predicate(self, name: str, predicate: Predicate):
|
||||||
|
if name in self._predicates:
|
||||||
|
raise ValueError(f"Predicate {name} already defined")
|
||||||
|
self._predicates[name] = predicate
|
||||||
|
|
||||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||||
"""Check whether `type1` is a subtype of `type2`
|
"""Check whether `type1` is a subtype of `type2`
|
||||||
|
|
||||||
@@ -123,6 +131,9 @@ class TypesRegistry:
|
|||||||
return False
|
return False
|
||||||
return self.is_subtype(bound, type2)
|
return self.is_subtype(bound, type2)
|
||||||
|
|
||||||
|
case (ConstraintType(type=base1), _):
|
||||||
|
return self.is_subtype(base1, type2)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# TODO: verify the logic in here
|
# TODO: verify the logic in here
|
||||||
@@ -345,3 +356,6 @@ class TypesRegistry:
|
|||||||
case _:
|
case _:
|
||||||
self.logger.debug(f"Can't get member on {type}")
|
self.logger.debug(f"Can't get member on {type}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def lookup_predicate(self, name: str) -> Optional[Predicate]:
|
||||||
|
return self._predicates.get(name)
|
||||||
|
|||||||
@@ -213,6 +213,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
for item in expr.items:
|
for item in expr.items:
|
||||||
self.resolve(item)
|
self.resolve(item)
|
||||||
|
|
||||||
|
def visit_dict_expr(self, expr: p.DictExpr) -> None:
|
||||||
|
for key in expr.keys:
|
||||||
|
if key is not None:
|
||||||
|
self.resolve(key)
|
||||||
|
for value in expr.values:
|
||||||
|
self.resolve(value)
|
||||||
|
|
||||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
||||||
self.resolve(expr.object)
|
self.resolve(expr.object)
|
||||||
self.resolve(expr.index)
|
self.resolve(expr.index)
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional, assert_never
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
|
from midas.ast.printer import MidasPrinter
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
@@ -130,6 +133,16 @@ class AppliedType:
|
|||||||
return f"{self.name}[{', '.join(map(str, self.args))}]"
|
return f"{self.name}[{', '.join(map(str, self.args))}]"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class ConstraintType:
|
||||||
|
type: Type
|
||||||
|
constraint: m.Expr
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
printer = MidasPrinter()
|
||||||
|
return f"{self.type} where {printer.print(self.constraint)}"
|
||||||
|
|
||||||
|
|
||||||
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||||
def sub_argument(arg: Function.Argument):
|
def sub_argument(arg: Function.Argument):
|
||||||
return Function.Argument(
|
return Function.Argument(
|
||||||
@@ -195,6 +208,12 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
body=substitute_typevars(body, substitutions),
|
body=substitute_typevars(body, substitutions),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case ConstraintType():
|
||||||
|
return ConstraintType(
|
||||||
|
type=substitute_typevars(type.type, substitutions),
|
||||||
|
constraint=type.constraint,
|
||||||
|
)
|
||||||
|
|
||||||
case TypeVar(name=name):
|
case TypeVar(name=name):
|
||||||
if name in substitutions:
|
if name in substitutions:
|
||||||
return substitutions[name]
|
return substitutions[name]
|
||||||
@@ -203,9 +222,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
case UnknownType() | UnitType():
|
case UnknownType() | UnitType():
|
||||||
return type
|
return type
|
||||||
|
|
||||||
case _:
|
case TopType() | GenericType():
|
||||||
raise NotImplementedError(f"Unsupported type {type}")
|
raise NotImplementedError(f"Unsupported type {type}")
|
||||||
|
|
||||||
|
# Ensure exhaustiveness
|
||||||
|
case _:
|
||||||
|
assert_never(type)
|
||||||
|
|
||||||
|
|
||||||
def unfold_type(type: Type) -> Type:
|
def unfold_type(type: Type) -> Type:
|
||||||
match type:
|
match type:
|
||||||
@@ -215,6 +238,65 @@ def unfold_type(type: Type) -> Type:
|
|||||||
return type
|
return type
|
||||||
|
|
||||||
|
|
||||||
|
def to_annotation(type: Type) -> str:
|
||||||
|
def _args_annotation(func: Function) -> str:
|
||||||
|
if len(func.kw_args) != 0:
|
||||||
|
return "..."
|
||||||
|
|
||||||
|
args: str = ", ".join(
|
||||||
|
to_annotation(arg.type) for arg in func.pos_args + func.args
|
||||||
|
)
|
||||||
|
return f"[{args}]"
|
||||||
|
|
||||||
|
match type:
|
||||||
|
case TopType():
|
||||||
|
return "Any"
|
||||||
|
|
||||||
|
case BaseType(name=name):
|
||||||
|
return name
|
||||||
|
|
||||||
|
case AliasType(name=name):
|
||||||
|
return name
|
||||||
|
|
||||||
|
case UnknownType():
|
||||||
|
return "Any"
|
||||||
|
|
||||||
|
case UnitType():
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
case Function(returns=returns):
|
||||||
|
params_annot: str = _args_annotation(type)
|
||||||
|
return f"Callable[{params_annot}, {to_annotation(returns)}]"
|
||||||
|
|
||||||
|
case OverloadedFunction():
|
||||||
|
return "Callable"
|
||||||
|
|
||||||
|
case ComplexType() | ExtensionType():
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
case TypeVar(name=name):
|
||||||
|
return name
|
||||||
|
|
||||||
|
case GenericType(name=name, params=params):
|
||||||
|
return f"{name}[{', '.join(map(to_annotation, params))}]"
|
||||||
|
|
||||||
|
case AppliedType(name=name, args=args):
|
||||||
|
return f"{name}[{', '.join(map(to_annotation, args))}]"
|
||||||
|
|
||||||
|
case ConstraintType():
|
||||||
|
return str(type)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
assert_never(type)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Predicate:
|
||||||
|
type: Type
|
||||||
|
body: m.Expr
|
||||||
|
alias: bool
|
||||||
|
|
||||||
|
|
||||||
Type = (
|
Type = (
|
||||||
TopType
|
TopType
|
||||||
| BaseType
|
| BaseType
|
||||||
@@ -228,4 +310,5 @@ Type = (
|
|||||||
| TypeVar
|
| TypeVar
|
||||||
| GenericType
|
| GenericType
|
||||||
| AppliedType
|
| AppliedType
|
||||||
|
| ConstraintType
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -38,5 +38,5 @@ def compile(
|
|||||||
if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)):
|
if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)):
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
generator = Generator(workdir=source_path.parent)
|
generator = Generator(workdir=source_path.parent, types=checker.types)
|
||||||
generator.generate(typed_ast, source_path)
|
generator.generate(typed_ast, source_path)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from typing import TextIO
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
|
from midas.ast.printer import MidasPrinter
|
||||||
from midas.checker.checker import TypeChecker
|
from midas.checker.checker import TypeChecker
|
||||||
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
|
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
|
||||||
|
|
||||||
@@ -35,6 +36,7 @@ def dump_registry(
|
|||||||
for types_file in types:
|
for types_file in types:
|
||||||
checker.import_midas(Path(types_file.name).resolve())
|
checker.import_midas(Path(types_file.name).resolve())
|
||||||
|
|
||||||
|
print("##### Types #####")
|
||||||
for name, type in checker.types._types.items():
|
for name, type in checker.types._types.items():
|
||||||
members: dict[str, Type] = checker.types._members.get(name, {})
|
members: dict[str, Type] = checker.types._members.get(name, {})
|
||||||
print(f"{name} = {base_type(type)}")
|
print(f"{name} = {base_type(type)}")
|
||||||
@@ -42,3 +44,17 @@ def dump_registry(
|
|||||||
print(" " * 4 + "Members:")
|
print(" " * 4 + "Members:")
|
||||||
for member_name, member_type in members.items():
|
for member_name, member_type in members.items():
|
||||||
print(" " * 8 + f"{member_name}: {member_type}")
|
print(" " * 8 + f"{member_name}: {member_type}")
|
||||||
|
|
||||||
|
print("##### Predicates #####")
|
||||||
|
printer = MidasPrinter()
|
||||||
|
for name, predicate in checker.types._predicates.items():
|
||||||
|
body: str = printer.print(predicate.body)
|
||||||
|
if predicate.alias:
|
||||||
|
print(f"{name}: {predicate.type} = {body}")
|
||||||
|
else:
|
||||||
|
print(f"{name}{predicate.type}:")
|
||||||
|
body = "\n".join(
|
||||||
|
" " + ("return " if i == 0 else "") + line
|
||||||
|
for i, line in enumerate(body.split("\n"))
|
||||||
|
)
|
||||||
|
print(body)
|
||||||
|
|||||||
224
midas/generator/constraints.py
Normal file
224
midas/generator/constraints.py
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
import ast
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.types import (
|
||||||
|
Function,
|
||||||
|
Predicate,
|
||||||
|
Type,
|
||||||
|
to_annotation,
|
||||||
|
)
|
||||||
|
from midas.lexer.token import TokenType
|
||||||
|
|
||||||
|
LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = {
|
||||||
|
TokenType.AND: ast.And,
|
||||||
|
# TokenType.OR: ast.Or,
|
||||||
|
}
|
||||||
|
|
||||||
|
BINARY_OPERATORS: dict[TokenType, type[ast.operator]] = {
|
||||||
|
# TokenType.PLUS: ast.Add,
|
||||||
|
TokenType.MINUS: ast.Sub,
|
||||||
|
TokenType.STAR: ast.Mult,
|
||||||
|
TokenType.SLASH: ast.Div,
|
||||||
|
}
|
||||||
|
|
||||||
|
UNARY_OPERATORS: dict[TokenType, type[ast.unaryop]] = {
|
||||||
|
# TokenType.PLUS: ast.UAdd,
|
||||||
|
TokenType.MINUS: ast.USub,
|
||||||
|
}
|
||||||
|
|
||||||
|
COMPARISON_OPERATORS: dict[TokenType, type[ast.cmpop]] = {
|
||||||
|
TokenType.GREATER: ast.Gt,
|
||||||
|
TokenType.GREATER_EQUAL: ast.GtE,
|
||||||
|
TokenType.LESS: ast.Lt,
|
||||||
|
TokenType.LESS_EQUAL: ast.LtE,
|
||||||
|
TokenType.EQUAL_EQUAL: ast.Eq,
|
||||||
|
TokenType.BANG_EQUAL: ast.NotEq,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||||
|
def __init__(self, types: TypesRegistry):
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
self._id: int = 0
|
||||||
|
self._definitions: list[ast.stmt] = []
|
||||||
|
self._aliases: dict[str, str] = {}
|
||||||
|
|
||||||
|
def get_definitions(self) -> list[ast.stmt]:
|
||||||
|
return self._definitions
|
||||||
|
|
||||||
|
def generate(self, expr: m.Expr) -> ast.expr:
|
||||||
|
match expr:
|
||||||
|
case m.VariableExpr():
|
||||||
|
return expr.accept(self)
|
||||||
|
case _:
|
||||||
|
func = Function(
|
||||||
|
pos_args=[],
|
||||||
|
args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="_",
|
||||||
|
type=self.types.get_type("Any"),
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
kw_args=[],
|
||||||
|
returns=self.types.get_type("bool"),
|
||||||
|
)
|
||||||
|
alias: str = self.make_alias(None)
|
||||||
|
definition: ast.stmt = self.make_definition(
|
||||||
|
alias, Predicate(type=func, body=expr, alias=False)
|
||||||
|
)
|
||||||
|
self._definitions.append(definition)
|
||||||
|
return ast.Name(id=alias)
|
||||||
|
|
||||||
|
def make_alias(self, name: Optional[str]) -> str:
|
||||||
|
suffix: str
|
||||||
|
if name is None:
|
||||||
|
suffix = f"p{self._id}"
|
||||||
|
self._id += 1
|
||||||
|
else:
|
||||||
|
suffix = name
|
||||||
|
alias: str = f"__midas_{suffix}__"
|
||||||
|
return alias
|
||||||
|
|
||||||
|
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
|
||||||
|
body: ast.expr = predicate.body.accept(self)
|
||||||
|
if predicate.alias:
|
||||||
|
return ast.Assign(
|
||||||
|
targets=[
|
||||||
|
ast.Name(id=name),
|
||||||
|
],
|
||||||
|
value=body,
|
||||||
|
)
|
||||||
|
return self.make_func(name, [ast.Return(value=body)], predicate.type)
|
||||||
|
|
||||||
|
def make_args(self, func: Function) -> ast.arguments:
|
||||||
|
return ast.arguments(
|
||||||
|
posonlyargs=[
|
||||||
|
ast.arg(
|
||||||
|
arg=arg.name,
|
||||||
|
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||||
|
)
|
||||||
|
for arg in func.pos_args
|
||||||
|
],
|
||||||
|
args=[
|
||||||
|
ast.arg(
|
||||||
|
arg=arg.name,
|
||||||
|
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||||
|
)
|
||||||
|
for arg in func.args
|
||||||
|
],
|
||||||
|
kwonlyargs=[
|
||||||
|
ast.arg(
|
||||||
|
arg=arg.name,
|
||||||
|
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||||
|
)
|
||||||
|
for arg in func.kw_args
|
||||||
|
],
|
||||||
|
defaults=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
def make_func(
|
||||||
|
self, name: str, inner_body: list[ast.stmt], type: Type, level: int = 0
|
||||||
|
) -> ast.stmt:
|
||||||
|
match type:
|
||||||
|
case Function(returns=Function()):
|
||||||
|
inner_name: str = f"inner{level}"
|
||||||
|
return ast.FunctionDef(
|
||||||
|
name=name,
|
||||||
|
args=self.make_args(type),
|
||||||
|
body=[
|
||||||
|
self.make_func(inner_name, inner_body, type.returns, level + 1),
|
||||||
|
ast.Return(value=ast.Name(id=inner_name)),
|
||||||
|
],
|
||||||
|
returns=ast.Constant(value=to_annotation(type.returns)),
|
||||||
|
decorator_list=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
case Function():
|
||||||
|
return ast.FunctionDef(
|
||||||
|
name=name,
|
||||||
|
args=self.make_args(type),
|
||||||
|
body=inner_body,
|
||||||
|
returns=ast.Constant(value=to_annotation(type.returns)),
|
||||||
|
decorator_list=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Expected function, got {type!r}")
|
||||||
|
|
||||||
|
def get_predicate(self, name: str) -> Optional[ast.expr]:
|
||||||
|
if name not in self._aliases:
|
||||||
|
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
|
||||||
|
if predicate is None:
|
||||||
|
return None
|
||||||
|
alias: str = self.make_alias(name)
|
||||||
|
self._aliases[name] = alias
|
||||||
|
self._definitions.append(self.make_definition(alias, predicate))
|
||||||
|
|
||||||
|
return ast.Name(id=self._aliases[name])
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: m.LogicalExpr) -> ast.expr:
|
||||||
|
return ast.BoolOp(
|
||||||
|
op=LOGICAL_OPERATORS[expr.operator.type](),
|
||||||
|
values=[
|
||||||
|
expr.left.accept(self),
|
||||||
|
expr.right.accept(self),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: m.BinaryExpr) -> ast.expr:
|
||||||
|
op: TokenType = expr.operator.type
|
||||||
|
if op in BINARY_OPERATORS:
|
||||||
|
return ast.BinOp(
|
||||||
|
left=expr.left.accept(self),
|
||||||
|
op=BINARY_OPERATORS[op](),
|
||||||
|
right=expr.right.accept(self),
|
||||||
|
)
|
||||||
|
if op in COMPARISON_OPERATORS:
|
||||||
|
return ast.Compare(
|
||||||
|
left=expr.left.accept(self),
|
||||||
|
ops=[COMPARISON_OPERATORS[op]()],
|
||||||
|
comparators=[expr.right.accept(self)],
|
||||||
|
)
|
||||||
|
raise ValueError(f"Unexpected binary operator {op}")
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: m.UnaryExpr) -> ast.expr:
|
||||||
|
return ast.UnaryOp(
|
||||||
|
op=UNARY_OPERATORS[expr.operator.type](),
|
||||||
|
operand=expr.right.accept(self),
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: m.CallExpr) -> ast.expr:
|
||||||
|
return ast.Call(
|
||||||
|
func=expr.callee.accept(self),
|
||||||
|
args=[arg.accept(self) for arg in expr.arguments],
|
||||||
|
keywords=[
|
||||||
|
ast.keyword(arg=name, value=arg.accept(self))
|
||||||
|
for name, arg in expr.keywords.items()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: m.GetExpr) -> ast.expr:
|
||||||
|
return ast.Attribute(
|
||||||
|
value=expr.expr.accept(self),
|
||||||
|
attr=expr.name.lexeme,
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: m.VariableExpr) -> ast.expr:
|
||||||
|
name: str = expr.name.lexeme
|
||||||
|
if (p := self.get_predicate(name)) is not None:
|
||||||
|
return p
|
||||||
|
return ast.Name(id=name)
|
||||||
|
|
||||||
|
def visit_grouping_expr(self, expr: m.GroupingExpr) -> ast.expr:
|
||||||
|
return expr.accept(self)
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: m.LiteralExpr) -> ast.expr:
|
||||||
|
return ast.Constant(value=expr.value)
|
||||||
|
|
||||||
|
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> ast.expr:
|
||||||
|
return ast.Name(id="_")
|
||||||
@@ -2,15 +2,19 @@ import ast
|
|||||||
import shutil
|
import shutil
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, assert_never
|
||||||
|
|
||||||
from midas.ast.location import Location
|
import midas.ast.midas as m
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.ast.printer import MidasPrinter
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
AliasType,
|
||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
|
ConstraintType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
@@ -19,7 +23,9 @@ from midas.checker.types import (
|
|||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnitType,
|
UnitType,
|
||||||
|
UnknownType,
|
||||||
)
|
)
|
||||||
|
from midas.generator.constraints import ConstraintGenerator
|
||||||
from midas.utils import TypedAST
|
from midas.utils import TypedAST
|
||||||
|
|
||||||
|
|
||||||
@@ -30,12 +36,9 @@ class Scope:
|
|||||||
|
|
||||||
|
|
||||||
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||||
def __init__(self, workdir: Path) -> None:
|
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
|
||||||
self.workdir: Path = workdir.resolve()
|
self.workdir: Path = workdir.resolve()
|
||||||
self.build_dir: Path = self.workdir / "build" / "midas"
|
self.build_dir: Path = self.workdir / "build" / "midas"
|
||||||
if self.build_dir.exists():
|
|
||||||
shutil.rmtree(self.build_dir)
|
|
||||||
self.build_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
self.rel_src_path: Path = Path()
|
self.rel_src_path: Path = Path()
|
||||||
|
|
||||||
self._typed_ast: TypedAST = TypedAST(
|
self._typed_ast: TypedAST = TypedAST(
|
||||||
@@ -43,13 +46,18 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
judgements=[],
|
judgements=[],
|
||||||
)
|
)
|
||||||
self._alias_count: int = 0
|
self._alias_count: int = 0
|
||||||
|
self._predicate_count: int = 0
|
||||||
self._scopes: list[Scope] = []
|
self._scopes: list[Scope] = []
|
||||||
|
|
||||||
|
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
|
||||||
|
self._constraints: list[tuple[m.Expr, ast.expr]] = []
|
||||||
|
|
||||||
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
|
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
|
||||||
self.rel_src_path = src_path.relative_to(self.workdir)
|
self.rel_src_path = src_path.resolve().relative_to(self.workdir)
|
||||||
self._typed_ast = typed_ast
|
self._typed_ast = typed_ast
|
||||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
||||||
module = ast.Module(body=body, type_ignores=[])
|
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
|
||||||
|
module = ast.Module(body=predicates + body, type_ignores=[])
|
||||||
module = ast.fix_missing_locations(module)
|
module = ast.fix_missing_locations(module)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
@@ -59,6 +67,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
module: ast.AST = self.generate_ast(typed_ast, src_path)
|
module: ast.AST = self.generate_ast(typed_ast, src_path)
|
||||||
compiled: str = ast.unparse(module)
|
compiled: str = ast.unparse(module)
|
||||||
if out_path is None:
|
if out_path is None:
|
||||||
|
if self.build_dir.exists():
|
||||||
|
shutil.rmtree(self.build_dir)
|
||||||
|
self.build_dir.mkdir(parents=True, exist_ok=True)
|
||||||
out_path = (self.build_dir / self.rel_src_path).resolve()
|
out_path = (self.build_dir / self.rel_src_path).resolve()
|
||||||
try:
|
try:
|
||||||
_ = out_path.relative_to(self.build_dir)
|
_ = out_path.relative_to(self.build_dir)
|
||||||
@@ -139,6 +150,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
elts=[item.accept(self) for item in expr.items],
|
elts=[item.accept(self) for item in expr.items],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def visit_dict_expr(self, expr: p.DictExpr) -> ast.expr:
|
||||||
|
return ast.Dict(
|
||||||
|
keys=[key.accept(self) if key is not None else None for key in expr.keys],
|
||||||
|
values=[value.accept(self) for value in expr.values],
|
||||||
|
)
|
||||||
|
|
||||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr:
|
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr:
|
||||||
return ast.Subscript(
|
return ast.Subscript(
|
||||||
value=expr.object.accept(self),
|
value=expr.object.accept(self),
|
||||||
@@ -240,7 +257,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
return generated
|
return generated
|
||||||
|
|
||||||
def _make_alias(self, expr: ast.expr) -> ast.expr:
|
def _make_alias(self, expr: ast.expr) -> ast.expr:
|
||||||
name: str = f"__midas_alias_{self._alias_count}__"
|
name: str = f"__midas_a{self._alias_count}__"
|
||||||
alias = ast.Name(id=name)
|
alias = ast.Name(id=name)
|
||||||
self._alias_count += 1
|
self._alias_count += 1
|
||||||
self._scopes[-1].aliases.append(name)
|
self._scopes[-1].aliases.append(name)
|
||||||
@@ -270,6 +287,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
|
|
||||||
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
|
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
|
||||||
match type:
|
match type:
|
||||||
|
case UnknownType():
|
||||||
|
pass
|
||||||
|
|
||||||
case BaseType(name=name):
|
case BaseType(name=name):
|
||||||
self._add_assert(
|
self._add_assert(
|
||||||
ast.Call(
|
ast.Call(
|
||||||
@@ -295,8 +315,15 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
self._make_cast_assert_message(src_location, expr, type),
|
self._make_cast_assert_message(src_location, expr, type),
|
||||||
)
|
)
|
||||||
|
|
||||||
case AppliedType():
|
case AppliedType(body=body):
|
||||||
self._make_cast_asserts(src_location, expr, type.body)
|
self._make_cast_asserts(src_location, expr, body)
|
||||||
|
|
||||||
|
case ConstraintType(type=base, constraint=constraint):
|
||||||
|
self._make_cast_asserts(src_location, expr, base)
|
||||||
|
self._make_constraint_assert(src_location, expr, constraint)
|
||||||
|
|
||||||
|
case TypeVar():
|
||||||
|
raise RuntimeError("Unexpected TypeVar")
|
||||||
|
|
||||||
case (
|
case (
|
||||||
TopType()
|
TopType()
|
||||||
@@ -308,8 +335,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
):
|
):
|
||||||
raise NotImplementedError(f"Can't make assertion for type {type}")
|
raise NotImplementedError(f"Can't make assertion for type {type}")
|
||||||
|
|
||||||
case TypeVar():
|
# Ensure exhaustiveness
|
||||||
raise RuntimeError("Unexpected TypeVar")
|
case _:
|
||||||
|
assert_never(type)
|
||||||
|
|
||||||
def _make_cast_assert_message(
|
def _make_cast_assert_message(
|
||||||
self, location: Location, expr: ast.expr, type: Type
|
self, location: Location, expr: ast.expr, type: Type
|
||||||
@@ -333,3 +361,36 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
ast.Constant(f" to {type}"),
|
ast.Constant(f" to {type}"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _make_constraint_assert(
|
||||||
|
self, src_location: Location, expr: ast.expr, constraint: m.Expr
|
||||||
|
):
|
||||||
|
test_func: ast.expr = self._get_constraint(constraint)
|
||||||
|
self._add_assert(
|
||||||
|
ast.Call(
|
||||||
|
func=test_func,
|
||||||
|
args=[expr],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
self._make_constraint_assert_message(src_location, expr, constraint),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_constraint_assert_message(
|
||||||
|
self, location: Location, expr: ast.expr, constraint: m.Expr
|
||||||
|
) -> ast.expr:
|
||||||
|
printer = MidasPrinter()
|
||||||
|
constraint_str: str = printer.print(constraint)
|
||||||
|
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||||
|
# f"file.py:L1:1: ConstraintError: Value does not fit constraint 'v > 0'"
|
||||||
|
return ast.Constant(
|
||||||
|
f"{loc_str}: ConstraintError: Value does not fit constraint '{constraint_str}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_constraint(self, expr: m.Expr) -> ast.expr:
|
||||||
|
for expr2, constraint in self._constraints:
|
||||||
|
if expr2 == expr:
|
||||||
|
return constraint
|
||||||
|
|
||||||
|
constraint: ast.expr = self._constraint_generator.generate(expr)
|
||||||
|
self._constraints.append((expr, constraint))
|
||||||
|
return constraint
|
||||||
|
|||||||
@@ -69,6 +69,8 @@ class MidasLexer(Lexer):
|
|||||||
):
|
):
|
||||||
self.advance()
|
self.advance()
|
||||||
self.add_token(TokenType.WHITESPACE)
|
self.add_token(TokenType.WHITESPACE)
|
||||||
|
case '"' | "'":
|
||||||
|
self.scan_string(char)
|
||||||
case _:
|
case _:
|
||||||
if char.isdigit():
|
if char.isdigit():
|
||||||
self.scan_number()
|
self.scan_number()
|
||||||
@@ -78,6 +80,17 @@ class MidasLexer(Lexer):
|
|||||||
self.error("Unexpected character")
|
self.error("Unexpected character")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def scan_string(self, opening: str):
|
||||||
|
while self.peek() != opening and not self.is_at_end():
|
||||||
|
self.advance()
|
||||||
|
|
||||||
|
if self.is_at_end():
|
||||||
|
self.error("Unterminated string")
|
||||||
|
|
||||||
|
self.advance()
|
||||||
|
value: str = self.source[self.start + 1 : self.idx - 1]
|
||||||
|
self.add_token(TokenType.STRING, value)
|
||||||
|
|
||||||
def scan_number(self):
|
def scan_number(self):
|
||||||
"""Scan the rest of number and add it as a token
|
"""Scan the rest of number and add it as a token
|
||||||
|
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ class TokenType(Enum):
|
|||||||
TRUE = auto()
|
TRUE = auto()
|
||||||
FALSE = auto()
|
FALSE = auto()
|
||||||
NONE = auto()
|
NONE = auto()
|
||||||
|
STRING = auto()
|
||||||
|
|
||||||
# Keywords
|
# Keywords
|
||||||
TYPE = auto()
|
TYPE = auto()
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from typing import Optional
|
|||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.ast.midas import (
|
from midas.ast.midas import (
|
||||||
BinaryExpr,
|
BinaryExpr,
|
||||||
|
CallExpr,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
Expr,
|
Expr,
|
||||||
@@ -17,6 +18,7 @@ from midas.ast.midas import (
|
|||||||
MemberKind,
|
MemberKind,
|
||||||
MemberStmt,
|
MemberStmt,
|
||||||
NamedType,
|
NamedType,
|
||||||
|
ParamSpec,
|
||||||
PredicateStmt,
|
PredicateStmt,
|
||||||
Stmt,
|
Stmt,
|
||||||
Type,
|
Type,
|
||||||
@@ -265,6 +267,9 @@ class MidasParser(Parser):
|
|||||||
Returns:
|
Returns:
|
||||||
Expr: the parsed constraint expression
|
Expr: the parsed constraint expression
|
||||||
"""
|
"""
|
||||||
|
return self.expression()
|
||||||
|
|
||||||
|
def expression(self) -> Expr:
|
||||||
return self.and_()
|
return self.and_()
|
||||||
|
|
||||||
def and_(self) -> Expr:
|
def and_(self) -> Expr:
|
||||||
@@ -331,7 +336,55 @@ class MidasParser(Parser):
|
|||||||
right: Expr = self.unary()
|
right: Expr = self.unary()
|
||||||
location: Location = Location.span(operator.get_location(), right.location)
|
location: Location = Location.span(operator.get_location(), right.location)
|
||||||
return UnaryExpr(location=location, operator=operator, right=right)
|
return UnaryExpr(location=location, operator=operator, right=right)
|
||||||
return self.reference()
|
return self.call()
|
||||||
|
|
||||||
|
def call(self) -> Expr:
|
||||||
|
expr: Expr = self.reference()
|
||||||
|
while self.match(TokenType.LEFT_PAREN):
|
||||||
|
expr = self.finish_call(expr)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
def finish_call(self, callee: Expr) -> Expr:
|
||||||
|
pos_args: list[Expr] = []
|
||||||
|
kw_args: dict[str, Expr] = {}
|
||||||
|
keywords: bool = False
|
||||||
|
while not self.match(TokenType.RIGHT_PAREN):
|
||||||
|
if self.check_identifier() and self.check_next(TokenType.EQUAL):
|
||||||
|
keywords = True
|
||||||
|
keyword: Token = self.advance()
|
||||||
|
self.advance()
|
||||||
|
value: Expr = self.expression()
|
||||||
|
name: str = keyword.lexeme
|
||||||
|
if name in kw_args:
|
||||||
|
self.error(
|
||||||
|
self.peek(),
|
||||||
|
f"Multiple values passed for '{name}', only the last occurrence will be used",
|
||||||
|
)
|
||||||
|
kw_args[name] = value
|
||||||
|
else:
|
||||||
|
value = self.expression()
|
||||||
|
if self.check(TokenType.EQUAL):
|
||||||
|
if keywords:
|
||||||
|
raise self.error(self.peek(), "Invalid keyword argument name")
|
||||||
|
else:
|
||||||
|
raise self.error(
|
||||||
|
self.peek(),
|
||||||
|
"Cannot pass positional arguments after a keyword argument",
|
||||||
|
)
|
||||||
|
pos_args.append(value)
|
||||||
|
|
||||||
|
if not self.match(TokenType.COMMA):
|
||||||
|
break
|
||||||
|
|
||||||
|
r_paren: Token = self.consume(
|
||||||
|
TokenType.RIGHT_PAREN, "Expected ')' after arguments."
|
||||||
|
)
|
||||||
|
return CallExpr(
|
||||||
|
location=Location.span(callee.location, r_paren.get_location()),
|
||||||
|
callee=callee,
|
||||||
|
arguments=pos_args,
|
||||||
|
keywords=kw_args,
|
||||||
|
)
|
||||||
|
|
||||||
def reference(self) -> Expr:
|
def reference(self) -> Expr:
|
||||||
"""Parse an attribute access expression or a simpler expression
|
"""Parse an attribute access expression or a simpler expression
|
||||||
@@ -365,6 +418,9 @@ class MidasParser(Parser):
|
|||||||
if self.match(TokenType.NUMBER):
|
if self.match(TokenType.NUMBER):
|
||||||
return LiteralExpr(location=token.get_location(), value=token.value)
|
return LiteralExpr(location=token.get_location(), value=token.value)
|
||||||
|
|
||||||
|
if self.match(TokenType.STRING):
|
||||||
|
return LiteralExpr(location=token.get_location(), value=token.value)
|
||||||
|
|
||||||
if self.match_identifier():
|
if self.match_identifier():
|
||||||
return VariableExpr(location=token.get_location(), name=token)
|
return VariableExpr(location=token.get_location(), name=token)
|
||||||
|
|
||||||
@@ -453,23 +509,35 @@ class MidasParser(Parser):
|
|||||||
PredicateStmt: the parsed predicate declaration statement
|
PredicateStmt: the parsed predicate declaration statement
|
||||||
"""
|
"""
|
||||||
keyword: Token = self.previous()
|
keyword: Token = self.previous()
|
||||||
|
|
||||||
name: Token = self.consume_identifier("Expected predicate name")
|
name: Token = self.consume_identifier("Expected predicate name")
|
||||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
|
||||||
subject: Token = self.consume_identifier("Expected subject name")
|
params: list[ParamSpec] = []
|
||||||
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
while self.check(TokenType.LEFT_PAREN):
|
||||||
type: Type = self.type_expr()
|
params.append(self.function_args())
|
||||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
|
||||||
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
|
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
|
||||||
condition: Expr = self.constraint()
|
body: Expr = self.constraint()
|
||||||
return PredicateStmt(
|
return PredicateStmt(
|
||||||
location=keyword.location_to(self.previous()),
|
location=keyword.location_to(self.previous()),
|
||||||
name=name,
|
name=name,
|
||||||
subject=subject,
|
params=params,
|
||||||
type=type,
|
body=body,
|
||||||
condition=condition,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def function(self) -> FunctionType:
|
def function(self) -> FunctionType:
|
||||||
|
params: ParamSpec = self.function_args()
|
||||||
|
|
||||||
|
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||||
|
result: Type = self.type_expr()
|
||||||
|
|
||||||
|
return FunctionType(
|
||||||
|
location=params.l_paren.location_to(self.previous()),
|
||||||
|
params=params,
|
||||||
|
returns=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
def function_args(self) -> ParamSpec:
|
||||||
l_paren: Token = self.consume(
|
l_paren: Token = self.consume(
|
||||||
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
||||||
)
|
)
|
||||||
@@ -526,14 +594,4 @@ class MidasParser(Parser):
|
|||||||
self.error(token, "Unnamed mixed argument")
|
self.error(token, "Unnamed mixed argument")
|
||||||
|
|
||||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
||||||
|
return ParamSpec(l_paren=l_paren, pos=pos_args, mixed=args, kw=kw_args)
|
||||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
|
||||||
result: Type = self.type_expr()
|
|
||||||
|
|
||||||
return FunctionType(
|
|
||||||
location=l_paren.location_to(self.previous()),
|
|
||||||
pos_args=pos_args,
|
|
||||||
args=args,
|
|
||||||
kw_args=kw_args,
|
|
||||||
returns=result,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from midas.ast.python import (
|
|||||||
CastExpr,
|
CastExpr,
|
||||||
CompareExpr,
|
CompareExpr,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
|
DictExpr,
|
||||||
Expr,
|
Expr,
|
||||||
ExpressionStmt,
|
ExpressionStmt,
|
||||||
ForStmt,
|
ForStmt,
|
||||||
@@ -447,6 +448,16 @@ class PythonParser:
|
|||||||
items=[self.parse_expr(item) for item in items],
|
items=[self.parse_expr(item) for item in items],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case ast.Dict(keys=keys, values=values):
|
||||||
|
return DictExpr(
|
||||||
|
location=location,
|
||||||
|
keys=[
|
||||||
|
self.parse_expr(key) if key is not None else None
|
||||||
|
for key in keys
|
||||||
|
],
|
||||||
|
values=[self.parse_expr(value) for value in values],
|
||||||
|
)
|
||||||
|
|
||||||
case ast.Subscript(value=value, slice=index):
|
case ast.Subscript(value=value, slice=index):
|
||||||
return SubscriptExpr(
|
return SubscriptExpr(
|
||||||
location=location,
|
location=location,
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ Module(
|
|||||||
level=0),
|
level=0),
|
||||||
Assign(
|
Assign(
|
||||||
targets=[
|
targets=[
|
||||||
Name(id='__midas_alias_0__')],
|
Name(id='__midas_a0__')],
|
||||||
value=Constant(value=123.45)),
|
value=Constant(value=123.45)),
|
||||||
Assert(
|
Assert(
|
||||||
test=Call(
|
test=Call(
|
||||||
func=Name(id='isinstance'),
|
func=Name(id='isinstance'),
|
||||||
args=[
|
args=[
|
||||||
Name(id='__midas_alias_0__'),
|
Name(id='__midas_a0__'),
|
||||||
Name(id='float')],
|
Name(id='float')],
|
||||||
keywords=[]),
|
keywords=[]),
|
||||||
msg=JoinedStr(
|
msg=JoinedStr(
|
||||||
@@ -26,7 +26,7 @@ Module(
|
|||||||
value=Call(
|
value=Call(
|
||||||
func=Name(id='type'),
|
func=Name(id='type'),
|
||||||
args=[
|
args=[
|
||||||
Name(id='__midas_alias_0__')],
|
Name(id='__midas_a0__')],
|
||||||
keywords=[]),
|
keywords=[]),
|
||||||
attr='__name__'),
|
attr='__name__'),
|
||||||
conversion=-1),
|
conversion=-1),
|
||||||
@@ -34,19 +34,19 @@ Module(
|
|||||||
Assign(
|
Assign(
|
||||||
targets=[
|
targets=[
|
||||||
Name(id='distance')],
|
Name(id='distance')],
|
||||||
value=Name(id='__midas_alias_0__')),
|
value=Name(id='__midas_a0__')),
|
||||||
Delete(
|
Delete(
|
||||||
targets=[
|
targets=[
|
||||||
Name(id='__midas_alias_0__')]),
|
Name(id='__midas_a0__')]),
|
||||||
Assign(
|
Assign(
|
||||||
targets=[
|
targets=[
|
||||||
Name(id='__midas_alias_1__')],
|
Name(id='__midas_a1__')],
|
||||||
value=Constant(value=6.7)),
|
value=Constant(value=6.7)),
|
||||||
Assert(
|
Assert(
|
||||||
test=Call(
|
test=Call(
|
||||||
func=Name(id='isinstance'),
|
func=Name(id='isinstance'),
|
||||||
args=[
|
args=[
|
||||||
Name(id='__midas_alias_1__'),
|
Name(id='__midas_a1__'),
|
||||||
Name(id='float')],
|
Name(id='float')],
|
||||||
keywords=[]),
|
keywords=[]),
|
||||||
msg=JoinedStr(
|
msg=JoinedStr(
|
||||||
@@ -57,7 +57,7 @@ Module(
|
|||||||
value=Call(
|
value=Call(
|
||||||
func=Name(id='type'),
|
func=Name(id='type'),
|
||||||
args=[
|
args=[
|
||||||
Name(id='__midas_alias_1__')],
|
Name(id='__midas_a1__')],
|
||||||
keywords=[]),
|
keywords=[]),
|
||||||
attr='__name__'),
|
attr='__name__'),
|
||||||
conversion=-1),
|
conversion=-1),
|
||||||
@@ -65,10 +65,10 @@ Module(
|
|||||||
Assign(
|
Assign(
|
||||||
targets=[
|
targets=[
|
||||||
Name(id='time')],
|
Name(id='time')],
|
||||||
value=Name(id='__midas_alias_1__')),
|
value=Name(id='__midas_a1__')),
|
||||||
Delete(
|
Delete(
|
||||||
targets=[
|
targets=[
|
||||||
Name(id='__midas_alias_1__')]),
|
Name(id='__midas_a1__')]),
|
||||||
Assign(
|
Assign(
|
||||||
targets=[
|
targets=[
|
||||||
Name(id='speed')],
|
Name(id='speed')],
|
||||||
|
|||||||
14
tests/cases/generator/02_constraints.midas
Normal file
14
tests/cases/generator/02_constraints.midas
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
// Inline
|
||||||
|
type T1 = float where _ > 0
|
||||||
|
|
||||||
|
// Named
|
||||||
|
predicate is_positive(v: float) = v > 0
|
||||||
|
type T2 = float where is_positive(_)
|
||||||
|
|
||||||
|
// Curried
|
||||||
|
predicate in_range(mn: float, mx: float)(v: float) = v >= mn & v < mx
|
||||||
|
type T3 = float where in_range(100, 200)(_)
|
||||||
|
|
||||||
|
// Alias
|
||||||
|
predicate minor = in_range(0, 18)
|
||||||
|
type T4 = float where minor(_)
|
||||||
8
tests/cases/generator/02_constraints.py
Normal file
8
tests/cases/generator/02_constraints.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from midas import T1, T2, T3, T4, cast
|
||||||
|
|
||||||
|
t: float = 12.5
|
||||||
|
|
||||||
|
t1: T1 = cast(T1, t)
|
||||||
|
t2: T2 = cast(T2, t)
|
||||||
|
t3: T3 = cast(T3, t)
|
||||||
|
t4: T4 = cast(T4, t)
|
||||||
333
tests/cases/generator/02_constraints.py.ref.txt
Normal file
333
tests/cases/generator/02_constraints.py.ref.txt
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
Module(
|
||||||
|
body=[
|
||||||
|
FunctionDef(
|
||||||
|
name='__midas_p0__',
|
||||||
|
args=arguments(
|
||||||
|
posonlyargs=[],
|
||||||
|
args=[
|
||||||
|
arg(
|
||||||
|
arg='_',
|
||||||
|
annotation=Constant(value='Any'))],
|
||||||
|
kwonlyargs=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
defaults=[]),
|
||||||
|
body=[
|
||||||
|
Return(
|
||||||
|
value=Compare(
|
||||||
|
left=Name(id='_'),
|
||||||
|
ops=[
|
||||||
|
Gt()],
|
||||||
|
comparators=[
|
||||||
|
Constant(value=0.0)]))],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=Constant(value='bool')),
|
||||||
|
FunctionDef(
|
||||||
|
name='__midas_is_positive__',
|
||||||
|
args=arguments(
|
||||||
|
posonlyargs=[],
|
||||||
|
args=[
|
||||||
|
arg(
|
||||||
|
arg='v',
|
||||||
|
annotation=Constant(value='float'))],
|
||||||
|
kwonlyargs=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
defaults=[]),
|
||||||
|
body=[
|
||||||
|
Return(
|
||||||
|
value=Compare(
|
||||||
|
left=Name(id='v'),
|
||||||
|
ops=[
|
||||||
|
Gt()],
|
||||||
|
comparators=[
|
||||||
|
Constant(value=0.0)]))],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=Constant(value='bool')),
|
||||||
|
FunctionDef(
|
||||||
|
name='__midas_p1__',
|
||||||
|
args=arguments(
|
||||||
|
posonlyargs=[],
|
||||||
|
args=[
|
||||||
|
arg(
|
||||||
|
arg='_',
|
||||||
|
annotation=Constant(value='Any'))],
|
||||||
|
kwonlyargs=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
defaults=[]),
|
||||||
|
body=[
|
||||||
|
Return(
|
||||||
|
value=Call(
|
||||||
|
func=Name(id='__midas_is_positive__'),
|
||||||
|
args=[
|
||||||
|
Name(id='_')],
|
||||||
|
keywords=[]))],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=Constant(value='bool')),
|
||||||
|
FunctionDef(
|
||||||
|
name='__midas_in_range__',
|
||||||
|
args=arguments(
|
||||||
|
posonlyargs=[],
|
||||||
|
args=[
|
||||||
|
arg(
|
||||||
|
arg='mn',
|
||||||
|
annotation=Constant(value='float')),
|
||||||
|
arg(
|
||||||
|
arg='mx',
|
||||||
|
annotation=Constant(value='float'))],
|
||||||
|
kwonlyargs=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
defaults=[]),
|
||||||
|
body=[
|
||||||
|
FunctionDef(
|
||||||
|
name='inner0',
|
||||||
|
args=arguments(
|
||||||
|
posonlyargs=[],
|
||||||
|
args=[
|
||||||
|
arg(
|
||||||
|
arg='v',
|
||||||
|
annotation=Constant(value='float'))],
|
||||||
|
kwonlyargs=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
defaults=[]),
|
||||||
|
body=[
|
||||||
|
Return(
|
||||||
|
value=BoolOp(
|
||||||
|
op=And(),
|
||||||
|
values=[
|
||||||
|
Compare(
|
||||||
|
left=Name(id='v'),
|
||||||
|
ops=[
|
||||||
|
GtE()],
|
||||||
|
comparators=[
|
||||||
|
Name(id='mn')]),
|
||||||
|
Compare(
|
||||||
|
left=Name(id='v'),
|
||||||
|
ops=[
|
||||||
|
Lt()],
|
||||||
|
comparators=[
|
||||||
|
Name(id='mx')])]))],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=Constant(value='bool')),
|
||||||
|
Return(
|
||||||
|
value=Name(id='inner0'))],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=Constant(value='Callable[[float], bool]')),
|
||||||
|
FunctionDef(
|
||||||
|
name='__midas_p2__',
|
||||||
|
args=arguments(
|
||||||
|
posonlyargs=[],
|
||||||
|
args=[
|
||||||
|
arg(
|
||||||
|
arg='_',
|
||||||
|
annotation=Constant(value='Any'))],
|
||||||
|
kwonlyargs=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
defaults=[]),
|
||||||
|
body=[
|
||||||
|
Return(
|
||||||
|
value=Call(
|
||||||
|
func=Call(
|
||||||
|
func=Name(id='__midas_in_range__'),
|
||||||
|
args=[
|
||||||
|
Constant(value=100.0),
|
||||||
|
Constant(value=200.0)],
|
||||||
|
keywords=[]),
|
||||||
|
args=[
|
||||||
|
Name(id='_')],
|
||||||
|
keywords=[]))],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=Constant(value='bool')),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_minor__')],
|
||||||
|
value=Call(
|
||||||
|
func=Name(id='__midas_in_range__'),
|
||||||
|
args=[
|
||||||
|
Constant(value=0.0),
|
||||||
|
Constant(value=18.0)],
|
||||||
|
keywords=[])),
|
||||||
|
FunctionDef(
|
||||||
|
name='__midas_p3__',
|
||||||
|
args=arguments(
|
||||||
|
posonlyargs=[],
|
||||||
|
args=[
|
||||||
|
arg(
|
||||||
|
arg='_',
|
||||||
|
annotation=Constant(value='Any'))],
|
||||||
|
kwonlyargs=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
defaults=[]),
|
||||||
|
body=[
|
||||||
|
Return(
|
||||||
|
value=Call(
|
||||||
|
func=Name(id='__midas_minor__'),
|
||||||
|
args=[
|
||||||
|
Name(id='_')],
|
||||||
|
keywords=[]))],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=Constant(value='bool')),
|
||||||
|
ImportFrom(
|
||||||
|
module='midas',
|
||||||
|
names=[
|
||||||
|
alias(name='T1'),
|
||||||
|
alias(name='T2'),
|
||||||
|
alias(name='T3'),
|
||||||
|
alias(name='T4'),
|
||||||
|
alias(name='cast')],
|
||||||
|
level=0),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='t')],
|
||||||
|
value=Constant(value=12.5)),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_a0__')],
|
||||||
|
value=Name(id='t')),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='isinstance'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a0__'),
|
||||||
|
Name(id='float')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=JoinedStr(
|
||||||
|
values=[
|
||||||
|
Constant(value='02_constraints.py:L5:10: CastError: Cannot cast '),
|
||||||
|
FormattedValue(
|
||||||
|
value=Attribute(
|
||||||
|
value=Call(
|
||||||
|
func=Name(id='type'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a0__')],
|
||||||
|
keywords=[]),
|
||||||
|
attr='__name__'),
|
||||||
|
conversion=-1),
|
||||||
|
Constant(value=' to float')])),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='__midas_p0__'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a0__')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=Constant(value="02_constraints.py:L5:10: ConstraintError: Value does not fit constraint '_ > 0.0'")),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='t1')],
|
||||||
|
value=Name(id='__midas_a0__')),
|
||||||
|
Delete(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_a0__')]),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_a1__')],
|
||||||
|
value=Name(id='t')),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='isinstance'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a1__'),
|
||||||
|
Name(id='float')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=JoinedStr(
|
||||||
|
values=[
|
||||||
|
Constant(value='02_constraints.py:L6:10: CastError: Cannot cast '),
|
||||||
|
FormattedValue(
|
||||||
|
value=Attribute(
|
||||||
|
value=Call(
|
||||||
|
func=Name(id='type'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a1__')],
|
||||||
|
keywords=[]),
|
||||||
|
attr='__name__'),
|
||||||
|
conversion=-1),
|
||||||
|
Constant(value=' to float')])),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='__midas_p1__'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a1__')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=Constant(value="02_constraints.py:L6:10: ConstraintError: Value does not fit constraint 'is_positive(_)'")),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='t2')],
|
||||||
|
value=Name(id='__midas_a1__')),
|
||||||
|
Delete(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_a1__')]),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_a2__')],
|
||||||
|
value=Name(id='t')),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='isinstance'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a2__'),
|
||||||
|
Name(id='float')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=JoinedStr(
|
||||||
|
values=[
|
||||||
|
Constant(value='02_constraints.py:L7:10: CastError: Cannot cast '),
|
||||||
|
FormattedValue(
|
||||||
|
value=Attribute(
|
||||||
|
value=Call(
|
||||||
|
func=Name(id='type'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a2__')],
|
||||||
|
keywords=[]),
|
||||||
|
attr='__name__'),
|
||||||
|
conversion=-1),
|
||||||
|
Constant(value=' to float')])),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='__midas_p2__'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a2__')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=Constant(value="02_constraints.py:L7:10: ConstraintError: Value does not fit constraint 'in_range(100.0, 200.0)(_)'")),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='t3')],
|
||||||
|
value=Name(id='__midas_a2__')),
|
||||||
|
Delete(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_a2__')]),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_a3__')],
|
||||||
|
value=Name(id='t')),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='isinstance'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a3__'),
|
||||||
|
Name(id='float')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=JoinedStr(
|
||||||
|
values=[
|
||||||
|
Constant(value='02_constraints.py:L8:10: CastError: Cannot cast '),
|
||||||
|
FormattedValue(
|
||||||
|
value=Attribute(
|
||||||
|
value=Call(
|
||||||
|
func=Name(id='type'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a3__')],
|
||||||
|
keywords=[]),
|
||||||
|
attr='__name__'),
|
||||||
|
conversion=-1),
|
||||||
|
Constant(value=' to float')])),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='__midas_p3__'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a3__')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=Constant(value="02_constraints.py:L8:10: ConstraintError: Value does not fit constraint 'minor(_)'")),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='t4')],
|
||||||
|
value=Name(id='__midas_a3__')),
|
||||||
|
Delete(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_a3__')])],
|
||||||
|
type_ignores=[])
|
||||||
@@ -2582,18 +2582,21 @@
|
|||||||
"name": "__sub__",
|
"name": "__sub__",
|
||||||
"type": {
|
"type": {
|
||||||
"_type": "FunctionType",
|
"_type": "FunctionType",
|
||||||
"pos_args": [
|
"params": {
|
||||||
{
|
"_type": "ParamSpec",
|
||||||
"name": null,
|
"pos": [
|
||||||
"type": {
|
{
|
||||||
"_type": "NamedType",
|
"name": null,
|
||||||
"name": "GeoLocation"
|
"type": {
|
||||||
},
|
"_type": "NamedType",
|
||||||
"required": true
|
"name": "GeoLocation"
|
||||||
}
|
},
|
||||||
],
|
"required": true
|
||||||
"args": [],
|
}
|
||||||
"kw_args": [],
|
],
|
||||||
|
"mixed": [],
|
||||||
|
"kw": []
|
||||||
|
},
|
||||||
"returns": {
|
"returns": {
|
||||||
"_type": "GenericType",
|
"_type": "GenericType",
|
||||||
"type": {
|
"type": {
|
||||||
@@ -2673,18 +2676,21 @@
|
|||||||
"name": "__sub__",
|
"name": "__sub__",
|
||||||
"type": {
|
"type": {
|
||||||
"_type": "FunctionType",
|
"_type": "FunctionType",
|
||||||
"pos_args": [
|
"params": {
|
||||||
{
|
"_type": "ParamSpec",
|
||||||
"name": null,
|
"pos": [
|
||||||
"type": {
|
{
|
||||||
"_type": "NamedType",
|
"name": null,
|
||||||
"name": "Latitude"
|
"type": {
|
||||||
},
|
"_type": "NamedType",
|
||||||
"required": true
|
"name": "Latitude"
|
||||||
}
|
},
|
||||||
],
|
"required": true
|
||||||
"args": [],
|
}
|
||||||
"kw_args": [],
|
],
|
||||||
|
"mixed": [],
|
||||||
|
"kw": []
|
||||||
|
},
|
||||||
"returns": {
|
"returns": {
|
||||||
"_type": "GenericType",
|
"_type": "GenericType",
|
||||||
"type": {
|
"type": {
|
||||||
@@ -2713,18 +2719,21 @@
|
|||||||
"name": "__sub__",
|
"name": "__sub__",
|
||||||
"type": {
|
"type": {
|
||||||
"_type": "FunctionType",
|
"_type": "FunctionType",
|
||||||
"pos_args": [
|
"params": {
|
||||||
{
|
"_type": "ParamSpec",
|
||||||
"name": null,
|
"pos": [
|
||||||
"type": {
|
{
|
||||||
"_type": "NamedType",
|
"name": null,
|
||||||
"name": "Longitude"
|
"type": {
|
||||||
},
|
"_type": "NamedType",
|
||||||
"required": true
|
"name": "Longitude"
|
||||||
}
|
},
|
||||||
],
|
"required": true
|
||||||
"args": [],
|
}
|
||||||
"kw_args": [],
|
],
|
||||||
|
"mixed": [],
|
||||||
|
"kw": []
|
||||||
|
},
|
||||||
"returns": {
|
"returns": {
|
||||||
"_type": "GenericType",
|
"_type": "GenericType",
|
||||||
"type": {
|
"type": {
|
||||||
@@ -2745,12 +2754,24 @@
|
|||||||
{
|
{
|
||||||
"_type": "PredicateStmt",
|
"_type": "PredicateStmt",
|
||||||
"name": "Positive",
|
"name": "Positive",
|
||||||
"subject": "v",
|
"params": [
|
||||||
"type": {
|
{
|
||||||
"_type": "NamedType",
|
"_type": "ParamSpec",
|
||||||
"name": "float"
|
"pos": [],
|
||||||
},
|
"mixed": [
|
||||||
"condition": {
|
{
|
||||||
|
"name": "v",
|
||||||
|
"type": {
|
||||||
|
"_type": "NamedType",
|
||||||
|
"name": "float"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kw": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
"_type": "BinaryExpr",
|
"_type": "BinaryExpr",
|
||||||
"left": {
|
"left": {
|
||||||
"_type": "VariableExpr",
|
"_type": "VariableExpr",
|
||||||
@@ -2766,12 +2787,24 @@
|
|||||||
{
|
{
|
||||||
"_type": "PredicateStmt",
|
"_type": "PredicateStmt",
|
||||||
"name": "StrictlyPositive",
|
"name": "StrictlyPositive",
|
||||||
"subject": "v",
|
"params": [
|
||||||
"type": {
|
{
|
||||||
"_type": "NamedType",
|
"_type": "ParamSpec",
|
||||||
"name": "float"
|
"pos": [],
|
||||||
},
|
"mixed": [
|
||||||
"condition": {
|
{
|
||||||
|
"name": "v",
|
||||||
|
"type": {
|
||||||
|
"_type": "NamedType",
|
||||||
|
"name": "float"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kw": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
"_type": "BinaryExpr",
|
"_type": "BinaryExpr",
|
||||||
"left": {
|
"left": {
|
||||||
"_type": "VariableExpr",
|
"_type": "VariableExpr",
|
||||||
@@ -2787,12 +2820,24 @@
|
|||||||
{
|
{
|
||||||
"_type": "PredicateStmt",
|
"_type": "PredicateStmt",
|
||||||
"name": "Equatorial",
|
"name": "Equatorial",
|
||||||
"subject": "loc",
|
"params": [
|
||||||
"type": {
|
{
|
||||||
"_type": "NamedType",
|
"_type": "ParamSpec",
|
||||||
"name": "GeoLocation"
|
"pos": [],
|
||||||
},
|
"mixed": [
|
||||||
"condition": {
|
{
|
||||||
|
"name": "loc",
|
||||||
|
"type": {
|
||||||
|
"_type": "NamedType",
|
||||||
|
"name": "GeoLocation"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kw": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
"_type": "GroupingExpr",
|
"_type": "GroupingExpr",
|
||||||
"expr": {
|
"expr": {
|
||||||
"_type": "BinaryExpr",
|
"_type": "BinaryExpr",
|
||||||
@@ -2827,12 +2872,24 @@
|
|||||||
{
|
{
|
||||||
"_type": "PredicateStmt",
|
"_type": "PredicateStmt",
|
||||||
"name": "Arctic",
|
"name": "Arctic",
|
||||||
"subject": "loc",
|
"params": [
|
||||||
"type": {
|
{
|
||||||
"_type": "NamedType",
|
"_type": "ParamSpec",
|
||||||
"name": "GeoLocation"
|
"pos": [],
|
||||||
},
|
"mixed": [
|
||||||
"condition": {
|
{
|
||||||
|
"name": "loc",
|
||||||
|
"type": {
|
||||||
|
"_type": "NamedType",
|
||||||
|
"name": "GeoLocation"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kw": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
"_type": "GroupingExpr",
|
"_type": "GroupingExpr",
|
||||||
"expr": {
|
"expr": {
|
||||||
"_type": "BinaryExpr",
|
"_type": "BinaryExpr",
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class GeneratorTester(Tester):
|
|||||||
typed_ast: TypedAST = checker.type_check(path)
|
typed_ast: TypedAST = checker.type_check(path)
|
||||||
|
|
||||||
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
|
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
|
||||||
generator = Generator(workdir=path.parent)
|
generator = Generator(workdir=path.parent, types=checker.types)
|
||||||
result.compiled_ast = generator.generate_ast(typed_ast, path)
|
result.compiled_ast = generator.generate_ast(typed_ast, path)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from typing import Optional, Sequence
|
|||||||
|
|
||||||
from midas.ast.midas import (
|
from midas.ast.midas import (
|
||||||
BinaryExpr,
|
BinaryExpr,
|
||||||
|
CallExpr,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
Expr,
|
Expr,
|
||||||
@@ -15,6 +16,7 @@ from midas.ast.midas import (
|
|||||||
LogicalExpr,
|
LogicalExpr,
|
||||||
MemberStmt,
|
MemberStmt,
|
||||||
NamedType,
|
NamedType,
|
||||||
|
ParamSpec,
|
||||||
PredicateStmt,
|
PredicateStmt,
|
||||||
Stmt,
|
Stmt,
|
||||||
Type,
|
Type,
|
||||||
@@ -78,9 +80,8 @@ class MidasAstJsonSerializer(
|
|||||||
return {
|
return {
|
||||||
"_type": "PredicateStmt",
|
"_type": "PredicateStmt",
|
||||||
"name": stmt.name.lexeme,
|
"name": stmt.name.lexeme,
|
||||||
"subject": stmt.subject.lexeme,
|
"params": [self._serialize_param_spec(spec) for spec in stmt.params],
|
||||||
"type": stmt.type.accept(self),
|
"body": stmt.body.accept(self),
|
||||||
"condition": stmt.condition.accept(self),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
||||||
@@ -106,6 +107,14 @@ class MidasAstJsonSerializer(
|
|||||||
"right": expr.right.accept(self),
|
"right": expr.right.accept(self),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: CallExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "CallExpr",
|
||||||
|
"callee": expr.callee.accept(self),
|
||||||
|
"arguments": self._serialize_list(expr.arguments),
|
||||||
|
"keywords": {name: arg.accept(self) for name, arg in expr.keywords.items()},
|
||||||
|
}
|
||||||
|
|
||||||
def visit_get_expr(self, expr: GetExpr) -> dict:
|
def visit_get_expr(self, expr: GetExpr) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "GetExpr",
|
"_type": "GetExpr",
|
||||||
@@ -163,15 +172,21 @@ class MidasAstJsonSerializer(
|
|||||||
def visit_function_type(self, type: FunctionType) -> dict:
|
def visit_function_type(self, type: FunctionType) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "FunctionType",
|
"_type": "FunctionType",
|
||||||
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
|
"params": self._serialize_param_spec(type.params),
|
||||||
"args": [self._serialize_func_arg(arg) for arg in type.args],
|
|
||||||
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
|
|
||||||
"returns": type.returns.accept(self),
|
"returns": type.returns.accept(self),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _serialize_param_spec(self, spec: ParamSpec) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "ParamSpec",
|
||||||
|
"pos": [self._serialize_func_arg(arg) for arg in spec.pos],
|
||||||
|
"mixed": [self._serialize_func_arg(arg) for arg in spec.mixed],
|
||||||
|
"kw": [self._serialize_func_arg(arg) for arg in spec.kw],
|
||||||
|
}
|
||||||
|
|
||||||
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
|
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
|
||||||
return {
|
return {
|
||||||
"name": arg.name,
|
"name": arg.name.lexeme if arg.name is not None else None,
|
||||||
"type": arg.type.accept(self),
|
"type": arg.type.accept(self),
|
||||||
"required": arg.required,
|
"required": arg.required,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from midas.ast.python import (
|
|||||||
CastExpr,
|
CastExpr,
|
||||||
CompareExpr,
|
CompareExpr,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
|
DictExpr,
|
||||||
Expr,
|
Expr,
|
||||||
ExpressionStmt,
|
ExpressionStmt,
|
||||||
ForStmt,
|
ForStmt,
|
||||||
@@ -278,6 +279,13 @@ class PythonAstJsonSerializer(
|
|||||||
"items": [item.accept(self) for item in expr.items],
|
"items": [item.accept(self) for item in expr.items],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def visit_dict_expr(self, expr: DictExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "DictExpr",
|
||||||
|
"keys": [self._serialize_optional(key) for key in expr.keys],
|
||||||
|
"values": self._serialize_list(expr.values),
|
||||||
|
}
|
||||||
|
|
||||||
def visit_subscript_expr(self, expr: SubscriptExpr) -> dict:
|
def visit_subscript_expr(self, expr: SubscriptExpr) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "SubscriptExpr",
|
"_type": "SubscriptExpr",
|
||||||
|
|||||||
Reference in New Issue
Block a user