Merge pull request 'Basic type checker' (#6) from feat/basic-type-checker into main
Reviewed-on: #6
This commit was merged in pull request #6.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,3 +6,4 @@ venv
|
||||
*.pyc
|
||||
uv.lock
|
||||
.python-version
|
||||
/out
|
||||
150
docs/architecture.typ
Normal file
150
docs/architecture.typ
Normal file
@@ -0,0 +1,150 @@
|
||||
#import "@preview/cetz:0.5.2": canvas, draw
|
||||
|
||||
#let diagram-only = false
|
||||
|
||||
#set document(
|
||||
title: [Midas Architecture],
|
||||
//author: "Louis Heredero",
|
||||
)
|
||||
|
||||
#set text(
|
||||
font: "Source Sans 3",
|
||||
)
|
||||
|
||||
#let diagram = canvas({
|
||||
let framed = draw.content.with(
|
||||
padding: (x: .8em, y: 1em),
|
||||
frame: "rect",
|
||||
stroke: black,
|
||||
)
|
||||
let arrow = draw.line.with(mark: (end: ">", fill: black))
|
||||
framed(
|
||||
(0, 0),
|
||||
name: "python-parser",
|
||||
)[Python parser]
|
||||
|
||||
draw.content(
|
||||
(rel: (0, 1), to: "python-parser.north"),
|
||||
padding: 5pt,
|
||||
anchor: "south",
|
||||
name: "source-py",
|
||||
)[_`source.py`_]
|
||||
arrow("source-py", "python-parser")
|
||||
|
||||
framed(
|
||||
(rel: (3, 0), to: "python-parser.east"),
|
||||
anchor: "west",
|
||||
name: "custom-parser",
|
||||
align(center)[Custom python\ parser],
|
||||
)
|
||||
|
||||
arrow("python-parser", "custom-parser", name: "arrow-python-ast")
|
||||
draw.content(
|
||||
"arrow-python-ast",
|
||||
anchor: "south",
|
||||
padding: 5pt,
|
||||
)[`ast.Module`]
|
||||
|
||||
framed(
|
||||
(rel: (-3, -2), to: "custom-parser.south"),
|
||||
anchor: "east",
|
||||
name: "python-resolver",
|
||||
)[Python Resolver]
|
||||
arrow(
|
||||
"custom-parser",
|
||||
((), "|-", "python-resolver.east"),
|
||||
"python-resolver",
|
||||
name: "arrow-python-custom-ast",
|
||||
)
|
||||
draw.content(
|
||||
(rel: (1.5, 0), to: "arrow-python-custom-ast.end"),
|
||||
padding: 5pt,
|
||||
anchor: "south",
|
||||
)[P-AST#footnote[#strong[P]ython *AST*]<fn-past>]
|
||||
draw.content(
|
||||
"python-resolver.west",
|
||||
padding: 5pt,
|
||||
anchor: "south-east",
|
||||
)[Resolved P-AST@fn-past]
|
||||
|
||||
draw.circle(
|
||||
(rel: (1, -2), to: "custom-parser.south-east"),
|
||||
radius: .4,
|
||||
name: "midas-loader",
|
||||
)
|
||||
arrow(
|
||||
"custom-parser",
|
||||
"midas-loader",
|
||||
name: "arrow-load-midas",
|
||||
mark: (end: (symbol: ">", fill: black), start: "o"),
|
||||
)
|
||||
draw.content(
|
||||
"arrow-load-midas",
|
||||
anchor: "west",
|
||||
padding: 5pt,
|
||||
)[```python midas.using("types.midas")```]
|
||||
|
||||
framed(
|
||||
(rel: (0, -2), to: "midas-loader.south"),
|
||||
name: "midas-parser",
|
||||
)[Midas lexer/parser]
|
||||
arrow("midas-loader", "midas-parser", name: "arrow-midas-source")
|
||||
draw.content(
|
||||
"arrow-midas-source",
|
||||
anchor: "west",
|
||||
padding: 5pt,
|
||||
)[_`types.midas`_]
|
||||
|
||||
|
||||
framed(
|
||||
(rel: (-2, 0), to: "midas-parser.west"),
|
||||
anchor: "east",
|
||||
name: "midas-resolver",
|
||||
)[Midas Resolver]
|
||||
arrow("midas-parser", "midas-resolver", name: "arrow-midas-ast")
|
||||
draw.content(
|
||||
"arrow-midas-ast",
|
||||
anchor: "south",
|
||||
padding: 5pt,
|
||||
)[M-AST#footnote[#strong[M]idas *AST*]<fn-mast>]
|
||||
|
||||
framed(
|
||||
(rel: (-3, 0), to: "midas-resolver.west"),
|
||||
anchor: "east",
|
||||
name: "checker",
|
||||
)[Checker]
|
||||
arrow("midas-resolver", "checker", name: "arrow-type-ctx")
|
||||
arrow(
|
||||
"python-resolver",
|
||||
((), "-|", "checker.north"),
|
||||
"checker",
|
||||
)
|
||||
draw.content(
|
||||
"arrow-type-ctx",
|
||||
anchor: "south",
|
||||
padding: 5pt,
|
||||
)[Types context]
|
||||
})
|
||||
|
||||
#show: doc => if diagram-only {
|
||||
set page(width: auto, height: auto, margin: .5cm)
|
||||
diagram
|
||||
} else { doc }
|
||||
|
||||
#align(center, title())
|
||||
|
||||
#v(1cm)
|
||||
|
||||
#figure(
|
||||
diagram,
|
||||
caption: [Midas type-checker architecture],
|
||||
)
|
||||
|
||||
== Components
|
||||
|
||||
- *Python parser*: builtin Python AST parser, extracts abstract syntax from the raw Python source (```python ast.parse(...)```)
|
||||
- *Custom python parser*: converts the raw Python AST into custom, more suitable constructs, especially for type annotations
|
||||
- *Python resolver*: resolves bindings and references, tracks binding scopes
|
||||
- *Midas lexer/parser*: parses a Midas type definition file and extracts its AST
|
||||
- *Midas resolver*: walks the AST and fills the environment with the defined types and operations
|
||||
- *Checker*: evaluates expressions and checks type coherence
|
||||
@@ -2,10 +2,6 @@
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
# Prototype of custom type import to use valid Python syntax
|
||||
import midas
|
||||
midas.using("02_custom_types.midas")
|
||||
|
||||
# A data-frame using a custom type
|
||||
df: Frame[
|
||||
location: GeoLocation
|
||||
|
||||
33
examples/00_syntax_prototype/05_custom_types_v3.midas
Normal file
33
examples/00_syntax_prototype/05_custom_types_v3.midas
Normal file
@@ -0,0 +1,33 @@
|
||||
type Foo1 = float
|
||||
type Foo2 = float where (_ > 3)
|
||||
type Foo3 = int | float
|
||||
type Foo4 = int where (_ > 3) | float where (_ > 3)
|
||||
type Foo5 = (int | float) where (_ > 3)
|
||||
type Foo6 = {
|
||||
foo: float
|
||||
bar: float where (_ > 3)
|
||||
}
|
||||
|
||||
type Foo7[T] = T where (_ > 3)
|
||||
type Foo8[A, B<:int] = {
|
||||
a: A
|
||||
b: B
|
||||
}
|
||||
|
||||
type Complex = {
|
||||
a: int
|
||||
b: int
|
||||
}
|
||||
type Complex2 = Complex where (_.a > 3 & _.b < 5)
|
||||
|
||||
predicate Positive(n: int) = n >= 0
|
||||
|
||||
extend Foo1 {
|
||||
op __add__(Foo1) -> Foo1
|
||||
}
|
||||
|
||||
extend Foo7[T] {
|
||||
op __add__(Foo7[T]) -> Foo7[T]
|
||||
}
|
||||
|
||||
type Optional[T] = None | T
|
||||
11
examples/01_simple_type_checking/01_simple_operations.py
Normal file
11
examples/01_simple_type_checking/01_simple_operations.py
Normal file
@@ -0,0 +1,11 @@
|
||||
a: int = 3
|
||||
b: int = 4
|
||||
|
||||
c = a + b # -> int
|
||||
|
||||
c = "invalid" # -> can't assign str to int variable
|
||||
|
||||
d = True
|
||||
e = d + d
|
||||
|
||||
f: float = a
|
||||
14
examples/01_simple_type_checking/02_simple_types.midas
Normal file
14
examples/01_simple_type_checking/02_simple_types.midas
Normal file
@@ -0,0 +1,14 @@
|
||||
type Meter = float
|
||||
type Second = float
|
||||
type MeterPerSecond = float
|
||||
|
||||
extend Meter {
|
||||
op __add__(Meter) -> Meter
|
||||
op __sub__(Meter) -> Meter
|
||||
op __truediv__(Second) -> MeterPerSecond
|
||||
}
|
||||
|
||||
extend Second {
|
||||
op __add__(Second) -> Second
|
||||
op __sub__(Second) -> Second
|
||||
}
|
||||
6
examples/01_simple_type_checking/02_simple_types.py
Normal file
6
examples/01_simple_type_checking/02_simple_types.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
distance: Meter = cast(Meter, 123.45)
|
||||
time: Second = cast(Second, 6.7)
|
||||
speed = distance / time
|
||||
23
examples/01_simple_type_checking/03_control_flow.py
Normal file
23
examples/01_simple_type_checking/03_control_flow.py
Normal file
@@ -0,0 +1,23 @@
|
||||
def minimum(x: int, y: int):
|
||||
if x < y:
|
||||
return x
|
||||
else:
|
||||
return y
|
||||
|
||||
|
||||
a = 15
|
||||
b = 72
|
||||
c = minimum(a, b)
|
||||
|
||||
|
||||
def factorial(n: int) -> int:
|
||||
if n <= 1:
|
||||
return 1
|
||||
return n * factorial(n - 1)
|
||||
|
||||
|
||||
category = "Category 1" if a < 10 else "Category 2"
|
||||
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
@@ -1,5 +1,5 @@
|
||||
from pathlib import Path
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
HEADER = '''"""
|
||||
This file was generated by a script. Any manual changes might be overwritten.
|
||||
@@ -11,7 +11,7 @@ SECTION_TEMPLATE = """{banner}
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class {base}(ABC):
|
||||
location: Optional[Location] = None
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
52
gen/midas.py
52
gen/midas.py
@@ -13,40 +13,38 @@ from midas.lexer.token import Token
|
||||
|
||||
|
||||
###> Stmt | Statements
|
||||
class SimpleTypeStmt:
|
||||
class TypeStmt:
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
base: TypeExpr
|
||||
constraint: Optional[Expr]
|
||||
params: list[Param]
|
||||
type: Type
|
||||
|
||||
|
||||
class ComplexTypeStmt:
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Param:
|
||||
location: Location
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
properties: list[PropertyStmt]
|
||||
bound: Optional[Type]
|
||||
|
||||
|
||||
class PropertyStmt:
|
||||
name: Token
|
||||
type: TypeExpr
|
||||
constraint: Optional[Expr]
|
||||
type: Type
|
||||
|
||||
|
||||
class ExtendStmt:
|
||||
type: TypeExpr
|
||||
type: Type
|
||||
operations: list[OpStmt]
|
||||
|
||||
|
||||
class OpStmt:
|
||||
name: Token
|
||||
operand: TypeExpr
|
||||
result: TypeExpr
|
||||
operand: Type
|
||||
result: Type
|
||||
|
||||
|
||||
class PredicateStmt:
|
||||
name: Token
|
||||
subject: Token
|
||||
type: TypeExpr
|
||||
type: Type
|
||||
condition: Expr
|
||||
|
||||
|
||||
@@ -54,9 +52,6 @@ class PredicateStmt:
|
||||
|
||||
|
||||
###> Expr | Expressions
|
||||
class SimpleTypeExpr:
|
||||
name: Token
|
||||
optional: bool
|
||||
|
||||
|
||||
class LogicalExpr:
|
||||
@@ -97,14 +92,27 @@ class WildcardExpr:
|
||||
token: Token
|
||||
|
||||
|
||||
class TemplateExpr:
|
||||
type: TypeExpr
|
||||
###<
|
||||
|
||||
###> Type | Types
|
||||
|
||||
|
||||
class TypeExpr:
|
||||
class NamedType:
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
optional: bool
|
||||
|
||||
|
||||
class GenericType:
|
||||
type: Type
|
||||
params: list[Type]
|
||||
|
||||
|
||||
class ConstraintType:
|
||||
type: Type
|
||||
constraint: Expr
|
||||
|
||||
|
||||
class ComplexType:
|
||||
properties: list[PropertyStmt]
|
||||
|
||||
|
||||
###<
|
||||
|
||||
@@ -44,14 +44,22 @@ class Function:
|
||||
name: str
|
||||
posonlyargs: list[Argument]
|
||||
args: list[Argument]
|
||||
sink: Optional[Argument]
|
||||
kwonlyargs: list[Argument]
|
||||
kw_sink: Optional[Argument]
|
||||
returns: Optional[MidasType]
|
||||
body: list[Stmt]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
location: Optional[Location] = None
|
||||
name: Optional[str]
|
||||
name: str
|
||||
type: Optional[MidasType]
|
||||
default: Optional[Expr]
|
||||
|
||||
@property
|
||||
def all_args(self) -> list[Argument]:
|
||||
return self.posonlyargs + self.args + self.kwonlyargs
|
||||
|
||||
|
||||
class TypeAssign:
|
||||
@@ -64,6 +72,16 @@ class AssignStmt:
|
||||
value: Expr
|
||||
|
||||
|
||||
class ReturnStmt:
|
||||
value: Optional[Expr]
|
||||
|
||||
|
||||
class IfStmt:
|
||||
test: Expr
|
||||
body: list[Stmt]
|
||||
orelse: list[Stmt]
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
@@ -116,4 +134,15 @@ class SetExpr:
|
||||
value: Expr
|
||||
|
||||
|
||||
class CastExpr:
|
||||
type: MidasType
|
||||
expr: Expr
|
||||
|
||||
|
||||
class TernaryExpr:
|
||||
test: Expr
|
||||
if_true: Expr
|
||||
if_false: Expr
|
||||
|
||||
|
||||
###<
|
||||
|
||||
@@ -21,17 +21,14 @@ T = TypeVar("T")
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Stmt(ABC):
|
||||
location: Optional[Location] = None
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> T: ...
|
||||
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
|
||||
@@ -47,31 +44,25 @@ class Stmt(ABC):
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SimpleTypeStmt(Stmt):
|
||||
class TypeStmt(Stmt):
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
base: TypeExpr
|
||||
constraint: Optional[Expr]
|
||||
params: list[Param]
|
||||
type: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Param:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_simple_type_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ComplexTypeStmt(Stmt):
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
properties: list[PropertyStmt]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_complex_type_stmt(self)
|
||||
return visitor.visit_type_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PropertyStmt(Stmt):
|
||||
name: Token
|
||||
type: TypeExpr
|
||||
constraint: Optional[Expr]
|
||||
type: Type
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_property_stmt(self)
|
||||
@@ -79,7 +70,7 @@ class PropertyStmt(Stmt):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExtendStmt(Stmt):
|
||||
type: TypeExpr
|
||||
type: Type
|
||||
operations: list[OpStmt]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
@@ -89,8 +80,8 @@ class ExtendStmt(Stmt):
|
||||
@dataclass(frozen=True)
|
||||
class OpStmt(Stmt):
|
||||
name: Token
|
||||
operand: TypeExpr
|
||||
result: TypeExpr
|
||||
operand: Type
|
||||
result: Type
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_op_stmt(self)
|
||||
@@ -100,7 +91,7 @@ class OpStmt(Stmt):
|
||||
class PredicateStmt(Stmt):
|
||||
name: Token
|
||||
subject: Token
|
||||
type: TypeExpr
|
||||
type: Type
|
||||
condition: Expr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
@@ -114,15 +105,12 @@ class PredicateStmt(Stmt):
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Expr(ABC):
|
||||
location: Optional[Location] = None
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
|
||||
|
||||
@@ -147,21 +135,6 @@ class Expr(ABC):
|
||||
@abstractmethod
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_template_expr(self, expr: TemplateExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_type_expr(self, expr: TypeExpr) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SimpleTypeExpr(Expr):
|
||||
name: Token
|
||||
optional: bool
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_simple_type_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LogicalExpr(Expr):
|
||||
@@ -233,19 +206,61 @@ class WildcardExpr(Expr):
|
||||
return visitor.visit_wildcard_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TemplateExpr(Expr):
|
||||
type: TypeExpr
|
||||
#########
|
||||
# Types #
|
||||
#########
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_template_expr(self)
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Type(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_named_type(self, type: NamedType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_generic_type(self, type: GenericType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint_type(self, type: ConstraintType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_complex_type(self, type: ComplexType) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeExpr(Expr):
|
||||
class NamedType(Type):
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
optional: bool
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_type_expr(self)
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_named_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenericType(Type):
|
||||
type: Type
|
||||
params: list[Type]
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_generic_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstraintType(Type):
|
||||
type: Type
|
||||
constraint: Expr
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ComplexType(Type):
|
||||
properties: list[PropertyStmt]
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_complex_type(self)
|
||||
|
||||
@@ -85,40 +85,39 @@ class AstPrinter(Generic[T]):
|
||||
child.accept(self)
|
||||
|
||||
|
||||
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
|
||||
class MidasAstPrinter(
|
||||
AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None], m.Type.Visitor[None]
|
||||
):
|
||||
# Statements
|
||||
|
||||
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt):
|
||||
self._write_line("SimpleTypeStmt")
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
self._write_line("TypeStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_optional_child("template", stmt.template)
|
||||
self._write_line("base")
|
||||
with self._child_level(single=True):
|
||||
stmt.base.accept(self)
|
||||
self._write_optional_child("constraint", stmt.constraint, last=True)
|
||||
|
||||
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
|
||||
self._write_line("ComplexTypeStmt")
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_optional_child("template", stmt.template)
|
||||
self._write_line("properties", last=True)
|
||||
with self._child_level():
|
||||
for i, prop in enumerate(stmt.properties):
|
||||
for i, param in enumerate(stmt.params):
|
||||
self._idx = i
|
||||
if i == len(stmt.properties) - 1:
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
prop.accept(self)
|
||||
self._print_type_stmt_param(param)
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None:
|
||||
self._write_line("Param")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{param.name.lexeme}"')
|
||||
self._write_optional_child("bound", param.bound, last=True)
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||
self._write_line("PropertyStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type")
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
self._write_optional_child("constraint", stmt.constraint, last=True)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self._write_line("ExtendStmt")
|
||||
@@ -161,12 +160,6 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
|
||||
|
||||
# Expressions
|
||||
|
||||
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
|
||||
self._write_line("SimpleTypeExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"')
|
||||
self._write_line(f"optional: {expr.optional}", last=True)
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
self._write_line("LogicalExpr")
|
||||
with self._child_level():
|
||||
@@ -230,22 +223,48 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||
self._write_line("WildcardExpr")
|
||||
|
||||
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
|
||||
self._write_line("TemplateExpr")
|
||||
with self._child_level(single=True):
|
||||
def visit_named_type(self, type: m.NamedType) -> None:
|
||||
self._write_line("NamedType")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{type.name.lexeme}"', last=True)
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> None:
|
||||
self._write_line("GenericType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level():
|
||||
type.type.accept(self)
|
||||
self._write_line("params", last=True)
|
||||
with self._child_level():
|
||||
for i, param in enumerate(type.params):
|
||||
self._idx = i
|
||||
if i == len(type.params) - 1:
|
||||
self._mark_last()
|
||||
param.accept(self)
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
||||
self._write_line("ConstraintType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
expr.type.accept(self)
|
||||
type.type.accept(self)
|
||||
self._write_line("constraint", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.constraint.accept(self)
|
||||
|
||||
def visit_type_expr(self, expr: m.TypeExpr):
|
||||
self._write_line("TypeExpr")
|
||||
def visit_complex_type(self, type: m.ComplexType) -> None:
|
||||
self._write_line("ComplexType")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"')
|
||||
self._write_optional_child("template", expr.template)
|
||||
self._write_line(f"optional: {expr.optional}", last=True)
|
||||
self._write_line("properties", last=True)
|
||||
with self._child_level():
|
||||
for i, prop in enumerate(type.properties):
|
||||
self._idx = i
|
||||
if i == len(type.properties) - 1:
|
||||
self._mark_last()
|
||||
prop.accept(self)
|
||||
|
||||
|
||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
||||
def __init__(self, indent: int = 4):
|
||||
self.indent: int = indent
|
||||
self.level: int = 0
|
||||
@@ -253,33 +272,28 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
||||
def indented(self, text: str) -> str:
|
||||
return " " * (self.level * self.indent) + text
|
||||
|
||||
def print(self, expr: m.Expr | m.Stmt):
|
||||
def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
|
||||
self.level = 0
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt):
|
||||
template: str = stmt.template.accept(self) if stmt.template is not None else ""
|
||||
res: str = f"type {stmt.name.lexeme}{template}({stmt.base.accept(self)})"
|
||||
if stmt.constraint is not None:
|
||||
res += " where " + stmt.constraint.accept(self)
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
|
||||
template: str = ""
|
||||
if len(stmt.params) != 0:
|
||||
params: list[str] = [
|
||||
self._print_type_template_param(param) for param in stmt.params
|
||||
]
|
||||
template = f"[{', '.join(params)}]"
|
||||
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
||||
return self.indented(res)
|
||||
|
||||
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
|
||||
template: str = stmt.template.accept(self) if stmt.template is not None else ""
|
||||
res: str = self.indented(f"type {stmt.name.lexeme}{template}")
|
||||
res += " {\n"
|
||||
self.level += 1
|
||||
for prop in stmt.properties:
|
||||
res += prop.accept(self)
|
||||
res += "\n"
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
def _print_type_template_param(self, param: m.TypeStmt.Param) -> str:
|
||||
res: str = param.name.lexeme
|
||||
if param.bound is not None:
|
||||
res += "<:" + param.bound.accept(self)
|
||||
return res
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||
res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
|
||||
if stmt.constraint is not None:
|
||||
res += " where " + stmt.constraint.accept(self)
|
||||
return self.indented(res)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt):
|
||||
@@ -289,13 +303,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
||||
for op in stmt.operations:
|
||||
res += op.accept(self)
|
||||
self.level -= 1
|
||||
res += "\n" + self.indented("}")
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt):
|
||||
operand: str = stmt.operand.accept(self)
|
||||
result: str = stmt.result.accept(self)
|
||||
return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}")
|
||||
return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}\n")
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
name: str = stmt.name.lexeme
|
||||
@@ -304,9 +318,6 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
||||
condition: str = stmt.condition.accept(self)
|
||||
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
|
||||
|
||||
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
|
||||
return f"{expr.name.lexeme}{'?' if expr.optional else ''}"
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
operator: str = expr.operator.lexeme
|
||||
@@ -342,12 +353,30 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr):
|
||||
return "_"
|
||||
|
||||
def visit_template_expr(self, expr: m.TemplateExpr):
|
||||
return f"[{expr.type.accept(self)}]"
|
||||
def visit_named_type(self, type: m.NamedType) -> str:
|
||||
return type.name.lexeme
|
||||
|
||||
def visit_type_expr(self, expr: m.TypeExpr):
|
||||
template: str = expr.template.accept(self) if expr.template is not None else ""
|
||||
return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}"
|
||||
def visit_generic_type(self, type: m.GenericType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
if len(type.params) != 0:
|
||||
params: list[str] = [param.accept(self) for param in type.params]
|
||||
res += f"[{', '.join(params)}]"
|
||||
return res
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
res += " where " + type.constraint.accept(self)
|
||||
return res
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> str:
|
||||
res: str = "{\n"
|
||||
self.level += 1
|
||||
for prop in type.properties:
|
||||
res += prop.accept(self)
|
||||
res += "\n"
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
|
||||
class PythonAstPrinter(
|
||||
@@ -419,7 +448,14 @@ class PythonAstPrinter(
|
||||
self._mark_last()
|
||||
self._print_argument(arg)
|
||||
|
||||
self._write_optional_child("returns", stmt.returns, last=True)
|
||||
self._write_optional_child("returns", stmt.returns)
|
||||
self._write_line("body", last=True)
|
||||
with self._child_level():
|
||||
for i, body_stmt in enumerate(stmt.body):
|
||||
self._idx = i
|
||||
if i == len(stmt.body) - 1:
|
||||
self._mark_last()
|
||||
body_stmt.accept(self)
|
||||
|
||||
def _print_argument(self, arg: p.Function.Argument) -> None:
|
||||
self._write_line("FunctionArgument")
|
||||
@@ -449,6 +485,32 @@ class PythonAstPrinter(
|
||||
with self._child_level(single=True):
|
||||
stmt.value.accept(self)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
self._write_line("ReturnStmt")
|
||||
with self._child_level():
|
||||
self._write_optional_child("value", stmt.value, last=True)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
self._write_line("IfStmt")
|
||||
with self._child_level():
|
||||
self._write_line("test")
|
||||
with self._child_level(single=True):
|
||||
stmt.test.accept(self)
|
||||
self._write_line("body")
|
||||
with self._child_level():
|
||||
for i, body_stmt in enumerate(stmt.body):
|
||||
self._idx = i
|
||||
if i == len(stmt.body) - 1:
|
||||
self._mark_last()
|
||||
body_stmt.accept(self)
|
||||
self._write_line("orelse", last=True)
|
||||
with self._child_level():
|
||||
for i, else_stmt in enumerate(stmt.orelse):
|
||||
self._idx = i
|
||||
if i == len(stmt.orelse) - 1:
|
||||
self._mark_last()
|
||||
else_stmt.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||
self._write_line("BinaryExpr")
|
||||
with self._child_level():
|
||||
@@ -550,3 +612,28 @@ class PythonAstPrinter(
|
||||
self._write_line("value", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.value.accept(self)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||
self._write_line("CastExpr")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
expr.type.accept(self)
|
||||
self._write_line("expr", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||
self._write_line("TernaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("test")
|
||||
with self._child_level(single=True):
|
||||
expr.test.accept(self)
|
||||
|
||||
self._write_line("if_true")
|
||||
with self._child_level(single=True):
|
||||
expr.if_true.accept(self)
|
||||
|
||||
self._write_line("if_false", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.if_false.accept(self)
|
||||
|
||||
@@ -21,7 +21,7 @@ T = TypeVar("T")
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MidasType(ABC):
|
||||
location: Optional[Location] = None
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
@@ -82,7 +82,7 @@ class FrameType(MidasType):
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Stmt(ABC):
|
||||
location: Optional[Location] = None
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
@@ -100,6 +100,12 @@ class Stmt(ABC):
|
||||
@abstractmethod
|
||||
def visit_assign_stmt(self, stmt: AssignStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_return_stmt(self, stmt: ReturnStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_if_stmt(self, stmt: IfStmt) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExpressionStmt(Stmt):
|
||||
@@ -114,14 +120,22 @@ class Function(Stmt):
|
||||
name: str
|
||||
posonlyargs: list[Argument]
|
||||
args: list[Argument]
|
||||
sink: Optional[Argument]
|
||||
kwonlyargs: list[Argument]
|
||||
kw_sink: Optional[Argument]
|
||||
returns: Optional[MidasType]
|
||||
body: list[Stmt]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
location: Optional[Location] = None
|
||||
name: Optional[str]
|
||||
name: str
|
||||
type: Optional[MidasType]
|
||||
default: Optional[Expr]
|
||||
|
||||
@property
|
||||
def all_args(self) -> list[Argument]:
|
||||
return self.posonlyargs + self.args + self.kwonlyargs
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_function(self)
|
||||
@@ -145,6 +159,24 @@ class AssignStmt(Stmt):
|
||||
return visitor.visit_assign_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReturnStmt(Stmt):
|
||||
value: Optional[Expr]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_return_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IfStmt(Stmt):
|
||||
test: Expr
|
||||
body: list[Stmt]
|
||||
orelse: list[Stmt]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_if_stmt(self)
|
||||
|
||||
|
||||
###############
|
||||
# Expressions #
|
||||
###############
|
||||
@@ -152,7 +184,7 @@ class AssignStmt(Stmt):
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Expr(ABC):
|
||||
location: Optional[Location] = None
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
@@ -185,6 +217,12 @@ class Expr(ABC):
|
||||
@abstractmethod
|
||||
def visit_set_expr(self, expr: SetExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_cast_expr(self, expr: CastExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BinaryExpr(Expr):
|
||||
@@ -268,3 +306,22 @@ class SetExpr(Expr):
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_set_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CastExpr(Expr):
|
||||
type: MidasType
|
||||
expr: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_cast_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TernaryExpr(Expr):
|
||||
test: Expr
|
||||
if_true: Expr
|
||||
if_false: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_ternary_expr(self)
|
||||
|
||||
540
midas/checker/checker.py
Normal file
540
midas/checker/checker.py
Normal file
@@ -0,0 +1,540 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
|
||||
from midas.checker.types import Function, Type, UnitType, UnknownType
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
from midas.resolver.midas import MidasResolver
|
||||
|
||||
|
||||
class ReturnException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MappedArgument:
|
||||
expr: p.Expr
|
||||
type: Type
|
||||
argument: Function.Argument
|
||||
|
||||
|
||||
class Checker(
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[Type],
|
||||
p.MidasType.Visitor[Type],
|
||||
):
|
||||
"""A type checker which can use custom type definitions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
locals: dict[p.Expr, int],
|
||||
source_path: Path,
|
||||
types_paths: list[Path],
|
||||
):
|
||||
self.logger: logging.Logger = logging.getLogger("Checker")
|
||||
self.source_path: Path = source_path
|
||||
self.types_paths: list[Path] = types_paths
|
||||
self.ctx: MidasResolver = MidasResolver()
|
||||
self.global_env: Environment = Environment()
|
||||
self.env: Environment = self.global_env
|
||||
self.locals: dict[p.Expr, int] = locals
|
||||
self.diagnostics: list[Diagnostic] = []
|
||||
|
||||
def diagnostic(self, type: DiagnosticType, location: Location, message: str):
|
||||
self.diagnostics.append(
|
||||
Diagnostic(
|
||||
file_path=self.source_path,
|
||||
location=location,
|
||||
type=type,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
|
||||
def error(self, location: Location, message: str):
|
||||
self.diagnostic(
|
||||
type=DiagnosticType.ERROR,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def warning(self, location: Location, message: str):
|
||||
self.diagnostic(
|
||||
type=DiagnosticType.WARNING,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def info(self, location: Location, message: str):
|
||||
self.diagnostic(
|
||||
type=DiagnosticType.INFO,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def type_of(self, expr: p.Expr) -> Type:
|
||||
"""Evaluate the type of an expression
|
||||
|
||||
Args:
|
||||
expr (p.Expr): the expression to evaluate
|
||||
|
||||
Returns:
|
||||
Type: the type of the given expression
|
||||
"""
|
||||
return expr.accept(self)
|
||||
|
||||
def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
|
||||
"""Evaluate a sequence of statements
|
||||
|
||||
Args:
|
||||
block (list[p.Stmt]): the statements to evaluate
|
||||
env (Environment): the environment in which to evaluate
|
||||
|
||||
Returns:
|
||||
bool: whether a return statement is present in the block
|
||||
"""
|
||||
previous_env: Environment = self.env
|
||||
self.env = env
|
||||
returned: bool = False
|
||||
for i, stmt in enumerate(block):
|
||||
try:
|
||||
stmt.accept(self)
|
||||
except ReturnException:
|
||||
returned = True
|
||||
if i < len(block) - 1:
|
||||
self.warning(block[i + 1].location, "Unreachable statement")
|
||||
break
|
||||
self.env = previous_env
|
||||
return returned
|
||||
|
||||
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
|
||||
"""Type check a sequence of statements and returns diagnostics
|
||||
|
||||
Args:
|
||||
statements (list[p.Stmt]): the statements to evaluate and check
|
||||
|
||||
Returns:
|
||||
list[Diagnostic]: the list of diagnostics (errors, warning, etc.)
|
||||
"""
|
||||
self.diagnostics = []
|
||||
|
||||
for path in self.types_paths:
|
||||
self.import_midas(path)
|
||||
self.logger.debug(f"Midas types: {self.ctx._types}")
|
||||
self.logger.debug(f"Midas operations: {self.ctx._operations}")
|
||||
|
||||
for stmt in statements:
|
||||
stmt.accept(self)
|
||||
|
||||
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
|
||||
return self.diagnostics
|
||||
|
||||
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
|
||||
"""Look up a variable in the environment it was declared
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
expr (p.Expr): the variable expression, used to lookup the scope distance
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the type of the variable, or None if it was not found
|
||||
"""
|
||||
distance: Optional[int] = self.locals.get(expr)
|
||||
if distance is not None:
|
||||
return self.env.get_at(distance, name)
|
||||
return self.global_env.get(name)
|
||||
|
||||
def import_midas(self, path: Path) -> None:
|
||||
"""Import Midas definitions from a path
|
||||
|
||||
Args:
|
||||
path (Path): the import path
|
||||
"""
|
||||
self.logger.debug(f"Importing type definitions from {path}")
|
||||
lexer: MidasLexer = MidasLexer(path.read_text())
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser: MidasParser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
self.ctx.resolve(stmts)
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
self.type_of(stmt.expr)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
env: Environment = Environment(self.env)
|
||||
pos_args: list[Function.Argument] = []
|
||||
args: list[Function.Argument] = []
|
||||
kw_args: list[Function.Argument] = []
|
||||
|
||||
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
||||
if arg.type is not None:
|
||||
return arg.type.accept(self)
|
||||
if arg.default is not None:
|
||||
return arg.default.accept(self)
|
||||
return UnknownType()
|
||||
|
||||
for arg in stmt.posonlyargs:
|
||||
pos_args.append(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
for arg in stmt.args:
|
||||
args.append(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
for arg in stmt.kwonlyargs:
|
||||
kw_args.append(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
|
||||
for arg in pos_args + args + kw_args:
|
||||
env.define(arg.name, arg.type)
|
||||
|
||||
returns_hint: Optional[Type] = None
|
||||
if stmt.returns is not None:
|
||||
returns_hint = stmt.returns.accept(self)
|
||||
# Early define to handle simple fully-typed recursion
|
||||
inside_function: Function = Function(
|
||||
name=stmt.name,
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns_hint,
|
||||
)
|
||||
self.env.define(stmt.name, inside_function)
|
||||
|
||||
returned: bool = self.process_block(stmt.body, env)
|
||||
inferred_return: Type = UnknownType()
|
||||
if not returned:
|
||||
env.return_types.append(UnitType())
|
||||
return_types: set[Type] = set(env.return_types)
|
||||
if len(return_types) == 1:
|
||||
inferred_return = list(return_types)[0]
|
||||
elif len(return_types) > 1:
|
||||
self.error(
|
||||
stmt.location,
|
||||
f"Mixed return types: {env.return_types}",
|
||||
)
|
||||
|
||||
returns: Type = UnknownType()
|
||||
if returns_hint is not None:
|
||||
assert stmt.returns is not None
|
||||
returns = returns_hint
|
||||
if returns != inferred_return:
|
||||
self.error(
|
||||
stmt.returns.location,
|
||||
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
||||
)
|
||||
else:
|
||||
returns = inferred_return
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
function: Function = Function(
|
||||
name=stmt.name,
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns,
|
||||
)
|
||||
self.env.define(stmt.name, function)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
# TODO check not yet defined locally
|
||||
type: Type = stmt.type.accept(self)
|
||||
self.env.define(stmt.name, type)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
value: Type = self.type_of(stmt.value)
|
||||
for target in stmt.targets:
|
||||
if not isinstance(target, p.VariableExpr):
|
||||
self.logger.warning(f"Unsupported assignment to {target}")
|
||||
self.warning(target.location, f"Unsupported assignment to {target}")
|
||||
continue
|
||||
name: str = target.name
|
||||
var_type: Optional[Type] = self.look_up_variable(name, target)
|
||||
|
||||
if var_type is None:
|
||||
self.env.define(name, value)
|
||||
else:
|
||||
# TODO: implement real comparison method
|
||||
if var_type != value:
|
||||
self.error(
|
||||
stmt.location,
|
||||
f"Cannot assign {value} to {name} of type {var_type}",
|
||||
)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
|
||||
self.env.return_types.append(type)
|
||||
raise ReturnException()
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
# Not evaluated in sub-environment because assignments in the test leak out of the if
|
||||
# For example:
|
||||
# if (m := 1 + 1) < 2:
|
||||
# ...
|
||||
# print(m) # <- m is still defined
|
||||
test_type: Type = stmt.test.accept(self)
|
||||
|
||||
# TODO Allow subtypes or any type
|
||||
if test_type != self.ctx.get_type("bool"):
|
||||
self.error(
|
||||
stmt.test.location, f"If test must be a boolean, got {test_type}"
|
||||
)
|
||||
|
||||
env: Environment = Environment(self.env)
|
||||
body_returned: bool = self.process_block(stmt.body, env)
|
||||
else_returned: bool = self.process_block(stmt.orelse, env)
|
||||
self.env.return_types.extend(env.return_types)
|
||||
if body_returned and else_returned:
|
||||
raise ReturnException()
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
||||
return UnknownType()
|
||||
left: Type = self.type_of(expr.left)
|
||||
right: Type = self.type_of(expr.right)
|
||||
|
||||
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
||||
if result is None:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Undefined operation {method} between {left} and {right}",
|
||||
)
|
||||
return UnknownType()
|
||||
return result
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
||||
return UnknownType()
|
||||
left: Type = self.type_of(expr.left)
|
||||
right: Type = self.type_of(expr.right)
|
||||
|
||||
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
||||
if result is None:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Undefined operation {method} between {left} and {right}",
|
||||
)
|
||||
return UnknownType()
|
||||
return result
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||
callee: Type = self.type_of(expr.callee)
|
||||
if not isinstance(callee, Function):
|
||||
self.error(expr.callee.location, "Callee is not a function")
|
||||
return UnknownType()
|
||||
function: Function = callee
|
||||
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
||||
for arg in mapped:
|
||||
if arg.type != arg.argument.type:
|
||||
self.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
)
|
||||
return function.returns
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> Type: ...
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
|
||||
match expr.value:
|
||||
case bool(): # Must be before int
|
||||
return self.ctx.get_type("bool")
|
||||
case int():
|
||||
return self.ctx.get_type("int")
|
||||
case float():
|
||||
return self.ctx.get_type("float")
|
||||
case str():
|
||||
return self.ctx.get_type("str")
|
||||
case _:
|
||||
self.warning(expr.location, f"Unknown literal {expr}")
|
||||
return UnknownType()
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
|
||||
return self.look_up_variable(expr.name, expr) or UnknownType()
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
|
||||
left: Type = expr.left.accept(self)
|
||||
right: Type = expr.right.accept(self)
|
||||
# TODO: union type
|
||||
if left != right:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Operands must be of the same type, left={left} != right={right}",
|
||||
)
|
||||
return left
|
||||
|
||||
def visit_set_expr(self, expr: p.SetExpr) -> Type: ...
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
||||
return expr.type.accept(self)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
||||
test_type: Type = expr.test.accept(self)
|
||||
|
||||
# TODO Allow subtypes or any type
|
||||
if test_type != self.ctx.get_type("bool"):
|
||||
self.error(
|
||||
expr.test.location, f"If test must be a boolean, got {test_type}"
|
||||
)
|
||||
|
||||
true_type: Type = expr.if_true.accept(self)
|
||||
false_type: Type = expr.if_false.accept(self)
|
||||
if true_type != false_type:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Type mismatch in ternary if branches: true={true_type} != false={false_type}",
|
||||
)
|
||||
return UnknownType()
|
||||
return true_type
|
||||
|
||||
def visit_base_type(self, node: p.BaseType) -> Type:
|
||||
return self.ctx.get_type(node.base)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> Type: ...
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> Type: ...
|
||||
|
||||
def map_call_arguments(
|
||||
self, function: Function, call: p.CallExpr
|
||||
) -> list[MappedArgument]:
|
||||
"""Map call arguments to function parameters as defined in its signature
|
||||
|
||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||
with the arguments passed at the call site
|
||||
|
||||
Any mismatched, missing or unexpected argument is reported as a diagnostic
|
||||
|
||||
Args:
|
||||
function (Function): the function definition
|
||||
call (p.CallExpr): the call expression
|
||||
|
||||
Returns:
|
||||
list[MappedArgument]: the list of mapped arguments
|
||||
"""
|
||||
positional: list[tuple[p.Expr, Type]] = [
|
||||
(arg, self.type_of(arg)) for arg in call.arguments
|
||||
]
|
||||
keywords: dict[str, tuple[p.Expr, Type]] = {
|
||||
name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
|
||||
}
|
||||
set_args: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
arg.name for arg in function.pos_args + function.args if arg.required
|
||||
]
|
||||
required_keyword: list[str] = [
|
||||
arg.name for arg in function.kw_args if arg.required
|
||||
]
|
||||
|
||||
mapped: list[MappedArgument] = []
|
||||
|
||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||
mixed_params: list[Function.Argument] = list(function.args)
|
||||
kw_params: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.kw_args
|
||||
}
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Argument
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
self.error(arg[0].location, "Too many positional arguments")
|
||||
break
|
||||
name: str = param.name
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Argument
|
||||
if name not in kw_params:
|
||||
if name in set_args:
|
||||
self.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
)
|
||||
else:
|
||||
self.error(arg[0].location, f"Unknown keyword argument '{name}'")
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
def join_args(args: list[str]) -> str:
|
||||
args = list(map(lambda a: f"'{a}'", args))
|
||||
if len(args) == 0:
|
||||
return ""
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
args: str = join_args(required_positional)
|
||||
self.error(
|
||||
call.location,
|
||||
f"Missing required positional argument{plural}: {args}",
|
||||
)
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
args: str = join_args(required_keyword)
|
||||
self.error(
|
||||
call.location,
|
||||
f"Missing required keyword argument{plural}: {args}",
|
||||
)
|
||||
|
||||
return mapped
|
||||
33
midas/checker/diagnostic.py
Normal file
33
midas/checker/diagnostic.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
|
||||
|
||||
class DiagnosticType(StrEnum):
|
||||
ERROR = "Error"
|
||||
WARNING = "Warning"
|
||||
INFO = "Info"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Diagnostic:
|
||||
file_path: Path
|
||||
location: Location
|
||||
type: DiagnosticType
|
||||
message: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
|
||||
end_loc: Optional[str] = ""
|
||||
if (
|
||||
self.location.end_lineno is not None
|
||||
and self.location.end_col_offset is not None
|
||||
):
|
||||
end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}"
|
||||
loc: str = (
|
||||
f"at {start_loc}" if end_loc is None else f"from {start_loc} to {end_loc}"
|
||||
)
|
||||
return f"{self.type} in {self.file_path} {loc}: {self.message}"
|
||||
142
midas/checker/environment.py
Normal file
142
midas/checker/environment.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from midas.checker.types import Type
|
||||
|
||||
|
||||
class Environment:
|
||||
"""
|
||||
A scoped environment in which variables are defined
|
||||
|
||||
Each environment can inherit from a parent/enclosing environment.
|
||||
"""
|
||||
|
||||
def __init__(self, enclosing: Optional[Environment] = None) -> None:
|
||||
self.enclosing: Optional[Environment] = enclosing
|
||||
self.values: dict[str, Type] = {}
|
||||
self.return_types: list[Type] = []
|
||||
|
||||
self._children: list[Environment] = []
|
||||
if enclosing is not None:
|
||||
enclosing._children.append(self)
|
||||
|
||||
def define(self, name: str, value: Type) -> None:
|
||||
"""Define a variable in this environment
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
value (Type): the value
|
||||
"""
|
||||
self.values[name] = value
|
||||
|
||||
def get(self, name: str) -> Optional[Type]:
|
||||
"""Get a variable in the closest environment which has a definition for it
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the value of the variable, or None if it was not found
|
||||
"""
|
||||
if name in self.values:
|
||||
return self.values[name]
|
||||
if self.enclosing is not None:
|
||||
return self.enclosing.get(name)
|
||||
# raise NameError(f"Undefined variable '{name}'")
|
||||
return None
|
||||
|
||||
def assign(self, name: str, value: Type) -> bool:
|
||||
"""Assign a new value to a variable in the environment it was defined in
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
value (Type): the new value
|
||||
|
||||
Returns:
|
||||
bool: True if the variable was assigned in this environment or an ancestor, False otherwise
|
||||
"""
|
||||
if name not in self.values:
|
||||
if self.enclosing is None:
|
||||
return False
|
||||
if self.enclosing.assign(name, value):
|
||||
return True
|
||||
self.values[name] = value
|
||||
return True
|
||||
|
||||
def clear(self):
|
||||
"""Clear all definitions in this environment"""
|
||||
self.values = {}
|
||||
|
||||
def get_at(self, distance: int, name: str) -> Optional[Type]:
|
||||
"""Get the value of a variable at a given distance
|
||||
|
||||
A distance of 0 looks up in this environment, 1 in the parent environment, etc.
|
||||
This methods expects `distance` to be valid. An error will be raised if
|
||||
the stack does not extend far enough to reach `distance`
|
||||
|
||||
Args:
|
||||
distance (int): the scope distance
|
||||
name (str): the name of the variable
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the value at the given distance, or None if it is not defined in that environment
|
||||
|
||||
Raises:
|
||||
AssertionError: if the stack does not extend far enough to reach `distance`
|
||||
"""
|
||||
return self.ancestor(distance).values.get(name)
|
||||
|
||||
def assign_at(self, distance: int, name: str, value: Type) -> None:
|
||||
"""Assign a new value to a variable at a given distance
|
||||
|
||||
A distance of 0 assigns in this environment, 1 in the parent environment, etc.
|
||||
|
||||
Args:
|
||||
distance (int): the scope distance
|
||||
name (str): the name of the variable
|
||||
value (Type): the new value
|
||||
|
||||
Raises:
|
||||
AssertionError: if the stack does not extend far enough to reach `distance`
|
||||
"""
|
||||
self.ancestor(distance).values[name] = value
|
||||
|
||||
def ancestor(self, distance: int) -> Environment:
|
||||
"""Get the ancestor at a given distance
|
||||
|
||||
A distance of 0 references this environment, 1 the parent environment, etc.
|
||||
|
||||
Args:
|
||||
distance (int): the scope distance
|
||||
|
||||
Returns:
|
||||
Environment: the environment
|
||||
|
||||
Raises:
|
||||
AssertionError: if the stack does not extend far enough to reach `distance`
|
||||
"""
|
||||
env: Environment = self
|
||||
for _ in range(distance):
|
||||
assert env.enclosing is not None
|
||||
env = env.enclosing
|
||||
return env
|
||||
|
||||
def flat_dict(self) -> dict[str, Type]:
|
||||
"""Get the current environment including definitions in its ancestor as a flat dictionary
|
||||
|
||||
This method recursively combines this environment definitions with its ancestor's
|
||||
|
||||
Returns:
|
||||
dict: the combined environment
|
||||
"""
|
||||
if self.enclosing is None:
|
||||
return self.values
|
||||
return self.enclosing.flat_dict() | self.values
|
||||
|
||||
def dump(self) -> dict:
|
||||
return {
|
||||
"values": self.values,
|
||||
"return_types": self.return_types,
|
||||
"children": [child.dump() for child in self._children],
|
||||
}
|
||||
31
midas/checker/operators.py
Normal file
31
midas/checker/operators.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import ast
|
||||
from typing import Type
|
||||
|
||||
OPERATOR_METHODS: dict[Type[ast.operator], str] = {
|
||||
ast.Add: "__add__",
|
||||
ast.Sub: "__sub__",
|
||||
ast.Mult: "__mul__",
|
||||
ast.MatMult: "__matmul__",
|
||||
ast.Div: "__truediv__",
|
||||
ast.Mod: "__mod__",
|
||||
ast.Pow: "__pow__",
|
||||
ast.LShift: "__lshift__",
|
||||
ast.RShift: "__rshift__",
|
||||
ast.BitOr: "__or__",
|
||||
ast.BitXor: "__xor__",
|
||||
ast.BitAnd: "__and__",
|
||||
ast.FloorDiv: "__floordiv__",
|
||||
}
|
||||
|
||||
COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
||||
ast.Eq: "__eq__",
|
||||
# ast.NotEq: "__noteq__",
|
||||
ast.Lt: "__lt__",
|
||||
ast.LtE: "__le__",
|
||||
ast.Gt: "__gt__",
|
||||
ast.GtE: "__ge__",
|
||||
# ast.Is: "__is__",
|
||||
# ast.IsNot: "__isnot__",
|
||||
# ast.In: "__in__",
|
||||
# ast.NotIn: "__notin__",
|
||||
}
|
||||
47
midas/checker/types.py
Normal file
47
midas/checker/types.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class BaseType:
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class AliasType:
|
||||
name: str
|
||||
type: Type
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class UnknownType:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class UnitType:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Function:
|
||||
name: str
|
||||
pos_args: list[Argument]
|
||||
args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
name: str
|
||||
type: Type
|
||||
required: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ComplexType:
|
||||
properties: dict[str, Type]
|
||||
|
||||
|
||||
Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType
|
||||
@@ -53,5 +53,6 @@ span {
|
||||
|
||||
&.keyword {
|
||||
color: rgb(211, 72, 9);
|
||||
pointer-events: none;
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generic, Optional, Protocol, TextIO, TypeVar
|
||||
|
||||
from midas.ast.location import Location
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.lexer.token import Token
|
||||
|
||||
H = TypeVar("H", bound="Highlighter", contravariant=True)
|
||||
|
||||
@@ -21,6 +24,15 @@ class Locatable(Protocol):
|
||||
def location(self) -> Optional[Location]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LocatableToken:
|
||||
token: Token
|
||||
|
||||
@property
|
||||
def location(self) -> Location:
|
||||
return self.token.get_location()
|
||||
|
||||
|
||||
class Highlighter(ABC):
|
||||
BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css"
|
||||
EXTRA_CSS_PATH: Optional[Path] = None
|
||||
@@ -71,6 +83,7 @@ class Highlighter(ABC):
|
||||
openings: list[str] = self.openings.get(pos, [])
|
||||
line_buf += "".join(closings + openings)
|
||||
line_buf += char
|
||||
line_buf += "".join(self.closings.get((lineno, len(line)), []))
|
||||
line_buf += "</div></div>"
|
||||
lines.append(" " + line_buf)
|
||||
lines.extend(
|
||||
@@ -83,7 +96,7 @@ class Highlighter(ABC):
|
||||
|
||||
buf.write("\n".join(lines))
|
||||
|
||||
def wrap(self, node: Locatable, cls: str):
|
||||
def wrap(self, node: Locatable, cls: str, message: Optional[str] = None):
|
||||
if node.location is None:
|
||||
return
|
||||
if node.location.end_lineno is None or node.location.end_col_offset is None:
|
||||
@@ -95,6 +108,10 @@ class Highlighter(ABC):
|
||||
)
|
||||
opening: str = f'<span class="{cls}" title="{cls}">'
|
||||
closing: str = "</span>"
|
||||
if message is not None:
|
||||
opening = f'<span class="with-msg">{opening}'
|
||||
closing = f'{closing}<span class="message">{message}</span></span>'
|
||||
|
||||
self.openings.setdefault(start_pos, []).append(opening)
|
||||
self.closings.setdefault(end_pos, []).insert(0, closing)
|
||||
if start_pos[0] != end_pos[0]:
|
||||
@@ -142,6 +159,8 @@ class PythonHighlighter(
|
||||
self.wrap(stmt, "function")
|
||||
for arg in stmt.posonlyargs + stmt.args + stmt.kwonlyargs:
|
||||
self._highlight_function_argument(arg)
|
||||
for body_stmt in stmt.body:
|
||||
body_stmt.accept(self)
|
||||
|
||||
def _highlight_function_argument(self, arg: p.Function.Argument) -> None:
|
||||
self.wrap(arg, "argument")
|
||||
@@ -151,7 +170,23 @@ class PythonHighlighter(
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: ...
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
for target in stmt.targets:
|
||||
target.accept(self)
|
||||
stmt.value.accept(self)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
self.wrap(stmt, "return")
|
||||
if stmt.value is not None:
|
||||
stmt.value.accept(self)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
self.wrap(stmt, "if")
|
||||
stmt.test.accept(self)
|
||||
for body_stmt in stmt.body:
|
||||
body_stmt.accept(self)
|
||||
for else_stmt in stmt.orelse:
|
||||
else_stmt.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
|
||||
|
||||
@@ -159,7 +194,13 @@ class PythonHighlighter(
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ...
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> None: ...
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||
self.wrap(expr, "call")
|
||||
expr.callee.accept(self)
|
||||
for arg in expr.arguments:
|
||||
arg.accept(self)
|
||||
for arg in expr.keywords.values():
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> None: ...
|
||||
|
||||
@@ -171,35 +212,27 @@ class PythonHighlighter(
|
||||
|
||||
def visit_set_expr(self, expr: p.SetExpr) -> None: ...
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
|
||||
|
||||
class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
|
||||
|
||||
|
||||
class MidasHighlighter(
|
||||
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
|
||||
):
|
||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"
|
||||
|
||||
def highlight(self, node: Highlightable[MidasHighlighter]):
|
||||
node.accept(self)
|
||||
|
||||
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None:
|
||||
self.wrap(stmt, "simple-type")
|
||||
if stmt.template is not None:
|
||||
stmt.template.accept(self)
|
||||
stmt.base.accept(self)
|
||||
if stmt.constraint is not None:
|
||||
self.wrap(stmt.constraint, "constraint")
|
||||
stmt.constraint.accept(self)
|
||||
|
||||
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None:
|
||||
self.wrap(stmt, "complex-type")
|
||||
if stmt.template is not None:
|
||||
stmt.template.accept(self)
|
||||
for prop in stmt.properties:
|
||||
prop.accept(self)
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
self.wrap(stmt, "type-stmt")
|
||||
self.wrap(LocatableToken(stmt.name), "type-name")
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None:
|
||||
self.wrap(stmt, "property")
|
||||
stmt.type.accept(self)
|
||||
if stmt.constraint is not None:
|
||||
self.wrap(stmt.constraint, "constraint")
|
||||
stmt.constraint.accept(self)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self.wrap(stmt, "extend")
|
||||
@@ -209,17 +242,16 @@ class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
|
||||
self.wrap(stmt, "op")
|
||||
self.wrap(LocatableToken(stmt.name), "op-name")
|
||||
stmt.operand.accept(self)
|
||||
stmt.result.accept(self)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||
self.wrap(stmt, "predicate")
|
||||
self.wrap(LocatableToken(stmt.name), "predicate-name")
|
||||
stmt.type.accept(self)
|
||||
stmt.condition.accept(self)
|
||||
|
||||
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> None:
|
||||
self.wrap(expr, "simple-type-expr")
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
||||
self.wrap(expr, "logical-expr")
|
||||
expr.left.accept(self)
|
||||
@@ -248,11 +280,29 @@ class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
||||
|
||||
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
|
||||
self.wrap(expr, "template")
|
||||
expr.type.accept(self)
|
||||
def visit_named_type(self, type: m.NamedType) -> None:
|
||||
self.wrap(type, "named-type")
|
||||
|
||||
def visit_type_expr(self, expr: m.TypeExpr) -> None:
|
||||
self.wrap(expr, "type")
|
||||
if expr.template is not None:
|
||||
expr.template.accept(self)
|
||||
def visit_generic_type(self, type: m.GenericType) -> None:
|
||||
self.wrap(type, "generic-type")
|
||||
type.type.accept(self)
|
||||
for param in type.params:
|
||||
param.accept(self)
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
||||
self.wrap(type, "constraint-type")
|
||||
type.type.accept(self)
|
||||
type.constraint.accept(self)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> None:
|
||||
self.wrap(type, "complex-type")
|
||||
for prop in type.properties:
|
||||
prop.accept(self)
|
||||
|
||||
|
||||
class DiagnosticsHighlighter(Highlighter):
|
||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
|
||||
|
||||
def highlight(self, diagnostics: list[Diagnostic]):
|
||||
for diagnostic in diagnostics:
|
||||
self.wrap(diagnostic, str(diagnostic.type).lower(), diagnostic.message)
|
||||
|
||||
39
midas/cli/hl_diagnostic.css
Normal file
39
midas/cli/hl_diagnostic.css
Normal file
@@ -0,0 +1,39 @@
|
||||
span {
|
||||
--opacity: 0.4;
|
||||
|
||||
&.error {
|
||||
--col: 255, 0, 0;
|
||||
}
|
||||
&.warning {
|
||||
--col: 250, 160, 0;
|
||||
}
|
||||
&.info {
|
||||
--col: 150, 190, 250;
|
||||
}
|
||||
|
||||
&.with-msg {
|
||||
position: relative;
|
||||
|
||||
.message {
|
||||
display: none;
|
||||
}
|
||||
|
||||
&:hover:not(:has(.with-msg:hover)) {
|
||||
.message {
|
||||
display: inline-block;
|
||||
}
|
||||
}
|
||||
|
||||
.message {
|
||||
position: absolute;
|
||||
top: calc(100% + 0.2em);
|
||||
left: -.2em;
|
||||
background-color: black;
|
||||
color: white;
|
||||
padding: 0.2em 0.4em;
|
||||
border-radius: .2em;
|
||||
z-index: 10;
|
||||
width: 300%;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,12 +5,11 @@ span {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
&.simple-type {
|
||||
--col: 108, 233, 108;
|
||||
}
|
||||
|
||||
&.named-type,
|
||||
&.generic-type,
|
||||
&.constraint-type,
|
||||
&.complex-type {
|
||||
--col: 233, 206, 108;
|
||||
--col: 150, 150, 150;
|
||||
}
|
||||
|
||||
&.constraint {
|
||||
@@ -33,10 +32,6 @@ span {
|
||||
--col: 193, 108, 233;
|
||||
}
|
||||
|
||||
&.simple-type-expr {
|
||||
--col: 150, 150, 150;
|
||||
}
|
||||
|
||||
&.logical-expr,
|
||||
&.binary-expr,
|
||||
&.unary-expr,
|
||||
@@ -48,7 +43,9 @@ span {
|
||||
--col: 163, 117, 71;
|
||||
}
|
||||
|
||||
&.type {
|
||||
&.type-name,
|
||||
&.op-name,
|
||||
&.predicate-name {
|
||||
--col: 200, 200, 200;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
@@ -1,18 +1,30 @@
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, TextIO
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, TextIO, get_args
|
||||
|
||||
import click
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.printer import PythonAstPrinter
|
||||
from midas.cli.highlighter import Highlighter, MidasHighlighter, PythonHighlighter
|
||||
from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
|
||||
from midas.checker.checker import Checker
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.checker.types import Type
|
||||
from midas.cli.highlighter import (
|
||||
DiagnosticsHighlighter,
|
||||
Highlighter,
|
||||
LocatableToken,
|
||||
MidasHighlighter,
|
||||
PythonHighlighter,
|
||||
)
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token, TokenType
|
||||
from midas.parser.midas import MidasParser
|
||||
from midas.parser.python import PythonParser
|
||||
from midas.resolver.resolver import Resolver
|
||||
from midas.utils import UniversalJSONDumper
|
||||
|
||||
|
||||
@click.group()
|
||||
@@ -21,9 +33,41 @@ def midas():
|
||||
|
||||
|
||||
@midas.command()
|
||||
@click.option("-l", "--highlight", type=click.File("w"))
|
||||
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||
@click.argument("file", type=click.File("r"))
|
||||
def compile(file: TextIO):
|
||||
raise NotImplementedError
|
||||
def compile(highlight: Optional[TextIO], file: TextIO, types: tuple[TextIO]):
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
source: str = file.read()
|
||||
tree: ast.Module = ast.parse(source, filename=file.name)
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
resolver = Resolver()
|
||||
resolver.resolve(*stmts)
|
||||
types_paths: list[Path] = [Path(t.name).resolve() for t in types]
|
||||
checker = Checker(
|
||||
resolver.locals,
|
||||
source_path=Path(file.name).resolve(),
|
||||
types_paths=types_paths,
|
||||
)
|
||||
diagnostics: list[Diagnostic] = checker.check(stmts)
|
||||
for diagnostic in diagnostics:
|
||||
print(diagnostic)
|
||||
|
||||
print(
|
||||
json.dumps(
|
||||
UniversalJSONDumper.dump(
|
||||
checker.global_env,
|
||||
[("Environment", "_children")],
|
||||
lambda obj: isinstance(obj, get_args(Type)),
|
||||
),
|
||||
indent=4,
|
||||
)
|
||||
)
|
||||
if highlight is not None:
|
||||
highlighter = DiagnosticsHighlighter(source)
|
||||
highlighter.highlight(diagnostics)
|
||||
highlighter.dump(highlight)
|
||||
|
||||
|
||||
@midas.group()
|
||||
@@ -31,26 +75,52 @@ def utils():
|
||||
pass
|
||||
|
||||
|
||||
def dump_python_ast(tree: ast.Module) -> str:
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
printer = PythonAstPrinter()
|
||||
dump: str = ""
|
||||
for stmt in stmts:
|
||||
dump += printer.print(stmt)
|
||||
dump += "\n"
|
||||
return dump
|
||||
|
||||
|
||||
def dump_midas_ast(source: str, filename: str) -> str:
|
||||
lexer = MidasLexer(source, file=filename)
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
if len(parser.errors) != 0:
|
||||
for err in parser.errors:
|
||||
print(err.get_report())
|
||||
raise RuntimeError("A parsing error occurred")
|
||||
printer = MidasAstPrinter()
|
||||
dump: str = ""
|
||||
for stmt in stmts:
|
||||
dump += printer.print(stmt)
|
||||
dump += "\n"
|
||||
return dump
|
||||
|
||||
|
||||
@utils.command()
|
||||
@click.option("-o", "--output", type=click.File("w"))
|
||||
@click.option("-p", "--parse", is_flag=True)
|
||||
@click.argument("file", type=click.File("r"))
|
||||
def dump_ast(output: Optional[TextIO], parse: bool, file: TextIO):
|
||||
source: str = file.read()
|
||||
tree: ast.Module = ast.parse(source, filename=file.name)
|
||||
|
||||
dump: str
|
||||
|
||||
if file.name.endswith(".py"):
|
||||
tree: ast.Module = ast.parse(source, filename=file.name)
|
||||
if parse:
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
printer = PythonAstPrinter()
|
||||
dump = ""
|
||||
for stmt in stmts:
|
||||
dump += printer.print(stmt)
|
||||
dump += "\n"
|
||||
|
||||
dump = dump_python_ast(tree)
|
||||
else:
|
||||
dump = ast.dump(tree, indent=4)
|
||||
elif file.name.endswith(".midas"):
|
||||
dump = dump_midas_ast(source, file.name)
|
||||
else:
|
||||
raise ValueError("Unsupported file type")
|
||||
|
||||
if output is None:
|
||||
click.echo(dump)
|
||||
@@ -77,14 +147,6 @@ def highlight_midas(source: str, path: str) -> Highlighter:
|
||||
for err in parser.errors:
|
||||
print(err.get_report())
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LocatableToken:
|
||||
token: Token
|
||||
|
||||
@property
|
||||
def location(self) -> Location:
|
||||
return self.token.get_location()
|
||||
|
||||
for stmt in stmts:
|
||||
highlighter.highlight(stmt)
|
||||
for token in tokens:
|
||||
@@ -109,3 +171,23 @@ def highlight(output: TextIO, file: TextIO):
|
||||
else:
|
||||
raise ValueError("Unsupported file type")
|
||||
highlighter.dump(output)
|
||||
|
||||
|
||||
@midas.command()
|
||||
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
def format(output: TextIO, file: TextIO):
|
||||
source: str = file.read()
|
||||
printer = MidasPrinter()
|
||||
lexer = MidasLexer(source, file=file.name)
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
for err in parser.errors:
|
||||
print(err.get_report())
|
||||
for stmt in stmts:
|
||||
output.write(printer.print(stmt) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
midas()
|
||||
|
||||
@@ -40,8 +40,8 @@ class MidasLexer(Lexer):
|
||||
self.add_token(TokenType.AND)
|
||||
case "?":
|
||||
self.add_token(TokenType.QMARK)
|
||||
# case ",":
|
||||
# self.add_token(TokenType.COMMA)
|
||||
case ",":
|
||||
self.add_token(TokenType.COMMA)
|
||||
case "_" if not self.is_identifier_char(self.peek_next(), start=False):
|
||||
self.add_token(TokenType.UNDERSCORE)
|
||||
case "-" if self.match(">"):
|
||||
|
||||
@@ -17,7 +17,7 @@ class TokenType(Enum):
|
||||
LEFT_BRACE = auto()
|
||||
RIGHT_BRACE = auto()
|
||||
COLON = auto()
|
||||
# COMMA = auto()
|
||||
COMMA = auto()
|
||||
UNDERSCORE = auto()
|
||||
ARROW = auto()
|
||||
AND = auto()
|
||||
|
||||
@@ -3,21 +3,22 @@ from typing import Optional
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.midas import (
|
||||
BinaryExpr,
|
||||
ComplexTypeStmt,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
GenericType,
|
||||
GetExpr,
|
||||
GroupingExpr,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
NamedType,
|
||||
OpStmt,
|
||||
PredicateStmt,
|
||||
PropertyStmt,
|
||||
SimpleTypeExpr,
|
||||
SimpleTypeStmt,
|
||||
Stmt,
|
||||
TemplateExpr,
|
||||
TypeExpr,
|
||||
Type,
|
||||
TypeStmt,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
WildcardExpr,
|
||||
@@ -81,7 +82,7 @@ class MidasParser(Parser):
|
||||
self.synchronize()
|
||||
return None
|
||||
|
||||
def type_declaration(self) -> SimpleTypeStmt | ComplexTypeStmt:
|
||||
def type_declaration(self) -> TypeStmt:
|
||||
"""Parse a type declaration
|
||||
|
||||
A type declaration can either be a simple type alias or a new complex type.
|
||||
@@ -107,33 +108,22 @@ class MidasParser(Parser):
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
template: Optional[TemplateExpr] = None
|
||||
params: list[TypeStmt.Param] = []
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
template = self.template_expr()
|
||||
params = self.type_stmt_params()
|
||||
|
||||
if self.match(TokenType.LEFT_PAREN):
|
||||
base: TypeExpr = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Unclosed base type parenthesis")
|
||||
constraint: Optional[Expr] = None
|
||||
if self.match(TokenType.WHERE):
|
||||
constraint = self.constraint()
|
||||
return SimpleTypeStmt(
|
||||
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
|
||||
|
||||
type: Type = self.type_expr()
|
||||
|
||||
return TypeStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
name=name,
|
||||
template=template,
|
||||
base=base,
|
||||
constraint=constraint,
|
||||
)
|
||||
else:
|
||||
properties: list[PropertyStmt] = self.type_properties()
|
||||
return ComplexTypeStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
name=name,
|
||||
template=template,
|
||||
properties=properties,
|
||||
params=params,
|
||||
type=type,
|
||||
)
|
||||
|
||||
def template_expr(self) -> TemplateExpr:
|
||||
def type_stmt_params(self) -> list[TypeStmt.Param]:
|
||||
"""Parse a generic template expression
|
||||
|
||||
A template is written `[TypeExpr]`
|
||||
@@ -141,16 +131,27 @@ class MidasParser(Parser):
|
||||
Returns:
|
||||
TemplateExpr: the parsed template expression
|
||||
"""
|
||||
left: Token = self.consume(
|
||||
TokenType.LEFT_BRACKET, "Missing '[' before template expression"
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
|
||||
params: list[TypeStmt.Param] = []
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
|
||||
bound: Optional[Type] = None
|
||||
if self.match(TokenType.LESS):
|
||||
self.consume(TokenType.COLON, "Expected ':' after '<'")
|
||||
bound = self.type_expr()
|
||||
params.append(
|
||||
TypeStmt.Param(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
bound=bound,
|
||||
)
|
||||
type: TypeExpr = self.type_expr()
|
||||
right: Token = self.consume(
|
||||
TokenType.RIGHT_BRACKET, "Missing ']' after template expression"
|
||||
)
|
||||
return TemplateExpr(location=left.location_to(right), type=type)
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
|
||||
return params
|
||||
|
||||
def type_expr(self) -> TypeExpr:
|
||||
def type_expr(self) -> Type:
|
||||
"""Parse a type expression
|
||||
|
||||
A type is an identifier, optionally followed by a template expression.
|
||||
@@ -159,30 +160,82 @@ class MidasParser(Parser):
|
||||
Returns:
|
||||
TypeExpr: the parsed type expression
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
template: Optional[TemplateExpr] = None
|
||||
return self.constraint_type()
|
||||
|
||||
def constraint_type(self) -> Type:
|
||||
type: Type = self.base_type()
|
||||
if self.match(TokenType.WHERE):
|
||||
constraint: Expr = self.constraint()
|
||||
return ConstraintType(
|
||||
location=Location.span(type.location, constraint.location),
|
||||
type=type,
|
||||
constraint=constraint,
|
||||
)
|
||||
return type
|
||||
|
||||
def base_type(self) -> Type:
|
||||
if self.match(TokenType.LEFT_PAREN):
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
|
||||
return type
|
||||
|
||||
if self.check(TokenType.LEFT_BRACE):
|
||||
return self.complex_type()
|
||||
|
||||
return self.generic_type()
|
||||
|
||||
def generic_type(self) -> Type:
|
||||
type: Type = self.named_type()
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
template = self.template_expr()
|
||||
optional: bool = self.match(TokenType.QMARK)
|
||||
return TypeExpr(
|
||||
location=name.location_to(self.previous()),
|
||||
params: list[Type] = self.type_params()
|
||||
return GenericType(
|
||||
location=Location.span(type.location, self.previous().get_location()),
|
||||
type=type,
|
||||
params=params,
|
||||
)
|
||||
return type
|
||||
|
||||
def type_params(self) -> list[Type]:
|
||||
params: list[Type] = []
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters")
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||
params.append(self.type_expr())
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters")
|
||||
return params
|
||||
|
||||
def named_type(self) -> Type:
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
return NamedType(
|
||||
location=name.get_location(),
|
||||
name=name,
|
||||
template=template,
|
||||
optional=optional,
|
||||
)
|
||||
|
||||
def simple_type_expr(self) -> SimpleTypeExpr:
|
||||
"""Parse a simple type expression
|
||||
def complex_type(self) -> Type:
|
||||
"""Parse a type definition body
|
||||
|
||||
A simple type is just an identifier optionally followed by a '?'
|
||||
A type definition body is a set of whitespace-separated
|
||||
property statements enclosed in curly braces
|
||||
|
||||
Returns:
|
||||
SimpleTypeExpr: the parsed simple type expression
|
||||
list[PropertyStmt]: the parsed type properties
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
optional: bool = self.match(TokenType.QMARK)
|
||||
return SimpleTypeExpr(
|
||||
location=name.location_to(self.previous()), name=name, optional=optional
|
||||
left: Token = self.consume(
|
||||
TokenType.LEFT_BRACE, "Expected '{' to start type body"
|
||||
)
|
||||
properties: list[PropertyStmt] = []
|
||||
names: set[str] = set()
|
||||
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
||||
prop: PropertyStmt = self.property_stmt()
|
||||
if prop.name.lexeme in names:
|
||||
raise self.error(prop.name, "Duplicate property")
|
||||
names.add(prop.name.lexeme)
|
||||
properties.append(prop)
|
||||
right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
|
||||
return ComplexType(
|
||||
location=left.location_to(right),
|
||||
properties=properties,
|
||||
)
|
||||
|
||||
def constraint(self) -> Expr:
|
||||
@@ -205,9 +258,7 @@ class MidasParser(Parser):
|
||||
while self.match(TokenType.AND):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.equality()
|
||||
location: Optional[Location] = None
|
||||
if expr.location and right.location:
|
||||
location = Location.span(expr.location, right.location)
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = LogicalExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
@@ -223,9 +274,7 @@ class MidasParser(Parser):
|
||||
while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.comparison()
|
||||
location: Optional[Location] = None
|
||||
if expr.location and right.location:
|
||||
location = Location.span(expr.location, right.location)
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = BinaryExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
@@ -246,9 +295,7 @@ class MidasParser(Parser):
|
||||
):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.unary()
|
||||
location: Optional[Location] = None
|
||||
if expr.location and right.location:
|
||||
location = Location.span(expr.location, right.location)
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = BinaryExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
@@ -263,9 +310,7 @@ class MidasParser(Parser):
|
||||
if self.match(TokenType.MINUS):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.unary()
|
||||
location: Optional[Location] = None
|
||||
if right.location:
|
||||
location = Location.span(operator.get_location(), right.location)
|
||||
location: Location = Location.span(operator.get_location(), right.location)
|
||||
return UnaryExpr(location=location, operator=operator, right=right)
|
||||
return self.reference()
|
||||
|
||||
@@ -280,9 +325,7 @@ class MidasParser(Parser):
|
||||
name: Token = self.consume(
|
||||
TokenType.IDENTIFIER, "Expected property name after '.'"
|
||||
)
|
||||
location: Optional[Location] = None
|
||||
if expr.location:
|
||||
location = Location.span(expr.location, name.get_location())
|
||||
location: Location = Location.span(expr.location, name.get_location())
|
||||
expr = GetExpr(location=location, expr=expr, name=name)
|
||||
return expr
|
||||
|
||||
@@ -318,22 +361,6 @@ class MidasParser(Parser):
|
||||
|
||||
raise self.error(self.peek(), "Expected expression")
|
||||
|
||||
def type_properties(self) -> list[PropertyStmt]:
|
||||
"""Parse a type definition body
|
||||
|
||||
A type definition body is a set of whitespace-separated
|
||||
property statements enclosed in curly braces
|
||||
|
||||
Returns:
|
||||
list[PropertyStmt]: the parsed type properties
|
||||
"""
|
||||
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start type body")
|
||||
properties: list[PropertyStmt] = []
|
||||
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
||||
properties.append(self.property_stmt())
|
||||
self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
|
||||
return properties
|
||||
|
||||
def property_stmt(self) -> PropertyStmt:
|
||||
"""Parse a property statement
|
||||
|
||||
@@ -344,15 +371,11 @@ class MidasParser(Parser):
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after property name")
|
||||
type: TypeExpr = self.type_expr()
|
||||
constraint: Optional[Expr] = None
|
||||
if self.match(TokenType.WHERE):
|
||||
constraint = self.constraint()
|
||||
type: Type = self.type_expr()
|
||||
return PropertyStmt(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
type=type,
|
||||
constraint=constraint,
|
||||
)
|
||||
|
||||
def extend_declaration(self) -> ExtendStmt:
|
||||
@@ -364,15 +387,13 @@ class MidasParser(Parser):
|
||||
ExtendStmt: the parsed extension statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
type: TypeExpr = self.type_expr()
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
|
||||
operations: list[OpStmt] = []
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
|
||||
operations.append(self.op_declaration())
|
||||
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
|
||||
location: Optional[Location] = None
|
||||
if type.location:
|
||||
location = keyword.location_to(self.previous())
|
||||
location: Location = keyword.location_to(self.previous())
|
||||
return ExtendStmt(location=location, type=type, operations=operations)
|
||||
|
||||
def op_declaration(self) -> OpStmt:
|
||||
@@ -387,11 +408,11 @@ class MidasParser(Parser):
|
||||
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name")
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type")
|
||||
operand: TypeExpr = self.type_expr()
|
||||
operand: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type")
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
result: TypeExpr = self.type_expr()
|
||||
result: Type = self.type_expr()
|
||||
|
||||
return OpStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
@@ -413,7 +434,7 @@ class MidasParser(Parser):
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
||||
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
||||
type: TypeExpr = self.type_expr()
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
||||
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
|
||||
condition: Expr = self.constraint()
|
||||
|
||||
@@ -2,12 +2,12 @@ import ast
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
|
||||
from midas.ast.python import (
|
||||
AssignStmt,
|
||||
BaseType,
|
||||
BinaryExpr,
|
||||
CallExpr,
|
||||
CastExpr,
|
||||
CompareExpr,
|
||||
ConstraintType,
|
||||
Expr,
|
||||
@@ -16,10 +16,13 @@ from midas.ast.python import (
|
||||
FrameType,
|
||||
Function,
|
||||
GetExpr,
|
||||
IfStmt,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
MidasType,
|
||||
ReturnStmt,
|
||||
Stmt,
|
||||
TernaryExpr,
|
||||
TypeAssign,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
@@ -38,6 +41,8 @@ class UnsupportedSyntaxError(Exception):
|
||||
|
||||
|
||||
class PythonParser:
|
||||
CAST_FUNCTION = "cast"
|
||||
|
||||
def parse_module(self, node: ast.Module) -> list[Stmt]:
|
||||
statements: list[Stmt] = []
|
||||
for stmt in node.body:
|
||||
@@ -53,6 +58,7 @@ class PythonParser:
|
||||
return statements
|
||||
|
||||
def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]:
|
||||
location: Location = Location.from_ast(node)
|
||||
match node:
|
||||
case ast.AnnAssign():
|
||||
return self.parse_annotation_assign(node)
|
||||
@@ -60,11 +66,29 @@ class PythonParser:
|
||||
case ast.Assign():
|
||||
return self.parse_assign(node)
|
||||
|
||||
case ast.AugAssign():
|
||||
return self.parse_aug_assign(node)
|
||||
|
||||
case ast.FunctionDef():
|
||||
return self.parse_function(node)
|
||||
|
||||
case ast.Expr(value=expr):
|
||||
return ExpressionStmt(expr=self.parse_expr(expr))
|
||||
return ExpressionStmt(
|
||||
location=location,
|
||||
expr=self.parse_expr(expr),
|
||||
)
|
||||
|
||||
case ast.Return(value=value):
|
||||
return ReturnStmt(
|
||||
location=location,
|
||||
value=self.parse_expr(value) if value is not None else None,
|
||||
)
|
||||
|
||||
case ast.If():
|
||||
return self.parse_if(node)
|
||||
|
||||
case ast.Pass():
|
||||
return None
|
||||
|
||||
case _:
|
||||
print(f"Unsupported statement: {ast.unparse(node)}")
|
||||
@@ -80,8 +104,7 @@ class PythonParser:
|
||||
value=value,
|
||||
simple=1,
|
||||
):
|
||||
type = self._parse_type(annotation, root=True)
|
||||
if type is not None:
|
||||
type = self._parse_type(annotation)
|
||||
statements.append(
|
||||
TypeAssign(
|
||||
location=loc,
|
||||
@@ -117,6 +140,45 @@ class PythonParser:
|
||||
value=value,
|
||||
)
|
||||
|
||||
def parse_aug_assign(self, node: ast.AugAssign) -> AssignStmt:
|
||||
location: Location = Location.from_ast(node)
|
||||
target: Expr = self.parse_expr(node.target)
|
||||
value: Expr = self.parse_expr(node.value)
|
||||
return AssignStmt(
|
||||
location=location,
|
||||
targets=[target],
|
||||
value=BinaryExpr(
|
||||
location=location,
|
||||
left=target,
|
||||
operator=node.op,
|
||||
right=value,
|
||||
),
|
||||
)
|
||||
|
||||
def parse_if(self, node: ast.If) -> IfStmt:
|
||||
body: list[Stmt] = []
|
||||
for stmt in node.body:
|
||||
stmts = self.parse_stmt(stmt)
|
||||
if isinstance(stmts, Stmt):
|
||||
body.append(stmts)
|
||||
elif stmts is not None:
|
||||
body.extend(stmts)
|
||||
|
||||
orelse: list[Stmt] = []
|
||||
for stmt in node.orelse:
|
||||
stmts = self.parse_stmt(stmt)
|
||||
if isinstance(stmts, Stmt):
|
||||
orelse.append(stmts)
|
||||
elif stmts is not None:
|
||||
orelse.extend(stmts)
|
||||
|
||||
return IfStmt(
|
||||
location=Location.from_ast(node),
|
||||
test=self.parse_expr(node.test),
|
||||
body=body,
|
||||
orelse=orelse,
|
||||
)
|
||||
|
||||
def parse_function(self, node: ast.FunctionDef) -> Function:
|
||||
loc: Location = Location.from_ast(node)
|
||||
match node:
|
||||
@@ -125,26 +187,74 @@ class PythonParser:
|
||||
args=ast.arguments(
|
||||
posonlyargs=posonlyargs,
|
||||
args=args,
|
||||
vararg=sink,
|
||||
kwonlyargs=kwonlyargs,
|
||||
kwarg=kw_sink,
|
||||
defaults=defaults,
|
||||
kw_defaults=kw_defaults,
|
||||
),
|
||||
returns=returns,
|
||||
body=raw_body,
|
||||
):
|
||||
|
||||
def parse_args(args_list: list[ast.arg]) -> list[Function.Argument]:
|
||||
return [self._parse_function_argument(arg) for arg in args_list]
|
||||
def parse_args(
|
||||
args_list: list[ast.arg], defaults: list[Optional[Expr]]
|
||||
) -> list[Function.Argument]:
|
||||
return [
|
||||
self._parse_function_argument(arg, default)
|
||||
for arg, default in zip(args_list, defaults)
|
||||
]
|
||||
|
||||
body: list[Stmt] = []
|
||||
for stmt in raw_body:
|
||||
stmts = self.parse_stmt(stmt)
|
||||
if isinstance(stmts, Stmt):
|
||||
body.append(stmts)
|
||||
elif stmts is not None:
|
||||
body.extend(stmts)
|
||||
|
||||
parsed_defaults: list[Optional[Expr]] = [
|
||||
self.parse_expr(default) for default in defaults
|
||||
]
|
||||
n_posargs: int = len(posonlyargs)
|
||||
n_args: int = len(args)
|
||||
n_all_posargs = n_posargs + n_args
|
||||
parsed_defaults = [
|
||||
None,
|
||||
] * (n_all_posargs - len(defaults)) + parsed_defaults
|
||||
|
||||
posargs_defaults: list[Optional[Expr]] = parsed_defaults[:n_posargs]
|
||||
args_defaults: list[Optional[Expr]] = parsed_defaults[n_posargs:]
|
||||
kwargs_defaults: list[Optional[Expr]] = [
|
||||
self.parse_expr(default) if default is not None else None
|
||||
for default in kw_defaults
|
||||
]
|
||||
|
||||
return Function(
|
||||
location=loc,
|
||||
name=name,
|
||||
posonlyargs=parse_args(posonlyargs),
|
||||
args=parse_args(args),
|
||||
kwonlyargs=parse_args(kwonlyargs),
|
||||
posonlyargs=parse_args(posonlyargs, posargs_defaults),
|
||||
args=parse_args(args, args_defaults),
|
||||
sink=(
|
||||
self._parse_function_argument(sink, None)
|
||||
if sink is not None
|
||||
else None
|
||||
),
|
||||
kwonlyargs=parse_args(kwonlyargs, kwargs_defaults),
|
||||
kw_sink=(
|
||||
self._parse_function_argument(kw_sink, None)
|
||||
if kw_sink is not None
|
||||
else None
|
||||
),
|
||||
returns=self._parse_type(returns) if returns is not None else None,
|
||||
body=body,
|
||||
)
|
||||
case _:
|
||||
print(f"Unsupported function definition: {ast.unparse(node)}")
|
||||
|
||||
def _parse_function_argument(self, arg: ast.arg) -> Function.Argument:
|
||||
def _parse_function_argument(
|
||||
self, arg: ast.arg, default: Optional[Expr]
|
||||
) -> Function.Argument:
|
||||
loc: Location = Location.from_ast(arg)
|
||||
name: str = arg.arg
|
||||
type: Optional[MidasType] = None
|
||||
@@ -154,11 +264,10 @@ class PythonParser:
|
||||
location=loc,
|
||||
name=name,
|
||||
type=type,
|
||||
default=default,
|
||||
)
|
||||
|
||||
def _parse_type(
|
||||
self, type_expr: ast.expr, root: bool = False
|
||||
) -> Optional[MidasType]:
|
||||
def _parse_type(self, type_expr: ast.expr) -> MidasType:
|
||||
loc: Location = Location.from_ast(type_expr)
|
||||
match type_expr:
|
||||
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
|
||||
@@ -205,9 +314,14 @@ class PythonParser:
|
||||
constraint=right_expr,
|
||||
)
|
||||
|
||||
case ast.Constant(value=None):
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base="None",
|
||||
param=None,
|
||||
)
|
||||
|
||||
case _:
|
||||
if root:
|
||||
return None
|
||||
raise UnsupportedSyntaxError(type_expr)
|
||||
|
||||
def _parse_frame_type(self, schema: ast.expr) -> FrameType:
|
||||
@@ -257,12 +371,14 @@ class PythonParser:
|
||||
raise UnsupportedSyntaxError(column)
|
||||
|
||||
def parse_expr(self, node: ast.expr) -> Expr:
|
||||
location: Location = Location.from_ast(node)
|
||||
match node:
|
||||
case ast.BoolOp():
|
||||
return self.parse_bool_op(node)
|
||||
|
||||
case ast.BinOp(left=left, op=op, right=right):
|
||||
return BinaryExpr(
|
||||
location=location,
|
||||
left=self.parse_expr(left),
|
||||
operator=op,
|
||||
right=self.parse_expr(right),
|
||||
@@ -270,6 +386,7 @@ class PythonParser:
|
||||
|
||||
case ast.UnaryOp(op=op, operand=right):
|
||||
return UnaryExpr(
|
||||
location=location,
|
||||
operator=op,
|
||||
right=self.parse_expr(right),
|
||||
)
|
||||
@@ -277,62 +394,96 @@ class PythonParser:
|
||||
case ast.Compare():
|
||||
return self.parse_compare(node)
|
||||
|
||||
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
|
||||
return self.parse_cast(node)
|
||||
|
||||
case ast.Call():
|
||||
return self.parse_call(node)
|
||||
|
||||
case ast.IfExp():
|
||||
return self.parse_ternary(node)
|
||||
|
||||
case ast.Constant(value=value):
|
||||
return LiteralExpr(value=value)
|
||||
return LiteralExpr(location=location, value=value)
|
||||
|
||||
case ast.Attribute(value=object, attr=name):
|
||||
return GetExpr(
|
||||
location=location,
|
||||
object=self.parse_expr(object),
|
||||
name=name,
|
||||
)
|
||||
|
||||
case ast.Name(id=name):
|
||||
return VariableExpr(name=name)
|
||||
return VariableExpr(location=location, name=name)
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(node)
|
||||
|
||||
def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr:
|
||||
op: ast.boolop = node.op
|
||||
values: list[ast.expr] = node.values
|
||||
rights: list[Expr] = [self.parse_expr(expr) for expr in node.values]
|
||||
expr: LogicalExpr = LogicalExpr(
|
||||
left=self.parse_expr(values[0]),
|
||||
location=Location.span(
|
||||
rights[0].location,
|
||||
rights[1].location,
|
||||
),
|
||||
left=rights[0],
|
||||
operator=op,
|
||||
right=self.parse_expr(values[1]),
|
||||
right=rights[1],
|
||||
)
|
||||
for value in values[2:]:
|
||||
for right in rights[2:]:
|
||||
expr = LogicalExpr(
|
||||
location=Location.span(expr.location, right.location),
|
||||
left=expr,
|
||||
operator=op,
|
||||
right=self.parse_expr(value),
|
||||
right=right,
|
||||
)
|
||||
return expr
|
||||
|
||||
def parse_compare(self, node: ast.Compare) -> Expr:
|
||||
ops: list[ast.cmpop] = node.ops
|
||||
left: Expr = self.parse_expr(node.left)
|
||||
rights: list[Expr] = [self.parse_expr(expr) for expr in node.comparators]
|
||||
expr: Expr = CompareExpr(
|
||||
left=self.parse_expr(node.left),
|
||||
location=Location.span(
|
||||
left.location,
|
||||
rights[0].location,
|
||||
),
|
||||
left=left,
|
||||
operator=ops[0],
|
||||
right=rights[0],
|
||||
)
|
||||
for i, right in enumerate(rights[1:]):
|
||||
expr = LogicalExpr(
|
||||
left=expr,
|
||||
operator=ast.And(),
|
||||
right=CompareExpr(
|
||||
comparison = CompareExpr(
|
||||
location=Location.span(rights[i].location, right.location),
|
||||
left=rights[i],
|
||||
operator=ops[i],
|
||||
right=right,
|
||||
),
|
||||
)
|
||||
expr = LogicalExpr(
|
||||
location=Location.span(expr.location, comparison.location),
|
||||
left=expr,
|
||||
operator=ast.And(),
|
||||
right=comparison,
|
||||
)
|
||||
return expr
|
||||
|
||||
def parse_cast(self, node: ast.Call) -> CastExpr:
|
||||
match node:
|
||||
case ast.Call(args=[type, expr], keywords=[]):
|
||||
return CastExpr(
|
||||
location=Location.from_ast(node),
|
||||
type=self._parse_type(type),
|
||||
expr=self.parse_expr(expr),
|
||||
)
|
||||
case _:
|
||||
raise InvalidSyntaxError(
|
||||
f"Invalid call to {self.CAST_FUNCTION}, expected type and expression"
|
||||
)
|
||||
|
||||
def parse_call(self, node: ast.Call) -> CallExpr:
|
||||
return CallExpr(
|
||||
location=Location.from_ast(node),
|
||||
callee=self.parse_expr(node.func),
|
||||
arguments=[self.parse_expr(arg) for arg in node.args],
|
||||
keywords={
|
||||
@@ -341,3 +492,11 @@ class PythonParser:
|
||||
if arg.arg is not None # Should always be True, type checker happy
|
||||
},
|
||||
)
|
||||
|
||||
def parse_ternary(self, node: ast.IfExp) -> TernaryExpr:
|
||||
return TernaryExpr(
|
||||
location=Location.from_ast(node),
|
||||
test=self.parse_expr(node.test),
|
||||
if_true=self.parse_expr(node.body),
|
||||
if_false=self.parse_expr(node.orelse),
|
||||
)
|
||||
|
||||
72
midas/resolver/builtin.py
Normal file
72
midas/resolver/builtin.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from midas.checker.types import BaseType, Type, UnitType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.resolver.midas import MidasResolver
|
||||
|
||||
|
||||
def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
|
||||
ctx.define_operation(
|
||||
left=t1,
|
||||
operator=operator,
|
||||
right=t2,
|
||||
result=t3,
|
||||
)
|
||||
|
||||
|
||||
def basic_op(ctx: MidasResolver, type: Type, op: str):
|
||||
ctx.define_operation(
|
||||
left=type,
|
||||
operator=op,
|
||||
right=type,
|
||||
result=type,
|
||||
)
|
||||
|
||||
|
||||
def define_builtins(ctx: MidasResolver):
|
||||
"""Define builtin types and operations"""
|
||||
unit = ctx.define_type("None", UnitType())
|
||||
bool = ctx.define_type("bool", BaseType(name="bool"))
|
||||
int = ctx.define_type("int", BaseType(name="int"))
|
||||
float = ctx.define_type("float", BaseType(name="float"))
|
||||
str = ctx.define_type("str", BaseType(name="str"))
|
||||
|
||||
basic_op(ctx, int, "__add__") # int + int = int
|
||||
basic_op(ctx, int, "__sub__") # int - int = int
|
||||
basic_op(ctx, int, "__mul__") # int * int = int
|
||||
basic_op(ctx, int, "__pow__") # int ** int = int
|
||||
basic_op(ctx, int, "__mod__") # int % int = int
|
||||
basic_op(ctx, int, "__and__") # int & int = int
|
||||
basic_op(ctx, int, "__or__") # int | int = int
|
||||
basic_op(ctx, int, "__xor__") # int ^ int = int
|
||||
op(ctx, int, "__lt__", int, bool) # int < int = bool
|
||||
op(ctx, int, "__gt__", int, bool) # int > int = bool
|
||||
op(ctx, int, "__le__", int, bool) # int <= int = bool
|
||||
op(ctx, int, "__ge__", int, bool) # int >= int = bool
|
||||
op(ctx, int, "__eq__", int, bool) # int == int = bool
|
||||
basic_op(ctx, float, "__add__") # float + float = float
|
||||
basic_op(ctx, float, "__sub__") # float - float = float
|
||||
basic_op(ctx, float, "__mul__") # float * float = float
|
||||
basic_op(ctx, float, "__truediv__") # float / float = float
|
||||
op(ctx, float, "__lt__", float, bool) # float < float = bool
|
||||
op(ctx, float, "__gt__", float, bool) # float > float = bool
|
||||
op(ctx, float, "__le__", float, bool) # float <= float = bool
|
||||
op(ctx, float, "__ge__", float, bool) # float >= float = bool
|
||||
op(ctx, float, "__eq__", float, bool) # float == float = bool
|
||||
basic_op(ctx, str, "__add__") # str + str = str
|
||||
op(ctx, str, "__eq__", str, bool) # str == str = bool
|
||||
|
||||
op(ctx, int, "__lt__", float, bool) # int < float = bool
|
||||
op(ctx, int, "__gt__", float, bool) # int > float = bool
|
||||
op(ctx, int, "__le__", float, bool) # int <= float = bool
|
||||
op(ctx, int, "__ge__", float, bool) # int >= float = bool
|
||||
op(ctx, int, "__eq__", float, bool) # int == float = bool
|
||||
|
||||
op(ctx, float, "__lt__", int, bool) # float < int = bool
|
||||
op(ctx, float, "__gt__", int, bool) # float > int = bool
|
||||
op(ctx, float, "__le__", int, bool) # float <= int = bool
|
||||
op(ctx, float, "__ge__", int, bool) # float >= int = bool
|
||||
op(ctx, float, "__eq__", int, bool) # float == int = bool
|
||||
163
midas/resolver/midas.py
Normal file
163
midas/resolver/midas.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.resolver.builtin import define_builtins
|
||||
|
||||
|
||||
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._types: dict[str, Type] = {}
|
||||
self._operations: dict[tuple[Type, str, Type], Type] = {}
|
||||
|
||||
define_builtins(self)
|
||||
|
||||
def get_type(self, name: str) -> Type:
|
||||
"""Get a type from its name
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
|
||||
Raises:
|
||||
NameError: if the type is not defined
|
||||
|
||||
Returns:
|
||||
Type: the type
|
||||
"""
|
||||
type: Optional[Type] = self._types.get(name)
|
||||
if type is None:
|
||||
raise NameError(f"Undefined type {name}")
|
||||
return type
|
||||
|
||||
def get_operation_result(
|
||||
self, left: Type, operator: str, right: Type
|
||||
) -> Optional[Type]:
|
||||
"""Get the resulting type of an operation
|
||||
|
||||
Args:
|
||||
left (Type): the type of the left operand
|
||||
operator (str): the operation name
|
||||
right (Type): the type of the right operand
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the result type, or None if no matching operation was found
|
||||
"""
|
||||
operation: tuple[Type, str, Type] = (left, operator, right)
|
||||
result: Optional[Type] = self._operations.get(operation)
|
||||
return result
|
||||
|
||||
def define_type(self, name: str, type: Type) -> Type:
|
||||
"""Define a type in the registry
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
type (Type): the type to define
|
||||
|
||||
Raises:
|
||||
ValueError: if a type is already defined with that name
|
||||
|
||||
Returns:
|
||||
Type: the defined type
|
||||
"""
|
||||
if name in self._types:
|
||||
raise ValueError(f"Type {name} already defined")
|
||||
self._types[name] = type
|
||||
return type
|
||||
|
||||
def define_operation(self, left: Type, operator: str, right: Type, result: Type):
|
||||
"""Define an operation in the registry
|
||||
|
||||
Args:
|
||||
left (Type): the type of the left operand
|
||||
operator (str): the operation name
|
||||
right (Type): the type of the right operand
|
||||
result (Type): the result type
|
||||
|
||||
Raises:
|
||||
ValueError: if an operation is already defined with these operands and name
|
||||
"""
|
||||
operation: tuple[Type, str, Type] = (left, operator, right)
|
||||
if operation in self._operations:
|
||||
raise ValueError(
|
||||
f"Operation {operator} already defined between {left} and {right}"
|
||||
)
|
||||
self._operations[operation] = result
|
||||
|
||||
def resolve(self, stmts: list[m.Stmt]):
|
||||
"""Process a sequence of statements
|
||||
|
||||
Args:
|
||||
stmts (list[m.Stmt]): the statements
|
||||
"""
|
||||
for stmt in stmts:
|
||||
stmt.accept(self)
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
type: Type = stmt.type.accept(self)
|
||||
for param in stmt.params:
|
||||
if param.bound is not None:
|
||||
param.bound.accept(self)
|
||||
name: str = stmt.name.lexeme
|
||||
self.define_type(name, AliasType(name=name, type=type))
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
base: Type = stmt.type.accept(self)
|
||||
for op in stmt.operations:
|
||||
right: Type = op.operand.accept(self)
|
||||
result: Type = op.result.accept(self)
|
||||
self.define_operation(
|
||||
left=base,
|
||||
operator=op.name.lexeme,
|
||||
right=right,
|
||||
result=result,
|
||||
)
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None: ...
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> None: ...
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
||||
return expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||
return self.get_type(type.name.lexeme)
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
params: list[Type] = [param.accept(self) for param in type.params]
|
||||
# TODO
|
||||
return UnknownType()
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
type.constraint.accept(self)
|
||||
# TODO
|
||||
return UnknownType()
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> Type:
|
||||
for prop in type.properties:
|
||||
prop.accept(self)
|
||||
# TODO
|
||||
return UnknownType()
|
||||
187
midas/resolver/resolver.py
Normal file
187
midas/resolver/resolver.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import midas.ast.python as p
|
||||
|
||||
|
||||
class ResolverError(Exception): ...
|
||||
|
||||
|
||||
class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
"""A variable assignment and reference resolver
|
||||
|
||||
This class keeps track of which scope a variable is defined in and which
|
||||
scope is referred to when a variable is referenced
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.locals: dict[p.Expr, int] = {}
|
||||
self.scopes: list[dict[str, bool]] = []
|
||||
|
||||
def resolve(self, *objects: p.Stmt | p.Expr) -> None:
|
||||
"""Resolve the given statements or expressions"""
|
||||
|
||||
for obj in objects:
|
||||
obj.accept(self)
|
||||
|
||||
def begin_scope(self):
|
||||
"""Begin a new scope inside the current one"""
|
||||
self.scopes.append({})
|
||||
|
||||
def end_scope(self):
|
||||
"""Close the current scope"""
|
||||
self.scopes.pop()
|
||||
|
||||
def declare(self, name: str) -> None:
|
||||
"""Declare a variable in the current scope
|
||||
|
||||
This method must be called *before* evaluating the variable initializer
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
|
||||
Raises:
|
||||
ResolverError: if the variable has already been declared in the current scope
|
||||
"""
|
||||
if len(self.scopes) == 0:
|
||||
return
|
||||
scope: dict[str, bool] = self.scopes[-1]
|
||||
if name in scope:
|
||||
raise ResolverError(
|
||||
f"A variable with the name {name} is already declared in this scope"
|
||||
)
|
||||
scope[name] = False
|
||||
|
||||
def define(self, name: str) -> None:
|
||||
"""Define a variable in the current scope
|
||||
|
||||
This method must be called *after* evaluating the variable initializer
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
"""
|
||||
if len(self.scopes) == 0:
|
||||
return
|
||||
self.scopes[-1][name] = True
|
||||
|
||||
def resolve_local(self, expr: p.Expr, name: str) -> None:
|
||||
"""Resolve a variable reference and store the scope distance
|
||||
|
||||
This method associates to the variable expression a number representing
|
||||
the "distance" of the variable declaration, i.e. the number of scope
|
||||
levels to go "up" to find the closest declaration for that variable.
|
||||
|
||||
Args:
|
||||
expr (p.Expr): the variable expression
|
||||
name (str): the name of the variable
|
||||
"""
|
||||
for i, scope in enumerate(reversed(self.scopes)):
|
||||
if name in scope:
|
||||
self.locals[expr] = i
|
||||
return
|
||||
|
||||
def resolve_function(self, function: p.Function) -> None:
|
||||
"""Resolve a function definition
|
||||
|
||||
This method creates a new scope for the function, resolves all the
|
||||
parameter declarations and then the body.
|
||||
|
||||
Args:
|
||||
function (p.Function): the function to resolve
|
||||
"""
|
||||
self.begin_scope()
|
||||
for param in function.all_args:
|
||||
self.declare(param.name)
|
||||
self.define(param.name)
|
||||
self.resolve(*function.body)
|
||||
self.end_scope()
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
stmt.expr.accept(self)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
# Declare before resolving body to allow recursion
|
||||
self.declare(stmt.name)
|
||||
self.define(stmt.name)
|
||||
self.resolve_function(stmt)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
self.declare(stmt.name)
|
||||
# NOTE: resolve type here?
|
||||
self.define(stmt.name)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
self.resolve(stmt.value)
|
||||
for target in stmt.targets:
|
||||
match target:
|
||||
case p.VariableExpr(name=name):
|
||||
self.resolve_local(target, name)
|
||||
# TODO: declare if not found
|
||||
case _:
|
||||
raise Exception(f"Unsupported assignment to {target}")
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
if stmt.value is not None:
|
||||
self.resolve(stmt.value)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
# Not resolved in sub-environment because assignments in the test leak out of the if
|
||||
# For example:
|
||||
# if (m := 1 + 1) < 2:
|
||||
# ...
|
||||
# print(m) # <- m is still defined
|
||||
self.resolve(stmt.test)
|
||||
|
||||
# Body
|
||||
self.begin_scope()
|
||||
self.resolve(*stmt.body)
|
||||
self.end_scope()
|
||||
|
||||
# Else
|
||||
self.begin_scope()
|
||||
self.resolve(*stmt.orelse)
|
||||
self.end_scope()
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||
self.resolve(expr.left)
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
|
||||
self.resolve(expr.left)
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||
self.resolve(expr.callee)
|
||||
for arg in expr.arguments:
|
||||
self.resolve(arg)
|
||||
for arg in expr.keywords.values():
|
||||
self.resolve(arg)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> None:
|
||||
self.resolve(expr.object)
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
|
||||
pass
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
|
||||
if len(self.scopes) != 0 and self.scopes[-1].get(expr.name) is False:
|
||||
raise ResolverError(
|
||||
f"Cannot use local variable '{expr.name}' in its own initializer"
|
||||
) # aka. UnboundLocalError
|
||||
self.resolve_local(expr, expr.name)
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
|
||||
self.resolve(expr.left)
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_set_expr(self, expr: p.SetExpr) -> None:
|
||||
self.resolve(expr.value)
|
||||
self.resolve(expr.object)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||
self.resolve(expr.expr)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||
self.resolve(expr.test)
|
||||
self.resolve(expr.if_true)
|
||||
self.resolve(expr.if_false)
|
||||
54
midas/utils.py
Normal file
54
midas/utils.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
AllowRepeat = Callable[[object], bool]
|
||||
|
||||
|
||||
class UniversalJSONDumper:
|
||||
@classmethod
|
||||
def dump(
|
||||
cls,
|
||||
obj: Any,
|
||||
include_keys: Optional[list[str | tuple[str, str]]] = None,
|
||||
allow_repeat: Optional[AllowRepeat] = None,
|
||||
) -> Any:
|
||||
if include_keys is None:
|
||||
include_keys = []
|
||||
return cls._dump(obj, include_keys, allow_repeat, [])
|
||||
|
||||
@classmethod
|
||||
def _dump(
|
||||
cls,
|
||||
obj: Any,
|
||||
include_keys: list[str | tuple[str, str]],
|
||||
allow_repeat: Optional[AllowRepeat],
|
||||
visited: list[Any],
|
||||
) -> Any:
|
||||
if obj in visited:
|
||||
return None
|
||||
match obj:
|
||||
case str() | int() | float() | None:
|
||||
return obj
|
||||
case list() | set() | tuple():
|
||||
return [
|
||||
cls._dump(child, include_keys, allow_repeat, visited)
|
||||
for child in obj
|
||||
]
|
||||
case dict():
|
||||
return {
|
||||
str(k): cls._dump(v, include_keys, allow_repeat, visited)
|
||||
for k, v in obj.items()
|
||||
}
|
||||
case object():
|
||||
if allow_repeat is None or not allow_repeat(obj):
|
||||
visited.append(obj)
|
||||
return {
|
||||
"_type": obj.__class__.__name__,
|
||||
} | {
|
||||
k: cls._dump(v, include_keys, allow_repeat, visited)
|
||||
for k, v in obj.__dict__.items()
|
||||
if not k.startswith("_")
|
||||
or k in include_keys
|
||||
or (obj.__class__.__name__, k) in include_keys
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported value: {obj}")
|
||||
@@ -19,16 +19,24 @@ Comparison ::= Unary (ComparisonOp Unary)*
|
||||
Equality ::= Comparison (EqualityOp Comparison)*
|
||||
Constraint ::= Equality ("&" Equality)*
|
||||
|
||||
SimpleType ::= Identifier "?"?
|
||||
Template ::= "[" Type "]"
|
||||
Type ::= Identifier Template? "?"?
|
||||
TemplateParam ::= Identifier ("<:" Type)?
|
||||
Template ::= "[" (TemplateParam ("," TemplateParam)*)? "]"
|
||||
|
||||
|
||||
TypeProperty ::= Identifier ":" Type
|
||||
ComplexType ::= "{" TypeProperty* "}"
|
||||
NamedType ::= Identifier
|
||||
TypeParams ::= "[" (Type ("," Type)*)? "]"
|
||||
GenericType ::= NamedType TypeParams?
|
||||
GroupedType ::= "(" Type ")"
|
||||
BaseType ::= GroupedType | ComplexType | GenericType
|
||||
ConstraintType ::= BaseType ("where" Constraint)?
|
||||
Type ::= ConstraintType
|
||||
|
||||
TypeProperty ::= Identifier ":" Type ("where" Constraints)?
|
||||
ComplexTypeBody ::= "{" TypeProperty* "}"
|
||||
OpDefinition ::= "op" Identifier "(" Type ")" "->" Type
|
||||
ExtendBody ::= "{" OpDefinition* "}"
|
||||
|
||||
TypeStatement ::= "type" Identifier Template? ("(" Type ")" ("where" Constraint)? | ComplexTypeBody)
|
||||
TypeStatement ::= "type" Identifier Template? "=" Type
|
||||
ExtendStatement ::= "extend" Type ExtendBody
|
||||
PredicateStatement ::= "predicate" Identifier "(" Identifier ":" Type ")" "=" Constraint
|
||||
|
||||
|
||||
@@ -43,28 +43,52 @@ svg.railroad .terminal rect {
|
||||
{[`constraint` 'equality'*"&"]}
|
||||
```
|
||||
|
||||
#let simple-type = ```
|
||||
{[`simple-type` 'identifier' <!, "?">]}
|
||||
#let template-param = ```
|
||||
{[`template-param` 'identifier' <!, ["<:" 'type']>]}
|
||||
```
|
||||
|
||||
#let template = ```
|
||||
{[`template` "[" 'type' "]"]}
|
||||
```
|
||||
|
||||
#let type = ```
|
||||
{[`type` 'identifier' <!, 'template'> <!, "?">]}
|
||||
{[`template` "[" <!, 'template-param'*","> "]"]}
|
||||
```
|
||||
|
||||
#let type-property = ```
|
||||
{[`type-property` 'identifier' ":" 'type' <!, ["where" 'constraint']>]}
|
||||
{[`type-property` 'identifier' ":" 'type']}
|
||||
```
|
||||
|
||||
#let type-body = ```
|
||||
{[`type-body` "{" <!, 'type-property'*!> "}"]}
|
||||
#let complex-type = ```
|
||||
{[`complex-type` "{" <!, 'type-property'*!> "}"]}
|
||||
```
|
||||
|
||||
#let named-type = ```
|
||||
{[`named-type` 'identifier']}
|
||||
```
|
||||
|
||||
#let type-params = ```
|
||||
{[`type-params` "[" <!, 'type'*","> "]"]}
|
||||
```
|
||||
|
||||
#let generic-type = ```
|
||||
{[`generic-type` 'named-type' <!, 'type-params'>]}
|
||||
```
|
||||
|
||||
#let grouped-type = ```
|
||||
{[`grouped-type` "(" 'type' ")"]}
|
||||
```
|
||||
|
||||
#let base-type = ```
|
||||
{[`base-type` <'grouped-type', 'complex-type', 'generic-type'>]}
|
||||
```
|
||||
|
||||
#let constraint-type = ```
|
||||
{[`constraint-type` 'base-type' <!, ["where" 'constraint']>]}
|
||||
```
|
||||
|
||||
#let type = ```
|
||||
{[`type` 'constraint-type']}
|
||||
```
|
||||
|
||||
#let type-statement = ```
|
||||
{[`type-statement` "type" 'identifier' <!, 'template'> <[["(" 'type' ")"] <!, ["where" 'constraint']>], 'type-body'>]}
|
||||
{[`type-statement` "type" 'identifier' <!, 'template'> "=" 'type']}
|
||||
```
|
||||
|
||||
#let op-definition = ```
|
||||
@@ -92,11 +116,17 @@ svg.railroad .terminal rect {
|
||||
comparison: comparison,
|
||||
equality: equality,
|
||||
constraint: constraint,
|
||||
simple-type: simple-type,
|
||||
template-param: template-param,
|
||||
template: template,
|
||||
type: type,
|
||||
type-property: type-property,
|
||||
type-body: type-body,
|
||||
complex-type: complex-type,
|
||||
named-type: named-type,
|
||||
type-params: type-params,
|
||||
generic-type: generic-type,
|
||||
grouped-type: grouped-type,
|
||||
base-type: base-type,
|
||||
constraint-type: constraint-type,
|
||||
type: type,
|
||||
type-statement: type-statement,
|
||||
op-definition: op-definition,
|
||||
extend-statement: extend-statement,
|
||||
@@ -107,10 +137,16 @@ svg.railroad .terminal rect {
|
||||
#let inline = (
|
||||
"grouping",
|
||||
"value",
|
||||
"template-param",
|
||||
"template",
|
||||
"simple-type",
|
||||
"type-property",
|
||||
"type-body",
|
||||
"complex-type",
|
||||
"type-params",
|
||||
"named-type",
|
||||
"grouped-type",
|
||||
"generic-type",
|
||||
"base-type",
|
||||
"constraint-type",
|
||||
"op-definition",
|
||||
"type-statement",
|
||||
"extend-statement",
|
||||
|
||||
204
tester.py
204
tester.py
@@ -1,204 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import difflib
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from midas.ast.json_serializer import AstJsonSerializer
|
||||
from midas.ast.midas import Stmt
|
||||
from midas.lexer.base import MidasSyntaxError
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
|
||||
DEFAULT_BASE_DIR: Path = Path() / "tests"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaseResult:
|
||||
tokens: Optional[list[dict]] = None
|
||||
stmts: Optional[list[dict]] = None
|
||||
errors: list[dict] = field(default_factory=list)
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(asdict(self), indent=2)
|
||||
|
||||
|
||||
class Tester:
|
||||
"""A test runner to check for regressions in the lexer and parser"""
|
||||
|
||||
def __init__(self, base_dir: Path):
|
||||
self.base_dir: Path = base_dir
|
||||
|
||||
def _list_tests(self) -> list[Path]:
|
||||
return list(self.base_dir.rglob("*.midas"))
|
||||
|
||||
def run_all_tests(self) -> bool:
|
||||
paths: list[Path] = self._list_tests()
|
||||
return self.run_tests(paths)
|
||||
|
||||
def run_tests(self, tests: list[Path]) -> bool:
|
||||
rule: str = "-" * 80
|
||||
n: int = len(tests)
|
||||
successes: int = 0
|
||||
failures: int = 0
|
||||
|
||||
print(rule)
|
||||
for i, test in enumerate(tests):
|
||||
print(f"Case {i+1}/{n}: {test}")
|
||||
success: bool = self._run_test(test)
|
||||
if success:
|
||||
successes += 1
|
||||
else:
|
||||
failures += 1
|
||||
|
||||
print(rule)
|
||||
print(f"Success: {successes}/{n}")
|
||||
print(f"Failed: {failures}/{n}")
|
||||
print(rule)
|
||||
return failures == 0
|
||||
|
||||
def _run_test(self, path: Path) -> bool:
|
||||
result: CaseResult = self._exec_case(path)
|
||||
result_path: Path = self._result_path(path)
|
||||
expected: str = result_path.read_text()
|
||||
actual: str = result.dumps()
|
||||
|
||||
if expected == actual:
|
||||
return True
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
expected.splitlines(keepends=True),
|
||||
actual.splitlines(keepends=True),
|
||||
fromfile="Snapshot",
|
||||
tofile="Result",
|
||||
)
|
||||
self._print_diff(diff)
|
||||
return False
|
||||
|
||||
def _exec_case(self, path: Path) -> CaseResult:
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Could not find test '{path}'")
|
||||
if not path.is_file():
|
||||
raise TypeError(f"Test '{path}' is not a file")
|
||||
|
||||
result: CaseResult = CaseResult()
|
||||
content: str = path.read_text()
|
||||
lexer: MidasLexer = MidasLexer(content)
|
||||
tokens: list[Token] = []
|
||||
try:
|
||||
tokens = lexer.process()
|
||||
result.tokens = [
|
||||
{
|
||||
"type": token.type.name,
|
||||
"lexeme": token.lexeme,
|
||||
"line": token.position.line,
|
||||
"column": token.position.column,
|
||||
}
|
||||
for token in tokens
|
||||
]
|
||||
except MidasSyntaxError as e:
|
||||
result.errors.append(
|
||||
{
|
||||
"type": "SyntaxError",
|
||||
"line": e.pos.line,
|
||||
"column": e.pos.column,
|
||||
"message": e.message,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
parser: MidasParser = MidasParser(tokens)
|
||||
stmts: list[Stmt] = parser.parse()
|
||||
result.stmts = AstJsonSerializer().serialize(stmts)
|
||||
result.errors.extend(
|
||||
[
|
||||
{
|
||||
"line": e.token.position.line,
|
||||
"column": e.token.position.column,
|
||||
"message": e.message,
|
||||
}
|
||||
for e in parser.errors
|
||||
]
|
||||
)
|
||||
return result
|
||||
|
||||
def update_all_tests(self):
|
||||
paths: list[Path] = self._list_tests()
|
||||
return self.update_tests(paths)
|
||||
|
||||
def update_tests(self, tests: list[Path]):
|
||||
updated: int = 0
|
||||
for test in tests:
|
||||
if self._update_test(test):
|
||||
updated += 1
|
||||
print(f"Updated {updated}/{len(tests)} tests")
|
||||
|
||||
def _update_test(self, path: Path) -> bool:
|
||||
result: CaseResult = self._exec_case(path)
|
||||
result_path: Path = self._result_path(path)
|
||||
current: str = result_path.read_text()
|
||||
new: str = result.dumps()
|
||||
if current == new:
|
||||
return False
|
||||
result_path.write_text(new)
|
||||
return True
|
||||
|
||||
def _result_path(self, test_path: Path) -> Path:
|
||||
return test_path.parent / (test_path.name + ".ref.json")
|
||||
|
||||
def _print_diff(self, diff: Iterator[str]):
|
||||
for line in diff:
|
||||
if line.startswith("+") and not line.startswith("+++"):
|
||||
print(f"\033[92m{line}\033[0m", end="")
|
||||
elif line.startswith("-") and not line.startswith("---"):
|
||||
print(f"\033[91m{line}\033[0m", end="")
|
||||
else:
|
||||
print(line, end="")
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-D",
|
||||
"--base-dir",
|
||||
help="Base directory containing test files",
|
||||
type=Path,
|
||||
default=DEFAULT_BASE_DIR,
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="subcommand")
|
||||
|
||||
update = subparsers.add_parser("update")
|
||||
update.add_argument("-a", "--all", action="store_true")
|
||||
update.add_argument("FILE", type=Path, nargs="*")
|
||||
|
||||
run = subparsers.add_parser("run")
|
||||
run.add_argument("-a", "--all", action="store_true")
|
||||
run.add_argument("FILE", type=Path, nargs="*")
|
||||
args = parser.parse_args()
|
||||
|
||||
tester: Tester = Tester(args.base_dir)
|
||||
|
||||
match args.subcommand:
|
||||
case "update":
|
||||
if args.all:
|
||||
tester.update_all_tests()
|
||||
else:
|
||||
tester.update_tests(args.FILE)
|
||||
case "run":
|
||||
success: bool
|
||||
if args.all:
|
||||
success = tester.run_all_tests()
|
||||
else:
|
||||
success = tester.run_tests(args.FILE)
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
143
tests/base.py
Normal file
143
tests/base.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import difflib
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Protocol
|
||||
|
||||
|
||||
class CaseResult(Protocol):
|
||||
def dumps(self) -> str: ...
|
||||
|
||||
|
||||
class Tester(ABC):
|
||||
"""A test runner to check for regressions in the lexer and parser"""
|
||||
|
||||
CASES_DIR: Path = Path(__file__).parent / "cases"
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def namespace(self) -> str: ...
|
||||
|
||||
@property
|
||||
def base_dir(self) -> Path:
|
||||
return self.CASES_DIR / self.namespace
|
||||
|
||||
@abstractmethod
|
||||
def _list_tests(self) -> list[Path]: ...
|
||||
|
||||
def run_all_tests(self) -> bool:
|
||||
paths: list[Path] = self._list_tests()
|
||||
return self.run_tests(paths)
|
||||
|
||||
def run_tests(self, tests: list[Path]) -> bool:
|
||||
rule: str = "-" * 80
|
||||
n: int = len(tests)
|
||||
successes: int = 0
|
||||
failures: int = 0
|
||||
|
||||
print(rule)
|
||||
for i, test in enumerate(tests):
|
||||
print(f"Case {i+1}/{n}: {test.relative_to(self.CASES_DIR)}")
|
||||
success: bool = self._run_test(test)
|
||||
if success:
|
||||
successes += 1
|
||||
else:
|
||||
failures += 1
|
||||
|
||||
print(rule)
|
||||
print(f"Success: {successes}/{n}")
|
||||
print(f"Failed: {failures}/{n}")
|
||||
print(rule)
|
||||
return failures == 0
|
||||
|
||||
def _run_test(self, path: Path) -> bool:
|
||||
result_path: Path = self._result_path(path)
|
||||
if not result_path.exists():
|
||||
print("Missing snapshot. Please run the update command first")
|
||||
return False
|
||||
result: CaseResult = self._exec_case(path)
|
||||
expected: str = result_path.read_text()
|
||||
actual: str = result.dumps()
|
||||
|
||||
if expected == actual:
|
||||
return True
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
expected.splitlines(keepends=True),
|
||||
actual.splitlines(keepends=True),
|
||||
fromfile="Snapshot",
|
||||
tofile="Result",
|
||||
)
|
||||
self._print_diff(diff)
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def _exec_case(self, path: Path) -> CaseResult: ...
|
||||
|
||||
def update_all_tests(self):
|
||||
paths: list[Path] = self._list_tests()
|
||||
return self.update_tests(paths)
|
||||
|
||||
def update_tests(self, tests: list[Path]):
|
||||
updated: int = 0
|
||||
for test in tests:
|
||||
if self._update_test(test):
|
||||
updated += 1
|
||||
print(f"Updated {updated}/{len(tests)} tests")
|
||||
|
||||
def _update_test(self, path: Path) -> bool:
|
||||
result: CaseResult = self._exec_case(path)
|
||||
result_path: Path = self._result_path(path)
|
||||
current: str = result_path.read_text() if result_path.exists() else ""
|
||||
new: str = result.dumps()
|
||||
if current == new:
|
||||
return False
|
||||
result_path.write_text(new)
|
||||
return True
|
||||
|
||||
def _result_path(self, test_path: Path) -> Path:
|
||||
return test_path.parent / (test_path.name + ".ref.json")
|
||||
|
||||
def _print_diff(self, diff: Iterator[str]):
|
||||
for line in diff:
|
||||
if line.startswith("+") and not line.startswith("+++"):
|
||||
print(f"\033[92m{line}\033[0m", end="")
|
||||
elif line.startswith("-") and not line.startswith("---"):
|
||||
print(f"\033[91m{line}\033[0m", end="")
|
||||
else:
|
||||
print(line, end="")
|
||||
print()
|
||||
|
||||
@classmethod
|
||||
def main(cls):
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers(dest="subcommand")
|
||||
|
||||
update = subparsers.add_parser("update")
|
||||
update.add_argument("-a", "--all", action="store_true")
|
||||
update.add_argument("FILE", type=Path, nargs="*")
|
||||
|
||||
run = subparsers.add_parser("run")
|
||||
run.add_argument("-a", "--all", action="store_true")
|
||||
run.add_argument("FILE", type=Path, nargs="*")
|
||||
args = parser.parse_args()
|
||||
|
||||
tester: Tester = cls()
|
||||
|
||||
match args.subcommand:
|
||||
case "update":
|
||||
if args.all:
|
||||
tester.update_all_tests()
|
||||
else:
|
||||
tester.update_tests(args.FILE)
|
||||
case "run":
|
||||
success: bool
|
||||
if args.all:
|
||||
success = tester.run_all_tests()
|
||||
else:
|
||||
success = tester.run_tests(args.FILE)
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
14
tests/cases/checker/01_simple_types.py
Normal file
14
tests/cases/checker/01_simple_types.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
df: Frame[
|
||||
verified: bool,
|
||||
birth_year: int,
|
||||
height: float + ( _ > 0 ) + ( _ < 250 ),
|
||||
name: str,
|
||||
date: datetime,
|
||||
float,
|
||||
unknown: _,
|
||||
_
|
||||
]
|
||||
3
tests/cases/checker/01_simple_types.py.ref.json
Normal file
3
tests/cases/checker/01_simple_types.py.ref.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"diagnostics": []
|
||||
}
|
||||
11
tests/cases/checker/02_simple_operations.py
Normal file
11
tests/cases/checker/02_simple_operations.py
Normal file
@@ -0,0 +1,11 @@
|
||||
a: int = 3
|
||||
b: int = 4
|
||||
|
||||
c = a + b
|
||||
|
||||
c = "invalid"
|
||||
|
||||
d = True
|
||||
e = d + d
|
||||
|
||||
f: float = a
|
||||
46
tests/cases/checker/02_simple_operations.py.ref.json
Normal file
46
tests/cases/checker/02_simple_operations.py.ref.json
Normal file
@@ -0,0 +1,46 @@
|
||||
{
|
||||
"diagnostics": [
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
6,
|
||||
0
|
||||
],
|
||||
"end": [
|
||||
6,
|
||||
13
|
||||
]
|
||||
},
|
||||
"message": "Cannot assign BaseType(name='str') to c of type BaseType(name='int')"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
9,
|
||||
4
|
||||
],
|
||||
"end": [
|
||||
9,
|
||||
9
|
||||
]
|
||||
},
|
||||
"message": "Undefined operation __add__ between BaseType(name='bool') and BaseType(name='bool')"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
11,
|
||||
0
|
||||
],
|
||||
"end": [
|
||||
11,
|
||||
12
|
||||
]
|
||||
},
|
||||
"message": "Cannot assign BaseType(name='int') to f of type BaseType(name='float')"
|
||||
}
|
||||
]
|
||||
}
|
||||
18
tests/cases/checker/03_functions.py
Normal file
18
tests/cases/checker/03_functions.py
Normal file
@@ -0,0 +1,18 @@
|
||||
def foo(a: int, /, b: float, *, c: str):
|
||||
return True
|
||||
|
||||
|
||||
r1 = foo()
|
||||
r2 = foo(1)
|
||||
r3 = foo(1, 2.0)
|
||||
r4 = foo(1, b=2.0)
|
||||
r5 = foo(1, 2.0, "test")
|
||||
r6 = foo(1, 2.0, b=3.0)
|
||||
r7 = foo(a=1)
|
||||
r8 = foo(g="test")
|
||||
|
||||
r9a = foo(1, 2.0, c="test")
|
||||
r9b = foo(1, b=2.0, c="test")
|
||||
r9c = foo(1, c="test", b=2.0)
|
||||
|
||||
r10 = foo("a", 3, c=False)
|
||||
270
tests/cases/checker/03_functions.py.ref.json
Normal file
270
tests/cases/checker/03_functions.py.ref.json
Normal file
@@ -0,0 +1,270 @@
|
||||
{
|
||||
"diagnostics": [
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
5,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
5,
|
||||
10
|
||||
]
|
||||
},
|
||||
"message": "Missing required positional arguments: 'a' and 'b'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
5,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
5,
|
||||
10
|
||||
]
|
||||
},
|
||||
"message": "Missing required keyword argument: 'c'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
6,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
6,
|
||||
11
|
||||
]
|
||||
},
|
||||
"message": "Missing required positional argument: 'b'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
6,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
6,
|
||||
11
|
||||
]
|
||||
},
|
||||
"message": "Missing required keyword argument: 'c'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
7,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
7,
|
||||
16
|
||||
]
|
||||
},
|
||||
"message": "Missing required keyword argument: 'c'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
8,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
8,
|
||||
18
|
||||
]
|
||||
},
|
||||
"message": "Missing required keyword argument: 'c'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
9,
|
||||
17
|
||||
],
|
||||
"end": [
|
||||
9,
|
||||
23
|
||||
]
|
||||
},
|
||||
"message": "Too many positional arguments"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
9,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
9,
|
||||
24
|
||||
]
|
||||
},
|
||||
"message": "Missing required keyword argument: 'c'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
10,
|
||||
19
|
||||
],
|
||||
"end": [
|
||||
10,
|
||||
22
|
||||
]
|
||||
},
|
||||
"message": "Multiple values for argument 'b'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
10,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
10,
|
||||
23
|
||||
]
|
||||
},
|
||||
"message": "Missing required keyword argument: 'c'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
11,
|
||||
11
|
||||
],
|
||||
"end": [
|
||||
11,
|
||||
12
|
||||
]
|
||||
},
|
||||
"message": "Unknown keyword argument 'a'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
11,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
11,
|
||||
13
|
||||
]
|
||||
},
|
||||
"message": "Missing required positional arguments: 'a' and 'b'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
11,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
11,
|
||||
13
|
||||
]
|
||||
},
|
||||
"message": "Missing required keyword argument: 'c'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
12,
|
||||
11
|
||||
],
|
||||
"end": [
|
||||
12,
|
||||
17
|
||||
]
|
||||
},
|
||||
"message": "Unknown keyword argument 'g'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
12,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
12,
|
||||
18
|
||||
]
|
||||
},
|
||||
"message": "Missing required positional arguments: 'a' and 'b'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
12,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
12,
|
||||
18
|
||||
]
|
||||
},
|
||||
"message": "Missing required keyword argument: 'c'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
18,
|
||||
10
|
||||
],
|
||||
"end": [
|
||||
18,
|
||||
13
|
||||
]
|
||||
},
|
||||
"message": "Wrong type for argument 'a', expected BaseType(name='int'), got BaseType(name='str')"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
18,
|
||||
15
|
||||
],
|
||||
"end": [
|
||||
18,
|
||||
16
|
||||
]
|
||||
},
|
||||
"message": "Wrong type for argument 'b', expected BaseType(name='float'), got BaseType(name='int')"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
18,
|
||||
20
|
||||
],
|
||||
"end": [
|
||||
18,
|
||||
25
|
||||
]
|
||||
},
|
||||
"message": "Wrong type for argument 'c', expected BaseType(name='str'), got BaseType(name='bool')"
|
||||
}
|
||||
]
|
||||
}
|
||||
14
tests/cases/checker/04_custom_types.midas
Normal file
14
tests/cases/checker/04_custom_types.midas
Normal file
@@ -0,0 +1,14 @@
|
||||
type Meter = float
|
||||
type Second = float
|
||||
type MeterPerSecond = float
|
||||
|
||||
extend Meter {
|
||||
op __add__(Meter) -> Meter
|
||||
op __sub__(Meter) -> Meter
|
||||
op __truediv__(Second) -> MeterPerSecond
|
||||
}
|
||||
|
||||
extend Second {
|
||||
op __add__(Second) -> Second
|
||||
op __sub__(Second) -> Second
|
||||
}
|
||||
6
tests/cases/checker/04_custom_types.py
Normal file
6
tests/cases/checker/04_custom_types.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
distance: Meter = cast(Meter, 123.45)
|
||||
time: Second = cast(Second, 6.7)
|
||||
speed = distance / time
|
||||
3
tests/cases/checker/04_custom_types.py.ref.json
Normal file
3
tests/cases/checker/04_custom_types.py.ref.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"diagnostics": []
|
||||
}
|
||||
25
tests/cases/checker/05_control_flow.py
Normal file
25
tests/cases/checker/05_control_flow.py
Normal file
@@ -0,0 +1,25 @@
|
||||
def valid(a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
def with_if(a: int, b: int) -> int:
|
||||
if a < b:
|
||||
return b - a
|
||||
else:
|
||||
return a - b
|
||||
|
||||
def unreachable1():
|
||||
return
|
||||
a = 0
|
||||
|
||||
def unreachable2(a: int) -> int:
|
||||
if a > 10:
|
||||
return a - 10
|
||||
else:
|
||||
return a
|
||||
b = 0
|
||||
|
||||
def mixed(a: int, b: int):
|
||||
if a < b:
|
||||
return b - a
|
||||
else:
|
||||
return "oops"
|
||||
46
tests/cases/checker/05_control_flow.py.ref.json
Normal file
46
tests/cases/checker/05_control_flow.py.ref.json
Normal file
@@ -0,0 +1,46 @@
|
||||
{
|
||||
"diagnostics": [
|
||||
{
|
||||
"type": "Warning",
|
||||
"location": {
|
||||
"start": [
|
||||
12,
|
||||
4
|
||||
],
|
||||
"end": [
|
||||
12,
|
||||
9
|
||||
]
|
||||
},
|
||||
"message": "Unreachable statement"
|
||||
},
|
||||
{
|
||||
"type": "Warning",
|
||||
"location": {
|
||||
"start": [
|
||||
19,
|
||||
4
|
||||
],
|
||||
"end": [
|
||||
19,
|
||||
9
|
||||
]
|
||||
},
|
||||
"message": "Unreachable statement"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
21,
|
||||
0
|
||||
],
|
||||
"end": [
|
||||
25,
|
||||
21
|
||||
]
|
||||
},
|
||||
"message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,15 +1,15 @@
|
||||
// Simple custom type derived from float
|
||||
type Custom(float)
|
||||
type Custom = float
|
||||
|
||||
// Simple custom types with constraints
|
||||
type Latitude(float) where (-90 <= _ <= 90)
|
||||
type Longitude(float) where (-180 <= _ <= 180)
|
||||
type Latitude = float where (-90 <= _ <= 90)
|
||||
type Longitude = float where (-180 <= _ <= 180)
|
||||
|
||||
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
|
||||
type Difference[T](T)
|
||||
type Difference[T] = T
|
||||
|
||||
// Complex custom type, containing two values accessible through properties
|
||||
type GeoLocation {
|
||||
type GeoLocation = {
|
||||
lat: Latitude
|
||||
lon: Longitude
|
||||
}
|
||||
@@ -24,7 +24,7 @@ extend GeoLocation {
|
||||
|
||||
// For complex generics, you need to specify how the genericity the properties
|
||||
// are handled
|
||||
type Difference[GeoLocation] {
|
||||
type Difference[GeoLocation] = {
|
||||
lat: Difference[Latitude]
|
||||
lon: Difference[Longitude]
|
||||
}
|
||||
@@ -44,11 +44,11 @@ predicate StrictlyPositive(v: float) = v > 0
|
||||
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
|
||||
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
|
||||
|
||||
type Person {
|
||||
type Person = {
|
||||
name: str
|
||||
|
||||
// Property with an inline constraint
|
||||
age: int? where (0 <= _ < 150)
|
||||
age: Optional[int where (0 <= _ < 150)]
|
||||
|
||||
// Property referencing a predicate
|
||||
height: float where StrictlyPositive
|
||||
File diff suppressed because it is too large
Load Diff
14
tests/cases/python-parser/01_simple_types.py
Normal file
14
tests/cases/python-parser/01_simple_types.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
df: Frame[
|
||||
verified: bool,
|
||||
birth_year: int,
|
||||
height: float + ( _ > 0 ) + ( _ < 250 ),
|
||||
name: str,
|
||||
date: datetime,
|
||||
float,
|
||||
unknown: _,
|
||||
_
|
||||
]
|
||||
85
tests/cases/python-parser/01_simple_types.py.ref.json
Normal file
85
tests/cases/python-parser/01_simple_types.py.ref.json
Normal file
@@ -0,0 +1,85 @@
|
||||
{
|
||||
"stmts": [
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "df",
|
||||
"type": {
|
||||
"_type": "FrameType",
|
||||
"columns": [
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "verified",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "bool",
|
||||
"param": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "birth_year",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "height",
|
||||
"type": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "(_ > 0) + (_ < 250)"
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "name",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "str",
|
||||
"param": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "date",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "datetime",
|
||||
"param": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": null,
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "unknown",
|
||||
"type": null
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": null,
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "_",
|
||||
"param": null
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
25
tests/cases/python-parser/02_custom_types.py
Normal file
25
tests/cases/python-parser/02_custom_types.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
df: Frame[
|
||||
location: GeoLocation
|
||||
]
|
||||
|
||||
lat: Column[GeoLocation] = df["location"].lat
|
||||
lon: Column[GeoLocation] = df["location"].lon
|
||||
|
||||
lat + lon
|
||||
|
||||
lat1: Latitude = lat[0]
|
||||
lat2: Latitude = lat[1]
|
||||
lat_diff: Difference[Latitude] = lat2 - lat1
|
||||
|
||||
df2: Frame[
|
||||
age: int + (_ >= 0),
|
||||
height: float + (_ >= 0),
|
||||
]
|
||||
df2_bis: Frame[
|
||||
age: int + Positive,
|
||||
height: float + Positive,
|
||||
]
|
||||
141
tests/cases/python-parser/02_custom_types.py.ref.json
Normal file
141
tests/cases/python-parser/02_custom_types.py.ref.json
Normal file
@@ -0,0 +1,141 @@
|
||||
{
|
||||
"stmts": [
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "df",
|
||||
"type": {
|
||||
"_type": "FrameType",
|
||||
"columns": [
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "location",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "GeoLocation",
|
||||
"param": null
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "ExpressionStmt",
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "lat"
|
||||
},
|
||||
"operator": "+",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "lon"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "lat_diff",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Difference",
|
||||
"param": {
|
||||
"_type": "BaseType",
|
||||
"base": "Latitude",
|
||||
"param": null
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "AssignStmt",
|
||||
"targets": [
|
||||
{
|
||||
"_type": "VariableExpr",
|
||||
"name": "lat_diff"
|
||||
}
|
||||
],
|
||||
"value": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "lat2"
|
||||
},
|
||||
"operator": "-",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "lat1"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "df2",
|
||||
"type": {
|
||||
"_type": "FrameType",
|
||||
"columns": [
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "age",
|
||||
"type": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "_ >= 0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "height",
|
||||
"type": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "_ >= 0"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "df2_bis",
|
||||
"type": {
|
||||
"_type": "FrameType",
|
||||
"columns": [
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "age",
|
||||
"type": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "Positive"
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "height",
|
||||
"type": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "Positive"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
15
tests/cases/python-parser/03_functions.py
Normal file
15
tests/cases/python-parser/03_functions.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def func(
|
||||
col1: Column[float + (0 <= _ <= 1)],
|
||||
col2: Column[float + (0 <= _ <= 1)],
|
||||
) -> Column[float + (0 <= _ <= 2)]:
|
||||
result: Column[float + (0 <= _ <= 2)] = col1 + col2
|
||||
return result
|
||||
|
||||
|
||||
def func2(a: int, /, b: float, *, c: str):
|
||||
pass
|
||||
149
tests/cases/python-parser/03_functions.py.ref.json
Normal file
149
tests/cases/python-parser/03_functions.py.ref.json
Normal file
@@ -0,0 +1,149 @@
|
||||
{
|
||||
"stmts": [
|
||||
{
|
||||
"_type": "Function",
|
||||
"name": "func",
|
||||
"posonlyargs": [],
|
||||
"args": [
|
||||
{
|
||||
"name": "col1",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 1"
|
||||
}
|
||||
},
|
||||
"default": null
|
||||
},
|
||||
{
|
||||
"name": "col2",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 1"
|
||||
}
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"sink": null,
|
||||
"kwonlyargs": [],
|
||||
"kw_sink": null,
|
||||
"returns": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 2"
|
||||
}
|
||||
},
|
||||
"body": [
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "result",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 2"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "AssignStmt",
|
||||
"targets": [
|
||||
{
|
||||
"_type": "VariableExpr",
|
||||
"name": "result"
|
||||
}
|
||||
],
|
||||
"value": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "col1"
|
||||
},
|
||||
"operator": "+",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "col2"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "ReturnStmt",
|
||||
"value": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "result"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"_type": "Function",
|
||||
"name": "func2",
|
||||
"posonlyargs": [
|
||||
{
|
||||
"name": "a",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"args": [
|
||||
{
|
||||
"name": "b",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"sink": null,
|
||||
"kwonlyargs": [
|
||||
{
|
||||
"name": "c",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "str",
|
||||
"param": null
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"kw_sink": null,
|
||||
"returns": null,
|
||||
"body": []
|
||||
}
|
||||
]
|
||||
}
|
||||
75
tests/checker.py
Normal file
75
tests/checker.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import ast
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.checker.checker import Checker
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.parser.python import PythonParser
|
||||
from midas.resolver.resolver import Resolver
|
||||
from tests.base import Tester
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaseResult:
|
||||
diagnostics: list[dict] = field(default_factory=list)
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(asdict(self), indent=2)
|
||||
|
||||
|
||||
class CheckerTester(Tester):
|
||||
@property
|
||||
def namespace(self) -> str:
|
||||
return "checker"
|
||||
|
||||
def _list_tests(self) -> list[Path]:
|
||||
return list(self.base_dir.rglob("*.py"))
|
||||
|
||||
def _exec_case(self, path: Path) -> CaseResult:
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Could not find test '{path}'")
|
||||
if not path.is_file():
|
||||
raise TypeError(f"Test '{path}' is not a file")
|
||||
|
||||
types_paths: list[Path] = []
|
||||
types_path: Path = path.with_suffix(".midas")
|
||||
if types_path.exists():
|
||||
types_paths.append(types_path)
|
||||
source: str = path.read_text()
|
||||
tree: ast.Module = ast.parse(source, filename=path)
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
resolver = Resolver()
|
||||
resolver.resolve(*stmts)
|
||||
result: CaseResult = CaseResult()
|
||||
checker = Checker(
|
||||
resolver.locals,
|
||||
source_path=path,
|
||||
types_paths=types_paths,
|
||||
)
|
||||
diagnostics: list[Diagnostic] = checker.check(stmts)
|
||||
for diagnostic in diagnostics:
|
||||
result.diagnostics.append(
|
||||
{
|
||||
"type": str(diagnostic.type),
|
||||
"location": {
|
||||
"start": (
|
||||
diagnostic.location.lineno,
|
||||
diagnostic.location.col_offset,
|
||||
),
|
||||
"end": (
|
||||
diagnostic.location.end_lineno,
|
||||
diagnostic.location.end_col_offset,
|
||||
),
|
||||
},
|
||||
"message": diagnostic.message,
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CheckerTester.main()
|
||||
82
tests/midas.py
Normal file
82
tests/midas.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.midas import Stmt
|
||||
from midas.lexer.base import MidasSyntaxError
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
from tests.base import Tester
|
||||
from tests.serializer.midas import MidasAstJsonSerializer
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaseResult:
|
||||
tokens: Optional[list[dict]] = None
|
||||
stmts: Optional[list[dict]] = None
|
||||
errors: list[dict] = field(default_factory=list)
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(asdict(self), indent=2)
|
||||
|
||||
|
||||
class MidasTester(Tester):
|
||||
@property
|
||||
def namespace(self) -> str:
|
||||
return "midas-parser"
|
||||
|
||||
def _list_tests(self) -> list[Path]:
|
||||
return list(self.base_dir.rglob("*.midas"))
|
||||
|
||||
def _exec_case(self, path: Path) -> CaseResult:
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Could not find test '{path}'")
|
||||
if not path.is_file():
|
||||
raise TypeError(f"Test '{path}' is not a file")
|
||||
|
||||
result: CaseResult = CaseResult()
|
||||
content: str = path.read_text()
|
||||
lexer: MidasLexer = MidasLexer(content)
|
||||
tokens: list[Token] = []
|
||||
try:
|
||||
tokens = lexer.process()
|
||||
result.tokens = [
|
||||
{
|
||||
"type": token.type.name,
|
||||
"lexeme": token.lexeme,
|
||||
"line": token.position.line,
|
||||
"column": token.position.column,
|
||||
}
|
||||
for token in tokens
|
||||
]
|
||||
except MidasSyntaxError as e:
|
||||
result.errors.append(
|
||||
{
|
||||
"type": "SyntaxError",
|
||||
"line": e.pos.line,
|
||||
"column": e.pos.column,
|
||||
"message": e.message,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
parser: MidasParser = MidasParser(tokens)
|
||||
stmts: list[Stmt] = parser.parse()
|
||||
result.stmts = MidasAstJsonSerializer().serialize(stmts)
|
||||
result.errors.extend(
|
||||
[
|
||||
{
|
||||
"line": e.token.position.line,
|
||||
"column": e.token.position.column,
|
||||
"message": e.message,
|
||||
}
|
||||
for e in parser.errors
|
||||
]
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
MidasTester.main()
|
||||
46
tests/python.py
Normal file
46
tests/python.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import ast
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.python import Stmt
|
||||
from midas.parser.python import PythonParser
|
||||
from tests.base import Tester
|
||||
from tests.serializer.python import PythonAstJsonSerializer
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaseResult:
|
||||
stmts: Optional[list[dict]] = None
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(asdict(self), indent=2)
|
||||
|
||||
|
||||
class PythonTester(Tester):
|
||||
@property
|
||||
def namespace(self) -> str:
|
||||
return "python-parser"
|
||||
|
||||
def _list_tests(self) -> list[Path]:
|
||||
return list(self.base_dir.rglob("*.py"))
|
||||
|
||||
def _exec_case(self, path: Path) -> CaseResult:
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Could not find test '{path}'")
|
||||
if not path.is_file():
|
||||
raise TypeError(f"Test '{path}' is not a file")
|
||||
|
||||
result: CaseResult = CaseResult()
|
||||
content: str = path.read_text()
|
||||
tree: ast.Module = ast.parse(content)
|
||||
|
||||
parser: PythonParser = PythonParser()
|
||||
stmts: list[Stmt] = parser.parse_module(tree)
|
||||
result.stmts = PythonAstJsonSerializer().serialize(stmts)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
PythonTester.main()
|
||||
@@ -2,56 +2,60 @@ from typing import Optional, Sequence
|
||||
|
||||
from midas.ast.midas import (
|
||||
BinaryExpr,
|
||||
ComplexTypeStmt,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
GenericType,
|
||||
GetExpr,
|
||||
GroupingExpr,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
NamedType,
|
||||
OpStmt,
|
||||
PredicateStmt,
|
||||
PropertyStmt,
|
||||
SimpleTypeExpr,
|
||||
SimpleTypeStmt,
|
||||
Stmt,
|
||||
TemplateExpr,
|
||||
TypeExpr,
|
||||
Type,
|
||||
TypeStmt,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
WildcardExpr,
|
||||
)
|
||||
|
||||
|
||||
class AstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
|
||||
class MidasAstJsonSerializer(
|
||||
Stmt.Visitor[dict], Expr.Visitor[dict], Type.Visitor[dict]
|
||||
):
|
||||
"""An AST serializer which produces a JSON-compatible structure"""
|
||||
|
||||
def serialize(self, stmts: list[Stmt]) -> list[dict]:
|
||||
return [stmt.accept(self) for stmt in stmts]
|
||||
|
||||
def _serialize_optional(self, element: Optional[Stmt | Expr]) -> Optional[dict]:
|
||||
def _serialize_optional(
|
||||
self, element: Optional[Stmt | Expr | Type]
|
||||
) -> Optional[dict]:
|
||||
if element is None:
|
||||
return None
|
||||
return element.accept(self)
|
||||
|
||||
def _serialize_list(self, elements: Sequence[Stmt | Expr]) -> list[dict]:
|
||||
def _serialize_list(self, elements: Sequence[Stmt | Expr | Type]) -> list[dict]:
|
||||
return [element.accept(self) for element in elements]
|
||||
|
||||
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> dict:
|
||||
def visit_type_stmt(self, stmt: TypeStmt) -> dict:
|
||||
return {
|
||||
"_type": "SimpleTypeStmt",
|
||||
"template": self._serialize_optional(stmt.template),
|
||||
"_type": "TypeStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"base": stmt.base.accept(self),
|
||||
"constraint": self._serialize_optional(stmt.constraint),
|
||||
"params": [
|
||||
self._serialize_type_stmt_template_param(param) for param in stmt.params
|
||||
],
|
||||
"type": stmt.type.accept(self),
|
||||
}
|
||||
|
||||
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> dict:
|
||||
def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict:
|
||||
return {
|
||||
"_type": "ComplexTypeStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"template": self._serialize_optional(stmt.template),
|
||||
"properties": self._serialize_list(stmt.properties),
|
||||
"name": param.name.lexeme,
|
||||
"bound": self._serialize_optional(param.bound),
|
||||
}
|
||||
|
||||
def visit_property_stmt(self, stmt: PropertyStmt) -> dict:
|
||||
@@ -59,7 +63,6 @@ class AstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
|
||||
"_type": "PropertyStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"type": stmt.type.accept(self),
|
||||
"constraint": self._serialize_optional(stmt.constraint),
|
||||
}
|
||||
|
||||
def visit_extend_stmt(self, stmt: ExtendStmt) -> dict:
|
||||
@@ -86,13 +89,6 @@ class AstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
|
||||
"condition": stmt.condition.accept(self),
|
||||
}
|
||||
|
||||
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> dict:
|
||||
return {
|
||||
"_type": "SimpleTypeExpr",
|
||||
"name": expr.name.lexeme,
|
||||
"optional": expr.optional,
|
||||
}
|
||||
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
||||
return {
|
||||
"_type": "LogicalExpr",
|
||||
@@ -144,16 +140,28 @@ class AstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> dict:
|
||||
return {"_type": "WildcardExpr"}
|
||||
|
||||
def visit_template_expr(self, expr: TemplateExpr) -> dict:
|
||||
def visit_named_type(self, type: NamedType) -> dict:
|
||||
return {
|
||||
"_type": "TemplateExpr",
|
||||
"type": expr.type.accept(self),
|
||||
"_type": "NamedType",
|
||||
"name": type.name.lexeme,
|
||||
}
|
||||
|
||||
def visit_type_expr(self, expr: TypeExpr) -> dict:
|
||||
def visit_generic_type(self, type: GenericType) -> dict:
|
||||
return {
|
||||
"_type": "TypeExpr",
|
||||
"name": expr.name.lexeme,
|
||||
"template": self._serialize_optional(expr.template),
|
||||
"optional": expr.optional,
|
||||
"_type": "GenericType",
|
||||
"type": type.type.accept(self),
|
||||
"params": self._serialize_list(type.params),
|
||||
}
|
||||
|
||||
def visit_constraint_type(self, type: ConstraintType) -> dict:
|
||||
return {
|
||||
"_type": "ConstraintType",
|
||||
"type": type.type.accept(self),
|
||||
"constraint": type.constraint.accept(self),
|
||||
}
|
||||
|
||||
def visit_complex_type(self, type: ComplexType) -> dict:
|
||||
return {
|
||||
"_type": "ComplexType",
|
||||
"properties": self._serialize_list(type.properties),
|
||||
}
|
||||
256
tests/serializer/python.py
Normal file
256
tests/serializer/python.py
Normal file
@@ -0,0 +1,256 @@
|
||||
import ast
|
||||
from typing import Optional, Sequence, Type
|
||||
|
||||
from midas.ast.python import (
|
||||
AssignStmt,
|
||||
BaseType,
|
||||
BinaryExpr,
|
||||
CallExpr,
|
||||
CastExpr,
|
||||
CompareExpr,
|
||||
ConstraintType,
|
||||
Expr,
|
||||
ExpressionStmt,
|
||||
FrameColumn,
|
||||
FrameType,
|
||||
Function,
|
||||
GetExpr,
|
||||
IfStmt,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
MidasType,
|
||||
ReturnStmt,
|
||||
SetExpr,
|
||||
Stmt,
|
||||
TernaryExpr,
|
||||
TypeAssign,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
)
|
||||
|
||||
unary_ops: dict[Type[ast.unaryop], str] = {
|
||||
ast.Invert: "~",
|
||||
ast.Not: "not",
|
||||
ast.UAdd: "+",
|
||||
ast.USub: "-",
|
||||
}
|
||||
binary_ops: dict[Type[ast.operator], str] = {
|
||||
ast.Add: "+",
|
||||
ast.Sub: "-",
|
||||
ast.Mult: "*",
|
||||
ast.MatMult: "@",
|
||||
ast.Div: "/",
|
||||
ast.Mod: "%",
|
||||
ast.LShift: "<<",
|
||||
ast.RShift: ">>",
|
||||
ast.BitOr: "|",
|
||||
ast.BitXor: "^",
|
||||
ast.BitAnd: "&",
|
||||
ast.FloorDiv: "//",
|
||||
ast.Pow: "**",
|
||||
}
|
||||
compare_ops: dict[Type[ast.cmpop], str] = {
|
||||
ast.Eq: "==",
|
||||
ast.NotEq: "!=",
|
||||
ast.Lt: "<",
|
||||
ast.LtE: "<=",
|
||||
ast.Gt: ">",
|
||||
ast.GtE: ">=",
|
||||
ast.Is: "is",
|
||||
ast.IsNot: "is not",
|
||||
ast.In: "in",
|
||||
ast.NotIn: "not in",
|
||||
}
|
||||
boolean_ops: dict[Type[ast.boolop], str] = {
|
||||
ast.And: "and",
|
||||
ast.Or: "or",
|
||||
}
|
||||
|
||||
|
||||
class PythonAstJsonSerializer(
|
||||
Stmt.Visitor[dict], Expr.Visitor[dict], MidasType.Visitor[dict]
|
||||
):
|
||||
"""An AST serializer which produces a JSON-compatible structure"""
|
||||
|
||||
def serialize(self, stmts: list[Stmt]) -> list[dict]:
|
||||
return [stmt.accept(self) for stmt in stmts]
|
||||
|
||||
def _serialize_optional(
|
||||
self, element: Optional[Stmt | Expr | MidasType]
|
||||
) -> Optional[dict]:
|
||||
if element is None:
|
||||
return None
|
||||
return element.accept(self)
|
||||
|
||||
def _serialize_list(
|
||||
self, elements: Sequence[Stmt | Expr | MidasType]
|
||||
) -> list[dict]:
|
||||
return [element.accept(self) for element in elements]
|
||||
|
||||
def visit_base_type(self, node: BaseType) -> dict:
|
||||
return {
|
||||
"_type": "BaseType",
|
||||
"base": node.base,
|
||||
"param": self._serialize_optional(node.param),
|
||||
}
|
||||
|
||||
def visit_constraint_type(self, node: ConstraintType) -> dict:
|
||||
return {
|
||||
"_type": "ConstraintType",
|
||||
"type": node.type.accept(self),
|
||||
"constraint": ast.unparse(node.constraint),
|
||||
}
|
||||
|
||||
def visit_frame_column(self, node: FrameColumn) -> dict:
|
||||
return {
|
||||
"_type": "FrameColumn",
|
||||
"name": node.name,
|
||||
"type": self._serialize_optional(node.type),
|
||||
}
|
||||
|
||||
def visit_frame_type(self, node: FrameType) -> dict:
|
||||
return {
|
||||
"_type": "FrameType",
|
||||
"columns": self._serialize_list(node.columns),
|
||||
}
|
||||
|
||||
def visit_expression_stmt(self, stmt: ExpressionStmt) -> dict:
|
||||
return {
|
||||
"_type": "ExpressionStmt",
|
||||
"expr": stmt.expr.accept(self),
|
||||
}
|
||||
|
||||
def _serialize_argument(self, arg: Function.Argument) -> dict:
|
||||
return {
|
||||
"name": arg.name,
|
||||
"type": self._serialize_optional(arg.type),
|
||||
"default": self._serialize_optional(arg.default),
|
||||
}
|
||||
|
||||
def visit_function(self, stmt: Function) -> dict:
|
||||
return {
|
||||
"_type": "Function",
|
||||
"name": stmt.name,
|
||||
"posonlyargs": [self._serialize_argument(arg) for arg in stmt.posonlyargs],
|
||||
"args": [self._serialize_argument(arg) for arg in stmt.args],
|
||||
"sink": (
|
||||
self._serialize_argument(stmt.sink) if stmt.sink is not None else None
|
||||
),
|
||||
"kwonlyargs": [self._serialize_argument(arg) for arg in stmt.kwonlyargs],
|
||||
"kw_sink": (
|
||||
self._serialize_argument(stmt.kw_sink)
|
||||
if stmt.kw_sink is not None
|
||||
else None
|
||||
),
|
||||
"returns": self._serialize_optional(stmt.returns),
|
||||
"body": self._serialize_list(stmt.body),
|
||||
}
|
||||
|
||||
def visit_type_assign(self, stmt: TypeAssign) -> dict:
|
||||
return {
|
||||
"_type": "TypeAssign",
|
||||
"name": stmt.name,
|
||||
"type": stmt.type.accept(self),
|
||||
}
|
||||
|
||||
def visit_assign_stmt(self, stmt: AssignStmt) -> dict:
|
||||
return {
|
||||
"_type": "AssignStmt",
|
||||
"targets": self._serialize_list(stmt.targets),
|
||||
"value": stmt.value.accept(self),
|
||||
}
|
||||
|
||||
def visit_return_stmt(self, stmt: ReturnStmt) -> dict:
|
||||
return {
|
||||
"_type": "ReturnStmt",
|
||||
"value": self._serialize_optional(stmt.value),
|
||||
}
|
||||
|
||||
def visit_if_stmt(self, stmt: IfStmt) -> dict:
|
||||
return {
|
||||
"_type": "IfStmt",
|
||||
"test": stmt.test.accept(self),
|
||||
"body": self._serialize_list(stmt.body),
|
||||
"orelse": self._serialize_list(stmt.orelse),
|
||||
}
|
||||
|
||||
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
|
||||
return {
|
||||
"_type": "BinaryExpr",
|
||||
"left": expr.left.accept(self),
|
||||
"operator": binary_ops[expr.operator.__class__],
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_compare_expr(self, expr: CompareExpr) -> dict:
|
||||
return {
|
||||
"_type": "CompareExpr",
|
||||
"left": expr.left.accept(self),
|
||||
"operator": compare_ops[expr.operator.__class__],
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_unary_expr(self, expr: UnaryExpr) -> dict:
|
||||
return {
|
||||
"_type": "UnaryExpr",
|
||||
"operator": unary_ops[expr.operator.__class__],
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_call_expr(self, expr: CallExpr) -> dict:
|
||||
return {
|
||||
"_type": "CallExpr",
|
||||
"callee": expr.callee.accept(self),
|
||||
"arguments": self._serialize_list(expr.arguments),
|
||||
"keywords": {name: arg.accept(self) for name, arg in expr.keywords.items()},
|
||||
}
|
||||
|
||||
def visit_get_expr(self, expr: GetExpr) -> dict:
|
||||
return {
|
||||
"_type": "GetExpr",
|
||||
"object": expr.object.accept(self),
|
||||
"name": expr.name,
|
||||
}
|
||||
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> dict:
|
||||
return {
|
||||
"_type": "LiteralExpr",
|
||||
"value": expr.value,
|
||||
}
|
||||
|
||||
def visit_variable_expr(self, expr: VariableExpr) -> dict:
|
||||
return {
|
||||
"_type": "VariableExpr",
|
||||
"name": expr.name,
|
||||
}
|
||||
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
||||
return {
|
||||
"_type": "LogicalExpr",
|
||||
"left": expr.left.accept(self),
|
||||
"operator": boolean_ops[expr.operator.__class__],
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_set_expr(self, expr: SetExpr) -> dict:
|
||||
return {
|
||||
"_type": "SetExpr",
|
||||
"object": expr.object.accept(self),
|
||||
"name": expr.name,
|
||||
"value": expr.value.accept(self),
|
||||
}
|
||||
|
||||
def visit_cast_expr(self, expr: CastExpr) -> dict:
|
||||
return {
|
||||
"_type": "CastExpr",
|
||||
"type": expr.type.accept(self),
|
||||
"expr": expr.expr.accept(self),
|
||||
}
|
||||
|
||||
def visit_ternary_expr(self, expr: TernaryExpr) -> dict:
|
||||
return {
|
||||
"_type": "TernaryExpr",
|
||||
"test": expr.test.accept(self),
|
||||
"if_true": expr.if_true.accept(self),
|
||||
"if_false": expr.if_false.accept(self),
|
||||
}
|
||||
@@ -31,22 +31,32 @@
|
||||
]
|
||||
},
|
||||
"type-base": {
|
||||
"begin": "<",
|
||||
"end": ">",
|
||||
"begin": "(\\()([a-zA-Z_][a-zA-Z_\\d]*)(\\))",
|
||||
"end": "$",
|
||||
"beginCaptures": {
|
||||
"0": {
|
||||
"1": {
|
||||
"name": "punctuation.definition.base.begin.midas"
|
||||
}
|
||||
},
|
||||
"endCaptures": {
|
||||
"0": {
|
||||
"2": {
|
||||
"name": "variable.name"
|
||||
},
|
||||
"3": {
|
||||
"name": "punctuation.definition.base.end.midas"
|
||||
}
|
||||
},
|
||||
"patterns": [
|
||||
{"include": "source.python"}
|
||||
{ "include": "#type-cond" }
|
||||
]
|
||||
},
|
||||
"type-cond": {
|
||||
"begin": "where",
|
||||
"end": "$",
|
||||
"beginCaptures": {
|
||||
"0": {
|
||||
"name": "keyword.control.where.midas"
|
||||
}
|
||||
}
|
||||
},
|
||||
"type-body": {
|
||||
"begin": "\\{",
|
||||
"end": "\\}",
|
||||
@@ -61,7 +71,8 @@
|
||||
}
|
||||
},
|
||||
"patterns": [
|
||||
{"include": "#type-prop"}
|
||||
{"include": "#type-prop"},
|
||||
{"include": "#comment"}
|
||||
]
|
||||
},
|
||||
"type-prop": {
|
||||
@@ -78,44 +89,67 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"op-def": {
|
||||
"match": "\\b(op)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>\\s+(\\S+)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>\\s+(=)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>",
|
||||
"captures": {
|
||||
"1": {
|
||||
"name": "keyword.control.op.midas"
|
||||
},
|
||||
"2": {
|
||||
"name" : "variable.name"
|
||||
},
|
||||
"3": {
|
||||
"name" : "keyword.operator"
|
||||
},
|
||||
"4": {
|
||||
"name" : "variable.name"
|
||||
},
|
||||
"5": {
|
||||
"name" : "keyword.operator.assignment"
|
||||
},
|
||||
"6": {
|
||||
"name" : "variable.name"
|
||||
}
|
||||
},
|
||||
"patterns": [
|
||||
{ "include": "#type-base" },
|
||||
{ "include": "#type-body" }
|
||||
]
|
||||
},
|
||||
"constr-def": {
|
||||
"begin": "(constraint)\\s+([a-zA-Z_][a-zA-Z_\\d]*)\\s*(=)",
|
||||
"end": "$",
|
||||
"extend-def": {
|
||||
"begin": "\\b(extend)\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\s+(\\{)",
|
||||
"end": "\\}",
|
||||
"beginCaptures": {
|
||||
"1": {
|
||||
"name": "keyword.control.constr.midas"
|
||||
"name": "keyword.control.extend.midas"
|
||||
},
|
||||
"2": {
|
||||
"name": "variable.name"
|
||||
},
|
||||
"3": {
|
||||
"name": "punctuation.definition.extend-body.begin.midas"
|
||||
}
|
||||
},
|
||||
"endCaptures": {
|
||||
"0": {
|
||||
"name": "punctuation.definition.extend-body.end.midas"
|
||||
}
|
||||
},
|
||||
"patterns": [
|
||||
{"include": "#op-def"},
|
||||
{"include": "#comment"}
|
||||
]
|
||||
},
|
||||
"op-def": {
|
||||
"match": "\\b(op)\\s+(\\S+)\\s*\\(\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\s*\\)\\s*(->)\\s*([a-zA-Z_][a-zA-Z_\\d]*)",
|
||||
"captures": {
|
||||
"1": {
|
||||
"name": "keyword.control.op.midas"
|
||||
},
|
||||
"2": {
|
||||
"name" : "keyword.operator"
|
||||
},
|
||||
"3": {
|
||||
"name" : "variable.name"
|
||||
},
|
||||
"4": {
|
||||
"name" : "keyword.operator.assignment"
|
||||
},
|
||||
"5": {
|
||||
"name" : "variable.name"
|
||||
}
|
||||
}
|
||||
},
|
||||
"pred-def": {
|
||||
"begin": "(predicate)\\s+([a-zA-Z_][a-zA-Z_\\d]*)\\(([a-zA-Z_][a-zA-Z_\\d]*):\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\)\\s*(=)",
|
||||
"end": "$",
|
||||
"beginCaptures": {
|
||||
"1": {
|
||||
"name": "keyword.control.pred.midas"
|
||||
},
|
||||
"2": {
|
||||
"name": "variable.name"
|
||||
},
|
||||
"3": {
|
||||
"name": "variable.name"
|
||||
},
|
||||
"4": {
|
||||
"name": "variable.name"
|
||||
},
|
||||
"5": {
|
||||
"name": "keyword.operator.assignment"
|
||||
}
|
||||
},
|
||||
@@ -127,8 +161,8 @@
|
||||
"patterns": [
|
||||
{ "include": "#comment" },
|
||||
{ "include": "#type-def" },
|
||||
{ "include": "#op-def" },
|
||||
{ "include": "#constr-def" }
|
||||
{ "include": "#extend-def" },
|
||||
{ "include": "#pred-def" }
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user