Merge pull request 'Basic type checker' (#6) from feat/basic-type-checker into main

Reviewed-on: #6
This commit was merged in pull request #6.
This commit is contained in:
2026-06-05 09:31:53 +00:00
61 changed files with 4539 additions and 1079 deletions

1
.gitignore vendored
View File

@@ -6,3 +6,4 @@ venv
*.pyc
uv.lock
.python-version
/out

150
docs/architecture.typ Normal file
View File

@@ -0,0 +1,150 @@
#import "@preview/cetz:0.5.2": canvas, draw
#let diagram-only = false
#set document(
title: [Midas Architecture],
//author: "Louis Heredero",
)
#set text(
font: "Source Sans 3",
)
#let diagram = canvas({
let framed = draw.content.with(
padding: (x: .8em, y: 1em),
frame: "rect",
stroke: black,
)
let arrow = draw.line.with(mark: (end: ">", fill: black))
framed(
(0, 0),
name: "python-parser",
)[Python parser]
draw.content(
(rel: (0, 1), to: "python-parser.north"),
padding: 5pt,
anchor: "south",
name: "source-py",
)[_`source.py`_]
arrow("source-py", "python-parser")
framed(
(rel: (3, 0), to: "python-parser.east"),
anchor: "west",
name: "custom-parser",
align(center)[Custom python\ parser],
)
arrow("python-parser", "custom-parser", name: "arrow-python-ast")
draw.content(
"arrow-python-ast",
anchor: "south",
padding: 5pt,
)[`ast.Module`]
framed(
(rel: (-3, -2), to: "custom-parser.south"),
anchor: "east",
name: "python-resolver",
)[Python Resolver]
arrow(
"custom-parser",
((), "|-", "python-resolver.east"),
"python-resolver",
name: "arrow-python-custom-ast",
)
draw.content(
(rel: (1.5, 0), to: "arrow-python-custom-ast.end"),
padding: 5pt,
anchor: "south",
)[P-AST#footnote[#strong[P]ython *AST*]<fn-past>]
draw.content(
"python-resolver.west",
padding: 5pt,
anchor: "south-east",
)[Resolved P-AST@fn-past]
draw.circle(
(rel: (1, -2), to: "custom-parser.south-east"),
radius: .4,
name: "midas-loader",
)
arrow(
"custom-parser",
"midas-loader",
name: "arrow-load-midas",
mark: (end: (symbol: ">", fill: black), start: "o"),
)
draw.content(
"arrow-load-midas",
anchor: "west",
padding: 5pt,
)[```python midas.using("types.midas")```]
framed(
(rel: (0, -2), to: "midas-loader.south"),
name: "midas-parser",
)[Midas lexer/parser]
arrow("midas-loader", "midas-parser", name: "arrow-midas-source")
draw.content(
"arrow-midas-source",
anchor: "west",
padding: 5pt,
)[_`types.midas`_]
framed(
(rel: (-2, 0), to: "midas-parser.west"),
anchor: "east",
name: "midas-resolver",
)[Midas Resolver]
arrow("midas-parser", "midas-resolver", name: "arrow-midas-ast")
draw.content(
"arrow-midas-ast",
anchor: "south",
padding: 5pt,
)[M-AST#footnote[#strong[M]idas *AST*]<fn-mast>]
framed(
(rel: (-3, 0), to: "midas-resolver.west"),
anchor: "east",
name: "checker",
)[Checker]
arrow("midas-resolver", "checker", name: "arrow-type-ctx")
arrow(
"python-resolver",
((), "-|", "checker.north"),
"checker",
)
draw.content(
"arrow-type-ctx",
anchor: "south",
padding: 5pt,
)[Types context]
})
#show: doc => if diagram-only {
set page(width: auto, height: auto, margin: .5cm)
diagram
} else { doc }
#align(center, title())
#v(1cm)
#figure(
diagram,
caption: [Midas type-checker architecture],
)
== Components
- *Python parser*: builtin Python AST parser, extracts abstract syntax from the raw Python source (```python ast.parse(...)```)
- *Custom python parser*: converts the raw Python AST into custom, more suitable constructs, especially for type annotations
- *Python resolver*: resolves bindings and references, tracks binding scopes
- *Midas lexer/parser*: parses a Midas type definition file and extracts its AST
- *Midas resolver*: walks the AST and fills the environment with the defined types and operations
- *Checker*: evaluates expressions and checks type coherence

View File

@@ -2,10 +2,6 @@
# ruff: disable[F821]
from __future__ import annotations
# Prototype of custom type import to use valid Python syntax
import midas
midas.using("02_custom_types.midas")
# A data-frame using a custom type
df: Frame[
location: GeoLocation

View File

@@ -0,0 +1,33 @@
type Foo1 = float
type Foo2 = float where (_ > 3)
type Foo3 = int | float
type Foo4 = int where (_ > 3) | float where (_ > 3)
type Foo5 = (int | float) where (_ > 3)
type Foo6 = {
foo: float
bar: float where (_ > 3)
}
type Foo7[T] = T where (_ > 3)
type Foo8[A, B<:int] = {
a: A
b: B
}
type Complex = {
a: int
b: int
}
type Complex2 = Complex where (_.a > 3 & _.b < 5)
predicate Positive(n: int) = n >= 0
extend Foo1 {
op __add__(Foo1) -> Foo1
}
extend Foo7[T] {
op __add__(Foo7[T]) -> Foo7[T]
}
type Optional[T] = None | T

View File

@@ -0,0 +1,11 @@
a: int = 3
b: int = 4
c = a + b # -> int
c = "invalid" # -> can't assign str to int variable
d = True
e = d + d
f: float = a

View File

@@ -0,0 +1,14 @@
type Meter = float
type Second = float
type MeterPerSecond = float
extend Meter {
op __add__(Meter) -> Meter
op __sub__(Meter) -> Meter
op __truediv__(Second) -> MeterPerSecond
}
extend Second {
op __add__(Second) -> Second
op __sub__(Second) -> Second
}

View File

@@ -0,0 +1,6 @@
# type: ignore
# ruff: disable [F821]
distance: Meter = cast(Meter, 123.45)
time: Second = cast(Second, 6.7)
speed = distance / time

View File

@@ -0,0 +1,23 @@
def minimum(x: int, y: int):
if x < y:
return x
else:
return y
a = 15
b = 72
c = minimum(a, b)
def factorial(n: int) -> int:
if n <= 1:
return 1
return n * factorial(n - 1)
category = "Category 1" if a < 10 else "Category 2"
def foo() -> None:
pass

View File

@@ -1,5 +1,5 @@
from pathlib import Path
import re
from pathlib import Path
HEADER = '''"""
This file was generated by a script. Any manual changes might be overwritten.
@@ -11,7 +11,7 @@ SECTION_TEMPLATE = """{banner}
@dataclass(frozen=True, kw_only=True)
class {base}(ABC):
location: Optional[Location] = None
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...

View File

@@ -13,40 +13,38 @@ from midas.lexer.token import Token
###> Stmt | Statements
class SimpleTypeStmt:
class TypeStmt:
name: Token
template: Optional[TemplateExpr]
base: TypeExpr
constraint: Optional[Expr]
params: list[Param]
type: Type
class ComplexTypeStmt:
@dataclass(frozen=True, kw_only=True)
class Param:
location: Location
name: Token
template: Optional[TemplateExpr]
properties: list[PropertyStmt]
bound: Optional[Type]
class PropertyStmt:
name: Token
type: TypeExpr
constraint: Optional[Expr]
type: Type
class ExtendStmt:
type: TypeExpr
type: Type
operations: list[OpStmt]
class OpStmt:
name: Token
operand: TypeExpr
result: TypeExpr
operand: Type
result: Type
class PredicateStmt:
name: Token
subject: Token
type: TypeExpr
type: Type
condition: Expr
@@ -54,9 +52,6 @@ class PredicateStmt:
###> Expr | Expressions
class SimpleTypeExpr:
name: Token
optional: bool
class LogicalExpr:
@@ -97,14 +92,27 @@ class WildcardExpr:
token: Token
class TemplateExpr:
type: TypeExpr
###<
###> Type | Types
class TypeExpr:
class NamedType:
name: Token
template: Optional[TemplateExpr]
optional: bool
class GenericType:
type: Type
params: list[Type]
class ConstraintType:
type: Type
constraint: Expr
class ComplexType:
properties: list[PropertyStmt]
###<

View File

@@ -44,14 +44,22 @@ class Function:
name: str
posonlyargs: list[Argument]
args: list[Argument]
sink: Optional[Argument]
kwonlyargs: list[Argument]
kw_sink: Optional[Argument]
returns: Optional[MidasType]
body: list[Stmt]
@dataclass(frozen=True, kw_only=True)
class Argument:
location: Optional[Location] = None
name: Optional[str]
name: str
type: Optional[MidasType]
default: Optional[Expr]
@property
def all_args(self) -> list[Argument]:
return self.posonlyargs + self.args + self.kwonlyargs
class TypeAssign:
@@ -64,6 +72,16 @@ class AssignStmt:
value: Expr
class ReturnStmt:
value: Optional[Expr]
class IfStmt:
test: Expr
body: list[Stmt]
orelse: list[Stmt]
###<
@@ -116,4 +134,15 @@ class SetExpr:
value: Expr
class CastExpr:
type: MidasType
expr: Expr
class TernaryExpr:
test: Expr
if_true: Expr
if_false: Expr
###<

View File

@@ -21,17 +21,14 @@ T = TypeVar("T")
@dataclass(frozen=True, kw_only=True)
class Stmt(ABC):
location: Optional[Location] = None
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> T: ...
@abstractmethod
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> T: ...
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
@abstractmethod
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
@@ -47,31 +44,25 @@ class Stmt(ABC):
@dataclass(frozen=True)
class SimpleTypeStmt(Stmt):
class TypeStmt(Stmt):
name: Token
template: Optional[TemplateExpr]
base: TypeExpr
constraint: Optional[Expr]
params: list[Param]
type: Type
@dataclass(frozen=True, kw_only=True)
class Param:
location: Location
name: Token
bound: Optional[Type]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_simple_type_stmt(self)
@dataclass(frozen=True)
class ComplexTypeStmt(Stmt):
name: Token
template: Optional[TemplateExpr]
properties: list[PropertyStmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_complex_type_stmt(self)
return visitor.visit_type_stmt(self)
@dataclass(frozen=True)
class PropertyStmt(Stmt):
name: Token
type: TypeExpr
constraint: Optional[Expr]
type: Type
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_property_stmt(self)
@@ -79,7 +70,7 @@ class PropertyStmt(Stmt):
@dataclass(frozen=True)
class ExtendStmt(Stmt):
type: TypeExpr
type: Type
operations: list[OpStmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
@@ -89,8 +80,8 @@ class ExtendStmt(Stmt):
@dataclass(frozen=True)
class OpStmt(Stmt):
name: Token
operand: TypeExpr
result: TypeExpr
operand: Type
result: Type
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_op_stmt(self)
@@ -100,7 +91,7 @@ class OpStmt(Stmt):
class PredicateStmt(Stmt):
name: Token
subject: Token
type: TypeExpr
type: Type
condition: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
@@ -114,15 +105,12 @@ class PredicateStmt(Stmt):
@dataclass(frozen=True, kw_only=True)
class Expr(ABC):
location: Optional[Location] = None
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> T: ...
@abstractmethod
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
@@ -147,21 +135,6 @@ class Expr(ABC):
@abstractmethod
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
@abstractmethod
def visit_template_expr(self, expr: TemplateExpr) -> T: ...
@abstractmethod
def visit_type_expr(self, expr: TypeExpr) -> T: ...
@dataclass(frozen=True)
class SimpleTypeExpr(Expr):
name: Token
optional: bool
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_simple_type_expr(self)
@dataclass(frozen=True)
class LogicalExpr(Expr):
@@ -233,19 +206,61 @@ class WildcardExpr(Expr):
return visitor.visit_wildcard_expr(self)
@dataclass(frozen=True)
class TemplateExpr(Expr):
type: TypeExpr
#########
# Types #
#########
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_template_expr(self)
@dataclass(frozen=True, kw_only=True)
class Type(ABC):
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_named_type(self, type: NamedType) -> T: ...
@abstractmethod
def visit_generic_type(self, type: GenericType) -> T: ...
@abstractmethod
def visit_constraint_type(self, type: ConstraintType) -> T: ...
@abstractmethod
def visit_complex_type(self, type: ComplexType) -> T: ...
@dataclass(frozen=True)
class TypeExpr(Expr):
class NamedType(Type):
name: Token
template: Optional[TemplateExpr]
optional: bool
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_type_expr(self)
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_named_type(self)
@dataclass(frozen=True)
class GenericType(Type):
type: Type
params: list[Type]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_generic_type(self)
@dataclass(frozen=True)
class ConstraintType(Type):
type: Type
constraint: Expr
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_constraint_type(self)
@dataclass(frozen=True)
class ComplexType(Type):
properties: list[PropertyStmt]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_complex_type(self)

View File

@@ -85,40 +85,39 @@ class AstPrinter(Generic[T]):
child.accept(self)
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
class MidasAstPrinter(
AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None], m.Type.Visitor[None]
):
# Statements
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt):
self._write_line("SimpleTypeStmt")
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
self._write_line("TypeStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_optional_child("template", stmt.template)
self._write_line("base")
with self._child_level(single=True):
stmt.base.accept(self)
self._write_optional_child("constraint", stmt.constraint, last=True)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
self._write_line("ComplexTypeStmt")
self._write_line("params")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_optional_child("template", stmt.template)
self._write_line("properties", last=True)
with self._child_level():
for i, prop in enumerate(stmt.properties):
for i, param in enumerate(stmt.params):
self._idx = i
if i == len(stmt.properties) - 1:
if i == len(stmt.params) - 1:
self._mark_last()
prop.accept(self)
self._print_type_stmt_param(param)
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None:
self._write_line("Param")
with self._child_level():
self._write_line(f'name: "{param.name.lexeme}"')
self._write_optional_child("bound", param.bound, last=True)
def visit_property_stmt(self, stmt: m.PropertyStmt):
self._write_line("PropertyStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type")
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
self._write_optional_child("constraint", stmt.constraint, last=True)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self._write_line("ExtendStmt")
@@ -161,12 +160,6 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
# Expressions
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
self._write_line("SimpleTypeExpr")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"')
self._write_line(f"optional: {expr.optional}", last=True)
def visit_logical_expr(self, expr: m.LogicalExpr):
self._write_line("LogicalExpr")
with self._child_level():
@@ -230,22 +223,48 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
self._write_line("WildcardExpr")
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
self._write_line("TemplateExpr")
with self._child_level(single=True):
def visit_named_type(self, type: m.NamedType) -> None:
self._write_line("NamedType")
with self._child_level():
self._write_line(f'name: "{type.name.lexeme}"', last=True)
def visit_generic_type(self, type: m.GenericType) -> None:
self._write_line("GenericType")
with self._child_level():
self._write_line("type")
with self._child_level():
type.type.accept(self)
self._write_line("params", last=True)
with self._child_level():
for i, param in enumerate(type.params):
self._idx = i
if i == len(type.params) - 1:
self._mark_last()
param.accept(self)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self._write_line("ConstraintType")
with self._child_level():
self._write_line("type")
with self._child_level(single=True):
expr.type.accept(self)
type.type.accept(self)
self._write_line("constraint", last=True)
with self._child_level(single=True):
type.constraint.accept(self)
def visit_type_expr(self, expr: m.TypeExpr):
self._write_line("TypeExpr")
def visit_complex_type(self, type: m.ComplexType) -> None:
self._write_line("ComplexType")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"')
self._write_optional_child("template", expr.template)
self._write_line(f"optional: {expr.optional}", last=True)
self._write_line("properties", last=True)
with self._child_level():
for i, prop in enumerate(type.properties):
self._idx = i
if i == len(type.properties) - 1:
self._mark_last()
prop.accept(self)
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
def __init__(self, indent: int = 4):
self.indent: int = indent
self.level: int = 0
@@ -253,33 +272,28 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
def indented(self, text: str) -> str:
return " " * (self.level * self.indent) + text
def print(self, expr: m.Expr | m.Stmt):
def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
self.level = 0
return expr.accept(self)
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt):
template: str = stmt.template.accept(self) if stmt.template is not None else ""
res: str = f"type {stmt.name.lexeme}{template}({stmt.base.accept(self)})"
if stmt.constraint is not None:
res += " where " + stmt.constraint.accept(self)
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
template: str = ""
if len(stmt.params) != 0:
params: list[str] = [
self._print_type_template_param(param) for param in stmt.params
]
template = f"[{', '.join(params)}]"
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
return self.indented(res)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
template: str = stmt.template.accept(self) if stmt.template is not None else ""
res: str = self.indented(f"type {stmt.name.lexeme}{template}")
res += " {\n"
self.level += 1
for prop in stmt.properties:
res += prop.accept(self)
res += "\n"
self.level -= 1
res += self.indented("}")
def _print_type_template_param(self, param: m.TypeStmt.Param) -> str:
res: str = param.name.lexeme
if param.bound is not None:
res += "<:" + param.bound.accept(self)
return res
def visit_property_stmt(self, stmt: m.PropertyStmt):
res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
if stmt.constraint is not None:
res += " where " + stmt.constraint.accept(self)
return self.indented(res)
def visit_extend_stmt(self, stmt: m.ExtendStmt):
@@ -289,13 +303,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
for op in stmt.operations:
res += op.accept(self)
self.level -= 1
res += "\n" + self.indented("}")
res += self.indented("}")
return res
def visit_op_stmt(self, stmt: m.OpStmt):
operand: str = stmt.operand.accept(self)
result: str = stmt.result.accept(self)
return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}")
return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}\n")
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
name: str = stmt.name.lexeme
@@ -304,9 +318,6 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
condition: str = stmt.condition.accept(self)
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
return f"{expr.name.lexeme}{'?' if expr.optional else ''}"
def visit_logical_expr(self, expr: m.LogicalExpr):
left: str = expr.left.accept(self)
operator: str = expr.operator.lexeme
@@ -342,12 +353,30 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
def visit_wildcard_expr(self, expr: m.WildcardExpr):
return "_"
def visit_template_expr(self, expr: m.TemplateExpr):
return f"[{expr.type.accept(self)}]"
def visit_named_type(self, type: m.NamedType) -> str:
return type.name.lexeme
def visit_type_expr(self, expr: m.TypeExpr):
template: str = expr.template.accept(self) if expr.template is not None else ""
return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}"
def visit_generic_type(self, type: m.GenericType) -> str:
res: str = type.type.accept(self)
if len(type.params) != 0:
params: list[str] = [param.accept(self) for param in type.params]
res += f"[{', '.join(params)}]"
return res
def visit_constraint_type(self, type: m.ConstraintType) -> str:
res: str = type.type.accept(self)
res += " where " + type.constraint.accept(self)
return res
def visit_complex_type(self, type: m.ComplexType) -> str:
res: str = "{\n"
self.level += 1
for prop in type.properties:
res += prop.accept(self)
res += "\n"
self.level -= 1
res += self.indented("}")
return res
class PythonAstPrinter(
@@ -419,7 +448,14 @@ class PythonAstPrinter(
self._mark_last()
self._print_argument(arg)
self._write_optional_child("returns", stmt.returns, last=True)
self._write_optional_child("returns", stmt.returns)
self._write_line("body", last=True)
with self._child_level():
for i, body_stmt in enumerate(stmt.body):
self._idx = i
if i == len(stmt.body) - 1:
self._mark_last()
body_stmt.accept(self)
def _print_argument(self, arg: p.Function.Argument) -> None:
self._write_line("FunctionArgument")
@@ -449,6 +485,32 @@ class PythonAstPrinter(
with self._child_level(single=True):
stmt.value.accept(self)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
self._write_line("ReturnStmt")
with self._child_level():
self._write_optional_child("value", stmt.value, last=True)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
self._write_line("IfStmt")
with self._child_level():
self._write_line("test")
with self._child_level(single=True):
stmt.test.accept(self)
self._write_line("body")
with self._child_level():
for i, body_stmt in enumerate(stmt.body):
self._idx = i
if i == len(stmt.body) - 1:
self._mark_last()
body_stmt.accept(self)
self._write_line("orelse", last=True)
with self._child_level():
for i, else_stmt in enumerate(stmt.orelse):
self._idx = i
if i == len(stmt.orelse) - 1:
self._mark_last()
else_stmt.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self._write_line("BinaryExpr")
with self._child_level():
@@ -550,3 +612,28 @@ class PythonAstPrinter(
self._write_line("value", last=True)
with self._child_level(single=True):
expr.value.accept(self)
def visit_cast_expr(self, expr: p.CastExpr) -> None:
self._write_line("CastExpr")
with self._child_level():
self._write_line("type")
with self._child_level(single=True):
expr.type.accept(self)
self._write_line("expr", last=True)
with self._child_level(single=True):
expr.expr.accept(self)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
self._write_line("TernaryExpr")
with self._child_level():
self._write_line("test")
with self._child_level(single=True):
expr.test.accept(self)
self._write_line("if_true")
with self._child_level(single=True):
expr.if_true.accept(self)
self._write_line("if_false", last=True)
with self._child_level(single=True):
expr.if_false.accept(self)

View File

@@ -21,7 +21,7 @@ T = TypeVar("T")
@dataclass(frozen=True, kw_only=True)
class MidasType(ABC):
location: Optional[Location] = None
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
@@ -82,7 +82,7 @@ class FrameType(MidasType):
@dataclass(frozen=True, kw_only=True)
class Stmt(ABC):
location: Optional[Location] = None
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
@@ -100,6 +100,12 @@ class Stmt(ABC):
@abstractmethod
def visit_assign_stmt(self, stmt: AssignStmt) -> T: ...
@abstractmethod
def visit_return_stmt(self, stmt: ReturnStmt) -> T: ...
@abstractmethod
def visit_if_stmt(self, stmt: IfStmt) -> T: ...
@dataclass(frozen=True)
class ExpressionStmt(Stmt):
@@ -114,14 +120,22 @@ class Function(Stmt):
name: str
posonlyargs: list[Argument]
args: list[Argument]
sink: Optional[Argument]
kwonlyargs: list[Argument]
kw_sink: Optional[Argument]
returns: Optional[MidasType]
body: list[Stmt]
@dataclass(frozen=True, kw_only=True)
class Argument:
location: Optional[Location] = None
name: Optional[str]
name: str
type: Optional[MidasType]
default: Optional[Expr]
@property
def all_args(self) -> list[Argument]:
return self.posonlyargs + self.args + self.kwonlyargs
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_function(self)
@@ -145,6 +159,24 @@ class AssignStmt(Stmt):
return visitor.visit_assign_stmt(self)
@dataclass(frozen=True)
class ReturnStmt(Stmt):
value: Optional[Expr]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_return_stmt(self)
@dataclass(frozen=True)
class IfStmt(Stmt):
test: Expr
body: list[Stmt]
orelse: list[Stmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_if_stmt(self)
###############
# Expressions #
###############
@@ -152,7 +184,7 @@ class AssignStmt(Stmt):
@dataclass(frozen=True, kw_only=True)
class Expr(ABC):
location: Optional[Location] = None
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
@@ -185,6 +217,12 @@ class Expr(ABC):
@abstractmethod
def visit_set_expr(self, expr: SetExpr) -> T: ...
@abstractmethod
def visit_cast_expr(self, expr: CastExpr) -> T: ...
@abstractmethod
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
@dataclass(frozen=True)
class BinaryExpr(Expr):
@@ -268,3 +306,22 @@ class SetExpr(Expr):
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_set_expr(self)
@dataclass(frozen=True)
class CastExpr(Expr):
type: MidasType
expr: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_cast_expr(self)
@dataclass(frozen=True)
class TernaryExpr(Expr):
test: Expr
if_true: Expr
if_false: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_ternary_expr(self)

540
midas/checker/checker.py Normal file
View File

@@ -0,0 +1,540 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.environment import Environment
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
from midas.checker.types import Function, Type, UnitType, UnknownType
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
from midas.resolver.midas import MidasResolver
class ReturnException(Exception):
pass
@dataclass(frozen=True, kw_only=True)
class MappedArgument:
expr: p.Expr
type: Type
argument: Function.Argument
class Checker(
p.Stmt.Visitor[None],
p.Expr.Visitor[Type],
p.MidasType.Visitor[Type],
):
"""A type checker which can use custom type definitions"""
def __init__(
self,
locals: dict[p.Expr, int],
source_path: Path,
types_paths: list[Path],
):
self.logger: logging.Logger = logging.getLogger("Checker")
self.source_path: Path = source_path
self.types_paths: list[Path] = types_paths
self.ctx: MidasResolver = MidasResolver()
self.global_env: Environment = Environment()
self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = locals
self.diagnostics: list[Diagnostic] = []
def diagnostic(self, type: DiagnosticType, location: Location, message: str):
self.diagnostics.append(
Diagnostic(
file_path=self.source_path,
location=location,
type=type,
message=message,
)
)
def error(self, location: Location, message: str):
self.diagnostic(
type=DiagnosticType.ERROR,
location=location,
message=message,
)
def warning(self, location: Location, message: str):
self.diagnostic(
type=DiagnosticType.WARNING,
location=location,
message=message,
)
def info(self, location: Location, message: str):
self.diagnostic(
type=DiagnosticType.INFO,
location=location,
message=message,
)
def type_of(self, expr: p.Expr) -> Type:
"""Evaluate the type of an expression
Args:
expr (p.Expr): the expression to evaluate
Returns:
Type: the type of the given expression
"""
return expr.accept(self)
def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
"""Evaluate a sequence of statements
Args:
block (list[p.Stmt]): the statements to evaluate
env (Environment): the environment in which to evaluate
Returns:
bool: whether a return statement is present in the block
"""
previous_env: Environment = self.env
self.env = env
returned: bool = False
for i, stmt in enumerate(block):
try:
stmt.accept(self)
except ReturnException:
returned = True
if i < len(block) - 1:
self.warning(block[i + 1].location, "Unreachable statement")
break
self.env = previous_env
return returned
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
"""Type check a sequence of statements and returns diagnostics
Args:
statements (list[p.Stmt]): the statements to evaluate and check
Returns:
list[Diagnostic]: the list of diagnostics (errors, warning, etc.)
"""
self.diagnostics = []
for path in self.types_paths:
self.import_midas(path)
self.logger.debug(f"Midas types: {self.ctx._types}")
self.logger.debug(f"Midas operations: {self.ctx._operations}")
for stmt in statements:
stmt.accept(self)
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
return self.diagnostics
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
"""Look up a variable in the environment it was declared
Args:
name (str): the name of the variable
expr (p.Expr): the variable expression, used to lookup the scope distance
Returns:
Optional[Type]: the type of the variable, or None if it was not found
"""
distance: Optional[int] = self.locals.get(expr)
if distance is not None:
return self.env.get_at(distance, name)
return self.global_env.get(name)
def import_midas(self, path: Path) -> None:
"""Import Midas definitions from a path
Args:
path (Path): the import path
"""
self.logger.debug(f"Importing type definitions from {path}")
lexer: MidasLexer = MidasLexer(path.read_text())
tokens: list[Token] = lexer.process()
parser: MidasParser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
self.ctx.resolve(stmts)
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
self.type_of(stmt.expr)
def visit_function(self, stmt: p.Function) -> None:
env: Environment = Environment(self.env)
pos_args: list[Function.Argument] = []
args: list[Function.Argument] = []
kw_args: list[Function.Argument] = []
def eval_arg_type(arg: p.Function.Argument) -> Type:
if arg.type is not None:
return arg.type.accept(self)
if arg.default is not None:
return arg.default.accept(self)
return UnknownType()
for arg in stmt.posonlyargs:
pos_args.append(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
for arg in stmt.args:
args.append(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
for arg in stmt.kwonlyargs:
kw_args.append(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
for arg in pos_args + args + kw_args:
env.define(arg.name, arg.type)
returns_hint: Optional[Type] = None
if stmt.returns is not None:
returns_hint = stmt.returns.accept(self)
# Early define to handle simple fully-typed recursion
inside_function: Function = Function(
name=stmt.name,
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns_hint,
)
self.env.define(stmt.name, inside_function)
returned: bool = self.process_block(stmt.body, env)
inferred_return: Type = UnknownType()
if not returned:
env.return_types.append(UnitType())
return_types: set[Type] = set(env.return_types)
if len(return_types) == 1:
inferred_return = list(return_types)[0]
elif len(return_types) > 1:
self.error(
stmt.location,
f"Mixed return types: {env.return_types}",
)
returns: Type = UnknownType()
if returns_hint is not None:
assert stmt.returns is not None
returns = returns_hint
if returns != inferred_return:
self.error(
stmt.returns.location,
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
)
else:
returns = inferred_return
# TODO: handle *args and **kwargs sinks
function: Function = Function(
name=stmt.name,
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns,
)
self.env.define(stmt.name, function)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
# TODO check not yet defined locally
type: Type = stmt.type.accept(self)
self.env.define(stmt.name, type)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
value: Type = self.type_of(stmt.value)
for target in stmt.targets:
if not isinstance(target, p.VariableExpr):
self.logger.warning(f"Unsupported assignment to {target}")
self.warning(target.location, f"Unsupported assignment to {target}")
continue
name: str = target.name
var_type: Optional[Type] = self.look_up_variable(name, target)
if var_type is None:
self.env.define(name, value)
else:
# TODO: implement real comparison method
if var_type != value:
self.error(
stmt.location,
f"Cannot assign {value} to {name} of type {var_type}",
)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
self.env.return_types.append(type)
raise ReturnException()
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
# Not evaluated in sub-environment because assignments in the test leak out of the if
# For example:
# if (m := 1 + 1) < 2:
# ...
# print(m) # <- m is still defined
test_type: Type = stmt.test.accept(self)
# TODO Allow subtypes or any type
if test_type != self.ctx.get_type("bool"):
self.error(
stmt.test.location, f"If test must be a boolean, got {test_type}"
)
env: Environment = Environment(self.env)
body_returned: bool = self.process_block(stmt.body, env)
else_returned: bool = self.process_block(stmt.orelse, env)
self.env.return_types.extend(env.return_types)
if body_returned and else_returned:
raise ReturnException()
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.warning(expr.location, f"Unsupported operator {expr.operator}")
return UnknownType()
left: Type = self.type_of(expr.left)
right: Type = self.type_of(expr.right)
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
if result is None:
self.error(
expr.location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
return result
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.warning(expr.location, f"Unsupported operator {expr.operator}")
return UnknownType()
left: Type = self.type_of(expr.left)
right: Type = self.type_of(expr.right)
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
if result is None:
self.error(
expr.location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
return result
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
def visit_call_expr(self, expr: p.CallExpr) -> Type:
callee: Type = self.type_of(expr.callee)
if not isinstance(callee, Function):
self.error(expr.callee.location, "Callee is not a function")
return UnknownType()
function: Function = callee
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
for arg in mapped:
if arg.type != arg.argument.type:
self.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
)
return function.returns
def visit_get_expr(self, expr: p.GetExpr) -> Type: ...
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
match expr.value:
case bool(): # Must be before int
return self.ctx.get_type("bool")
case int():
return self.ctx.get_type("int")
case float():
return self.ctx.get_type("float")
case str():
return self.ctx.get_type("str")
case _:
self.warning(expr.location, f"Unknown literal {expr}")
return UnknownType()
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
return self.look_up_variable(expr.name, expr) or UnknownType()
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
left: Type = expr.left.accept(self)
right: Type = expr.right.accept(self)
# TODO: union type
if left != right:
self.error(
expr.location,
f"Operands must be of the same type, left={left} != right={right}",
)
return left
def visit_set_expr(self, expr: p.SetExpr) -> Type: ...
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
return expr.type.accept(self)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
test_type: Type = expr.test.accept(self)
# TODO Allow subtypes or any type
if test_type != self.ctx.get_type("bool"):
self.error(
expr.test.location, f"If test must be a boolean, got {test_type}"
)
true_type: Type = expr.if_true.accept(self)
false_type: Type = expr.if_false.accept(self)
if true_type != false_type:
self.error(
expr.location,
f"Type mismatch in ternary if branches: true={true_type} != false={false_type}",
)
return UnknownType()
return true_type
def visit_base_type(self, node: p.BaseType) -> Type:
return self.ctx.get_type(node.base)
def visit_constraint_type(self, node: p.ConstraintType) -> Type: ...
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
def visit_frame_type(self, node: p.FrameType) -> Type: ...
def map_call_arguments(
self, function: Function, call: p.CallExpr
) -> list[MappedArgument]:
"""Map call arguments to function parameters as defined in its signature
This method maps positional-only, keyword-only and mixed parameter definitions
with the arguments passed at the call site
Any mismatched, missing or unexpected argument is reported as a diagnostic
Args:
function (Function): the function definition
call (p.CallExpr): the call expression
Returns:
list[MappedArgument]: the list of mapped arguments
"""
positional: list[tuple[p.Expr, Type]] = [
(arg, self.type_of(arg)) for arg in call.arguments
]
keywords: dict[str, tuple[p.Expr, Type]] = {
name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
}
set_args: set[str] = set()
required_positional: list[str] = [
arg.name for arg in function.pos_args + function.args if arg.required
]
required_keyword: list[str] = [
arg.name for arg in function.kw_args if arg.required
]
mapped: list[MappedArgument] = []
pos_params: list[Function.Argument] = list(function.pos_args)
mixed_params: list[Function.Argument] = list(function.args)
kw_params: dict[str, Function.Argument] = {
arg.name: arg for arg in function.kw_args
}
# TODO: handle *args and **kwargs sinks
for arg in positional:
param: Function.Argument
if len(pos_params) != 0:
param = pos_params.pop(0)
elif len(mixed_params) != 0:
param = mixed_params.pop(0)
else:
self.error(arg[0].location, "Too many positional arguments")
break
name: str = param.name
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
kw_params.update({arg.name: arg for arg in mixed_params})
for name, arg in keywords.items():
param: Function.Argument
if name not in kw_params:
if name in set_args:
self.error(
arg[0].location, f"Multiple values for argument '{name}'"
)
else:
self.error(arg[0].location, f"Unknown keyword argument '{name}'")
continue
param = kw_params.pop(name)
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
def join_args(args: list[str]) -> str:
args = list(map(lambda a: f"'{a}'", args))
if len(args) == 0:
return ""
if len(args) == 1:
return args[0]
return ", ".join(args[:-1]) + " and " + args[-1]
if len(required_positional) != 0:
plural: str = "" if len(required_positional) == 1 else "s"
args: str = join_args(required_positional)
self.error(
call.location,
f"Missing required positional argument{plural}: {args}",
)
if len(required_keyword) != 0:
plural: str = "" if len(required_keyword) == 1 else "s"
args: str = join_args(required_keyword)
self.error(
call.location,
f"Missing required keyword argument{plural}: {args}",
)
return mapped

View File

@@ -0,0 +1,33 @@
from dataclasses import dataclass
from enum import StrEnum
from pathlib import Path
from typing import Optional
from midas.ast.location import Location
class DiagnosticType(StrEnum):
ERROR = "Error"
WARNING = "Warning"
INFO = "Info"
@dataclass(frozen=True)
class Diagnostic:
file_path: Path
location: Location
type: DiagnosticType
message: str
def __str__(self) -> str:
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
end_loc: Optional[str] = ""
if (
self.location.end_lineno is not None
and self.location.end_col_offset is not None
):
end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}"
loc: str = (
f"at {start_loc}" if end_loc is None else f"from {start_loc} to {end_loc}"
)
return f"{self.type} in {self.file_path} {loc}: {self.message}"

View File

@@ -0,0 +1,142 @@
from __future__ import annotations
from typing import Optional
from midas.checker.types import Type
class Environment:
"""
A scoped environment in which variables are defined
Each environment can inherit from a parent/enclosing environment.
"""
def __init__(self, enclosing: Optional[Environment] = None) -> None:
self.enclosing: Optional[Environment] = enclosing
self.values: dict[str, Type] = {}
self.return_types: list[Type] = []
self._children: list[Environment] = []
if enclosing is not None:
enclosing._children.append(self)
def define(self, name: str, value: Type) -> None:
"""Define a variable in this environment
Args:
name (str): the name of the variable
value (Type): the value
"""
self.values[name] = value
def get(self, name: str) -> Optional[Type]:
"""Get a variable in the closest environment which has a definition for it
Args:
name (str): the name of the variable
Returns:
Optional[Type]: the value of the variable, or None if it was not found
"""
if name in self.values:
return self.values[name]
if self.enclosing is not None:
return self.enclosing.get(name)
# raise NameError(f"Undefined variable '{name}'")
return None
def assign(self, name: str, value: Type) -> bool:
"""Assign a new value to a variable in the environment it was defined in
Args:
name (str): the name of the variable
value (Type): the new value
Returns:
bool: True if the variable was assigned in this environment or an ancestor, False otherwise
"""
if name not in self.values:
if self.enclosing is None:
return False
if self.enclosing.assign(name, value):
return True
self.values[name] = value
return True
def clear(self):
"""Clear all definitions in this environment"""
self.values = {}
def get_at(self, distance: int, name: str) -> Optional[Type]:
"""Get the value of a variable at a given distance
A distance of 0 looks up in this environment, 1 in the parent environment, etc.
This methods expects `distance` to be valid. An error will be raised if
the stack does not extend far enough to reach `distance`
Args:
distance (int): the scope distance
name (str): the name of the variable
Returns:
Optional[Type]: the value at the given distance, or None if it is not defined in that environment
Raises:
AssertionError: if the stack does not extend far enough to reach `distance`
"""
return self.ancestor(distance).values.get(name)
def assign_at(self, distance: int, name: str, value: Type) -> None:
"""Assign a new value to a variable at a given distance
A distance of 0 assigns in this environment, 1 in the parent environment, etc.
Args:
distance (int): the scope distance
name (str): the name of the variable
value (Type): the new value
Raises:
AssertionError: if the stack does not extend far enough to reach `distance`
"""
self.ancestor(distance).values[name] = value
def ancestor(self, distance: int) -> Environment:
"""Get the ancestor at a given distance
A distance of 0 references this environment, 1 the parent environment, etc.
Args:
distance (int): the scope distance
Returns:
Environment: the environment
Raises:
AssertionError: if the stack does not extend far enough to reach `distance`
"""
env: Environment = self
for _ in range(distance):
assert env.enclosing is not None
env = env.enclosing
return env
def flat_dict(self) -> dict[str, Type]:
"""Get the current environment including definitions in its ancestor as a flat dictionary
This method recursively combines this environment definitions with its ancestor's
Returns:
dict: the combined environment
"""
if self.enclosing is None:
return self.values
return self.enclosing.flat_dict() | self.values
def dump(self) -> dict:
return {
"values": self.values,
"return_types": self.return_types,
"children": [child.dump() for child in self._children],
}

View File

@@ -0,0 +1,31 @@
import ast
from typing import Type
OPERATOR_METHODS: dict[Type[ast.operator], str] = {
ast.Add: "__add__",
ast.Sub: "__sub__",
ast.Mult: "__mul__",
ast.MatMult: "__matmul__",
ast.Div: "__truediv__",
ast.Mod: "__mod__",
ast.Pow: "__pow__",
ast.LShift: "__lshift__",
ast.RShift: "__rshift__",
ast.BitOr: "__or__",
ast.BitXor: "__xor__",
ast.BitAnd: "__and__",
ast.FloorDiv: "__floordiv__",
}
COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
ast.Eq: "__eq__",
# ast.NotEq: "__noteq__",
ast.Lt: "__lt__",
ast.LtE: "__le__",
ast.Gt: "__gt__",
ast.GtE: "__ge__",
# ast.Is: "__is__",
# ast.IsNot: "__isnot__",
# ast.In: "__in__",
# ast.NotIn: "__notin__",
}

47
midas/checker/types.py Normal file
View File

@@ -0,0 +1,47 @@
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True, kw_only=True)
class BaseType:
name: str
@dataclass(frozen=True, kw_only=True)
class AliasType:
name: str
type: Type
@dataclass(frozen=True, kw_only=True)
class UnknownType:
pass
@dataclass(frozen=True, kw_only=True)
class UnitType:
pass
@dataclass(frozen=True, kw_only=True)
class Function:
name: str
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
returns: Type
@dataclass(frozen=True, kw_only=True)
class Argument:
name: str
type: Type
required: bool
@dataclass(frozen=True, kw_only=True)
class ComplexType:
properties: dict[str, Type]
Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType

View File

@@ -53,5 +53,6 @@ span {
&.keyword {
color: rgb(211, 72, 9);
pointer-events: none;
}
}

View File

@@ -1,12 +1,15 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Generic, Optional, Protocol, TextIO, TypeVar
from midas.ast.location import Location
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic
from midas.lexer.token import Token
H = TypeVar("H", bound="Highlighter", contravariant=True)
@@ -21,6 +24,15 @@ class Locatable(Protocol):
def location(self) -> Optional[Location]: ...
@dataclass(frozen=True)
class LocatableToken:
token: Token
@property
def location(self) -> Location:
return self.token.get_location()
class Highlighter(ABC):
BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css"
EXTRA_CSS_PATH: Optional[Path] = None
@@ -71,6 +83,7 @@ class Highlighter(ABC):
openings: list[str] = self.openings.get(pos, [])
line_buf += "".join(closings + openings)
line_buf += char
line_buf += "".join(self.closings.get((lineno, len(line)), []))
line_buf += "</div></div>"
lines.append(" " + line_buf)
lines.extend(
@@ -83,7 +96,7 @@ class Highlighter(ABC):
buf.write("\n".join(lines))
def wrap(self, node: Locatable, cls: str):
def wrap(self, node: Locatable, cls: str, message: Optional[str] = None):
if node.location is None:
return
if node.location.end_lineno is None or node.location.end_col_offset is None:
@@ -95,6 +108,10 @@ class Highlighter(ABC):
)
opening: str = f'<span class="{cls}" title="{cls}">'
closing: str = "</span>"
if message is not None:
opening = f'<span class="with-msg">{opening}'
closing = f'{closing}<span class="message">{message}</span></span>'
self.openings.setdefault(start_pos, []).append(opening)
self.closings.setdefault(end_pos, []).insert(0, closing)
if start_pos[0] != end_pos[0]:
@@ -142,6 +159,8 @@ class PythonHighlighter(
self.wrap(stmt, "function")
for arg in stmt.posonlyargs + stmt.args + stmt.kwonlyargs:
self._highlight_function_argument(arg)
for body_stmt in stmt.body:
body_stmt.accept(self)
def _highlight_function_argument(self, arg: p.Function.Argument) -> None:
self.wrap(arg, "argument")
@@ -151,7 +170,23 @@ class PythonHighlighter(
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
stmt.type.accept(self)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: ...
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
for target in stmt.targets:
target.accept(self)
stmt.value.accept(self)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
self.wrap(stmt, "return")
if stmt.value is not None:
stmt.value.accept(self)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
self.wrap(stmt, "if")
stmt.test.accept(self)
for body_stmt in stmt.body:
body_stmt.accept(self)
for else_stmt in stmt.orelse:
else_stmt.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
@@ -159,7 +194,13 @@ class PythonHighlighter(
def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ...
def visit_call_expr(self, expr: p.CallExpr) -> None: ...
def visit_call_expr(self, expr: p.CallExpr) -> None:
self.wrap(expr, "call")
expr.callee.accept(self)
for arg in expr.arguments:
arg.accept(self)
for arg in expr.keywords.values():
arg.accept(self)
def visit_get_expr(self, expr: p.GetExpr) -> None: ...
@@ -171,35 +212,27 @@ class PythonHighlighter(
def visit_set_expr(self, expr: p.SetExpr) -> None: ...
def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
class MidasHighlighter(
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"
def highlight(self, node: Highlightable[MidasHighlighter]):
node.accept(self)
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None:
self.wrap(stmt, "simple-type")
if stmt.template is not None:
stmt.template.accept(self)
stmt.base.accept(self)
if stmt.constraint is not None:
self.wrap(stmt.constraint, "constraint")
stmt.constraint.accept(self)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None:
self.wrap(stmt, "complex-type")
if stmt.template is not None:
stmt.template.accept(self)
for prop in stmt.properties:
prop.accept(self)
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
self.wrap(stmt, "type-stmt")
self.wrap(LocatableToken(stmt.name), "type-name")
stmt.type.accept(self)
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None:
self.wrap(stmt, "property")
stmt.type.accept(self)
if stmt.constraint is not None:
self.wrap(stmt.constraint, "constraint")
stmt.constraint.accept(self)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self.wrap(stmt, "extend")
@@ -209,17 +242,16 @@ class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
self.wrap(stmt, "op")
self.wrap(LocatableToken(stmt.name), "op-name")
stmt.operand.accept(self)
stmt.result.accept(self)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.wrap(stmt, "predicate")
self.wrap(LocatableToken(stmt.name), "predicate-name")
stmt.type.accept(self)
stmt.condition.accept(self)
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> None:
self.wrap(expr, "simple-type-expr")
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.wrap(expr, "logical-expr")
expr.left.accept(self)
@@ -248,11 +280,29 @@ class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
self.wrap(expr, "template")
expr.type.accept(self)
def visit_named_type(self, type: m.NamedType) -> None:
self.wrap(type, "named-type")
def visit_type_expr(self, expr: m.TypeExpr) -> None:
self.wrap(expr, "type")
if expr.template is not None:
expr.template.accept(self)
def visit_generic_type(self, type: m.GenericType) -> None:
self.wrap(type, "generic-type")
type.type.accept(self)
for param in type.params:
param.accept(self)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self.wrap(type, "constraint-type")
type.type.accept(self)
type.constraint.accept(self)
def visit_complex_type(self, type: m.ComplexType) -> None:
self.wrap(type, "complex-type")
for prop in type.properties:
prop.accept(self)
class DiagnosticsHighlighter(Highlighter):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
def highlight(self, diagnostics: list[Diagnostic]):
for diagnostic in diagnostics:
self.wrap(diagnostic, str(diagnostic.type).lower(), diagnostic.message)

View File

@@ -0,0 +1,39 @@
span {
--opacity: 0.4;
&.error {
--col: 255, 0, 0;
}
&.warning {
--col: 250, 160, 0;
}
&.info {
--col: 150, 190, 250;
}
&.with-msg {
position: relative;
.message {
display: none;
}
&:hover:not(:has(.with-msg:hover)) {
.message {
display: inline-block;
}
}
.message {
position: absolute;
top: calc(100% + 0.2em);
left: -.2em;
background-color: black;
color: white;
padding: 0.2em 0.4em;
border-radius: .2em;
z-index: 10;
width: 300%;
}
}
}

View File

@@ -5,12 +5,11 @@ span {
font-style: italic;
}
&.simple-type {
--col: 108, 233, 108;
}
&.named-type,
&.generic-type,
&.constraint-type,
&.complex-type {
--col: 233, 206, 108;
--col: 150, 150, 150;
}
&.constraint {
@@ -33,10 +32,6 @@ span {
--col: 193, 108, 233;
}
&.simple-type-expr {
--col: 150, 150, 150;
}
&.logical-expr,
&.binary-expr,
&.unary-expr,
@@ -48,7 +43,9 @@ span {
--col: 163, 117, 71;
}
&.type {
&.type-name,
&.op-name,
&.predicate-name {
--col: 200, 200, 200;
font-weight: bold;
}

View File

@@ -1,18 +1,30 @@
import ast
from dataclasses import dataclass
from typing import Optional, TextIO
import json
import logging
from pathlib import Path
from typing import Optional, TextIO, get_args
import click
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import PythonAstPrinter
from midas.cli.highlighter import Highlighter, MidasHighlighter, PythonHighlighter
from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
from midas.checker.checker import Checker
from midas.checker.diagnostic import Diagnostic
from midas.checker.types import Type
from midas.cli.highlighter import (
DiagnosticsHighlighter,
Highlighter,
LocatableToken,
MidasHighlighter,
PythonHighlighter,
)
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token, TokenType
from midas.parser.midas import MidasParser
from midas.parser.python import PythonParser
from midas.resolver.resolver import Resolver
from midas.utils import UniversalJSONDumper
@click.group()
@@ -21,9 +33,41 @@ def midas():
@midas.command()
@click.option("-l", "--highlight", type=click.File("w"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.argument("file", type=click.File("r"))
def compile(file: TextIO):
raise NotImplementedError
def compile(highlight: Optional[TextIO], file: TextIO, types: tuple[TextIO]):
logging.basicConfig(level=logging.DEBUG)
source: str = file.read()
tree: ast.Module = ast.parse(source, filename=file.name)
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
resolver = Resolver()
resolver.resolve(*stmts)
types_paths: list[Path] = [Path(t.name).resolve() for t in types]
checker = Checker(
resolver.locals,
source_path=Path(file.name).resolve(),
types_paths=types_paths,
)
diagnostics: list[Diagnostic] = checker.check(stmts)
for diagnostic in diagnostics:
print(diagnostic)
print(
json.dumps(
UniversalJSONDumper.dump(
checker.global_env,
[("Environment", "_children")],
lambda obj: isinstance(obj, get_args(Type)),
),
indent=4,
)
)
if highlight is not None:
highlighter = DiagnosticsHighlighter(source)
highlighter.highlight(diagnostics)
highlighter.dump(highlight)
@midas.group()
@@ -31,26 +75,52 @@ def utils():
pass
def dump_python_ast(tree: ast.Module) -> str:
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
printer = PythonAstPrinter()
dump: str = ""
for stmt in stmts:
dump += printer.print(stmt)
dump += "\n"
return dump
def dump_midas_ast(source: str, filename: str) -> str:
lexer = MidasLexer(source, file=filename)
tokens: list[Token] = lexer.process()
parser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
if len(parser.errors) != 0:
for err in parser.errors:
print(err.get_report())
raise RuntimeError("A parsing error occurred")
printer = MidasAstPrinter()
dump: str = ""
for stmt in stmts:
dump += printer.print(stmt)
dump += "\n"
return dump
@utils.command()
@click.option("-o", "--output", type=click.File("w"))
@click.option("-p", "--parse", is_flag=True)
@click.argument("file", type=click.File("r"))
def dump_ast(output: Optional[TextIO], parse: bool, file: TextIO):
source: str = file.read()
tree: ast.Module = ast.parse(source, filename=file.name)
dump: str
if file.name.endswith(".py"):
tree: ast.Module = ast.parse(source, filename=file.name)
if parse:
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
printer = PythonAstPrinter()
dump = ""
for stmt in stmts:
dump += printer.print(stmt)
dump += "\n"
dump = dump_python_ast(tree)
else:
dump = ast.dump(tree, indent=4)
elif file.name.endswith(".midas"):
dump = dump_midas_ast(source, file.name)
else:
raise ValueError("Unsupported file type")
if output is None:
click.echo(dump)
@@ -77,14 +147,6 @@ def highlight_midas(source: str, path: str) -> Highlighter:
for err in parser.errors:
print(err.get_report())
@dataclass(frozen=True)
class LocatableToken:
token: Token
@property
def location(self) -> Location:
return self.token.get_location()
for stmt in stmts:
highlighter.highlight(stmt)
for token in tokens:
@@ -109,3 +171,23 @@ def highlight(output: TextIO, file: TextIO):
else:
raise ValueError("Unsupported file type")
highlighter.dump(output)
@midas.command()
@click.option("-o", "--output", type=click.File("w"), default="-")
@click.argument("file", type=click.File("r"))
def format(output: TextIO, file: TextIO):
source: str = file.read()
printer = MidasPrinter()
lexer = MidasLexer(source, file=file.name)
tokens: list[Token] = lexer.process()
parser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
for err in parser.errors:
print(err.get_report())
for stmt in stmts:
output.write(printer.print(stmt) + "\n")
if __name__ == "__main__":
midas()

View File

@@ -40,8 +40,8 @@ class MidasLexer(Lexer):
self.add_token(TokenType.AND)
case "?":
self.add_token(TokenType.QMARK)
# case ",":
# self.add_token(TokenType.COMMA)
case ",":
self.add_token(TokenType.COMMA)
case "_" if not self.is_identifier_char(self.peek_next(), start=False):
self.add_token(TokenType.UNDERSCORE)
case "-" if self.match(">"):

View File

@@ -17,7 +17,7 @@ class TokenType(Enum):
LEFT_BRACE = auto()
RIGHT_BRACE = auto()
COLON = auto()
# COMMA = auto()
COMMA = auto()
UNDERSCORE = auto()
ARROW = auto()
AND = auto()

View File

@@ -3,21 +3,22 @@ from typing import Optional
from midas.ast.location import Location
from midas.ast.midas import (
BinaryExpr,
ComplexTypeStmt,
ComplexType,
ConstraintType,
Expr,
ExtendStmt,
GenericType,
GetExpr,
GroupingExpr,
LiteralExpr,
LogicalExpr,
NamedType,
OpStmt,
PredicateStmt,
PropertyStmt,
SimpleTypeExpr,
SimpleTypeStmt,
Stmt,
TemplateExpr,
TypeExpr,
Type,
TypeStmt,
UnaryExpr,
VariableExpr,
WildcardExpr,
@@ -81,7 +82,7 @@ class MidasParser(Parser):
self.synchronize()
return None
def type_declaration(self) -> SimpleTypeStmt | ComplexTypeStmt:
def type_declaration(self) -> TypeStmt:
"""Parse a type declaration
A type declaration can either be a simple type alias or a new complex type.
@@ -107,33 +108,22 @@ class MidasParser(Parser):
"""
keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
template: Optional[TemplateExpr] = None
params: list[TypeStmt.Param] = []
if self.check(TokenType.LEFT_BRACKET):
template = self.template_expr()
params = self.type_stmt_params()
if self.match(TokenType.LEFT_PAREN):
base: TypeExpr = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Unclosed base type parenthesis")
constraint: Optional[Expr] = None
if self.match(TokenType.WHERE):
constraint = self.constraint()
return SimpleTypeStmt(
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
type: Type = self.type_expr()
return TypeStmt(
location=keyword.location_to(self.previous()),
name=name,
template=template,
base=base,
constraint=constraint,
)
else:
properties: list[PropertyStmt] = self.type_properties()
return ComplexTypeStmt(
location=keyword.location_to(self.previous()),
name=name,
template=template,
properties=properties,
params=params,
type=type,
)
def template_expr(self) -> TemplateExpr:
def type_stmt_params(self) -> list[TypeStmt.Param]:
"""Parse a generic template expression
A template is written `[TypeExpr]`
@@ -141,16 +131,27 @@ class MidasParser(Parser):
Returns:
TemplateExpr: the parsed template expression
"""
left: Token = self.consume(
TokenType.LEFT_BRACKET, "Missing '[' before template expression"
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
params: list[TypeStmt.Param] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
bound: Optional[Type] = None
if self.match(TokenType.LESS):
self.consume(TokenType.COLON, "Expected ':' after '<'")
bound = self.type_expr()
params.append(
TypeStmt.Param(
location=name.location_to(self.previous()),
name=name,
bound=bound,
)
type: TypeExpr = self.type_expr()
right: Token = self.consume(
TokenType.RIGHT_BRACKET, "Missing ']' after template expression"
)
return TemplateExpr(location=left.location_to(right), type=type)
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
return params
def type_expr(self) -> TypeExpr:
def type_expr(self) -> Type:
"""Parse a type expression
A type is an identifier, optionally followed by a template expression.
@@ -159,30 +160,82 @@ class MidasParser(Parser):
Returns:
TypeExpr: the parsed type expression
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
template: Optional[TemplateExpr] = None
return self.constraint_type()
def constraint_type(self) -> Type:
type: Type = self.base_type()
if self.match(TokenType.WHERE):
constraint: Expr = self.constraint()
return ConstraintType(
location=Location.span(type.location, constraint.location),
type=type,
constraint=constraint,
)
return type
def base_type(self) -> Type:
if self.match(TokenType.LEFT_PAREN):
type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
return type
if self.check(TokenType.LEFT_BRACE):
return self.complex_type()
return self.generic_type()
def generic_type(self) -> Type:
type: Type = self.named_type()
if self.check(TokenType.LEFT_BRACKET):
template = self.template_expr()
optional: bool = self.match(TokenType.QMARK)
return TypeExpr(
location=name.location_to(self.previous()),
params: list[Type] = self.type_params()
return GenericType(
location=Location.span(type.location, self.previous().get_location()),
type=type,
params=params,
)
return type
def type_params(self) -> list[Type]:
params: list[Type] = []
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters")
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
params.append(self.type_expr())
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters")
return params
def named_type(self) -> Type:
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
return NamedType(
location=name.get_location(),
name=name,
template=template,
optional=optional,
)
def simple_type_expr(self) -> SimpleTypeExpr:
"""Parse a simple type expression
def complex_type(self) -> Type:
"""Parse a type definition body
A simple type is just an identifier optionally followed by a '?'
A type definition body is a set of whitespace-separated
property statements enclosed in curly braces
Returns:
SimpleTypeExpr: the parsed simple type expression
list[PropertyStmt]: the parsed type properties
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
optional: bool = self.match(TokenType.QMARK)
return SimpleTypeExpr(
location=name.location_to(self.previous()), name=name, optional=optional
left: Token = self.consume(
TokenType.LEFT_BRACE, "Expected '{' to start type body"
)
properties: list[PropertyStmt] = []
names: set[str] = set()
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
prop: PropertyStmt = self.property_stmt()
if prop.name.lexeme in names:
raise self.error(prop.name, "Duplicate property")
names.add(prop.name.lexeme)
properties.append(prop)
right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
return ComplexType(
location=left.location_to(right),
properties=properties,
)
def constraint(self) -> Expr:
@@ -205,9 +258,7 @@ class MidasParser(Parser):
while self.match(TokenType.AND):
operator: Token = self.previous()
right: Expr = self.equality()
location: Optional[Location] = None
if expr.location and right.location:
location = Location.span(expr.location, right.location)
location: Location = Location.span(expr.location, right.location)
expr = LogicalExpr(
location=location, left=expr, operator=operator, right=right
)
@@ -223,9 +274,7 @@ class MidasParser(Parser):
while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL):
operator: Token = self.previous()
right: Expr = self.comparison()
location: Optional[Location] = None
if expr.location and right.location:
location = Location.span(expr.location, right.location)
location: Location = Location.span(expr.location, right.location)
expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right
)
@@ -246,9 +295,7 @@ class MidasParser(Parser):
):
operator: Token = self.previous()
right: Expr = self.unary()
location: Optional[Location] = None
if expr.location and right.location:
location = Location.span(expr.location, right.location)
location: Location = Location.span(expr.location, right.location)
expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right
)
@@ -263,9 +310,7 @@ class MidasParser(Parser):
if self.match(TokenType.MINUS):
operator: Token = self.previous()
right: Expr = self.unary()
location: Optional[Location] = None
if right.location:
location = Location.span(operator.get_location(), right.location)
location: Location = Location.span(operator.get_location(), right.location)
return UnaryExpr(location=location, operator=operator, right=right)
return self.reference()
@@ -280,9 +325,7 @@ class MidasParser(Parser):
name: Token = self.consume(
TokenType.IDENTIFIER, "Expected property name after '.'"
)
location: Optional[Location] = None
if expr.location:
location = Location.span(expr.location, name.get_location())
location: Location = Location.span(expr.location, name.get_location())
expr = GetExpr(location=location, expr=expr, name=name)
return expr
@@ -318,22 +361,6 @@ class MidasParser(Parser):
raise self.error(self.peek(), "Expected expression")
def type_properties(self) -> list[PropertyStmt]:
"""Parse a type definition body
A type definition body is a set of whitespace-separated
property statements enclosed in curly braces
Returns:
list[PropertyStmt]: the parsed type properties
"""
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start type body")
properties: list[PropertyStmt] = []
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
properties.append(self.property_stmt())
self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
return properties
def property_stmt(self) -> PropertyStmt:
"""Parse a property statement
@@ -344,15 +371,11 @@ class MidasParser(Parser):
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
self.consume(TokenType.COLON, "Expected ':' after property name")
type: TypeExpr = self.type_expr()
constraint: Optional[Expr] = None
if self.match(TokenType.WHERE):
constraint = self.constraint()
type: Type = self.type_expr()
return PropertyStmt(
location=name.location_to(self.previous()),
name=name,
type=type,
constraint=constraint,
)
def extend_declaration(self) -> ExtendStmt:
@@ -364,15 +387,13 @@ class MidasParser(Parser):
ExtendStmt: the parsed extension statement
"""
keyword: Token = self.previous()
type: TypeExpr = self.type_expr()
type: Type = self.type_expr()
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
operations: list[OpStmt] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
operations.append(self.op_declaration())
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
location: Optional[Location] = None
if type.location:
location = keyword.location_to(self.previous())
location: Location = keyword.location_to(self.previous())
return ExtendStmt(location=location, type=type, operations=operations)
def op_declaration(self) -> OpStmt:
@@ -387,11 +408,11 @@ class MidasParser(Parser):
name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type")
operand: TypeExpr = self.type_expr()
operand: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type")
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: TypeExpr = self.type_expr()
result: Type = self.type_expr()
return OpStmt(
location=keyword.location_to(self.previous()),
@@ -413,7 +434,7 @@ class MidasParser(Parser):
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
self.consume(TokenType.COLON, "Expected ':' after subject name")
type: TypeExpr = self.type_expr()
type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
condition: Expr = self.constraint()

View File

@@ -2,12 +2,12 @@ import ast
from typing import Optional
from midas.ast.location import Location
from midas.ast.python import (
AssignStmt,
BaseType,
BinaryExpr,
CallExpr,
CastExpr,
CompareExpr,
ConstraintType,
Expr,
@@ -16,10 +16,13 @@ from midas.ast.python import (
FrameType,
Function,
GetExpr,
IfStmt,
LiteralExpr,
LogicalExpr,
MidasType,
ReturnStmt,
Stmt,
TernaryExpr,
TypeAssign,
UnaryExpr,
VariableExpr,
@@ -38,6 +41,8 @@ class UnsupportedSyntaxError(Exception):
class PythonParser:
CAST_FUNCTION = "cast"
def parse_module(self, node: ast.Module) -> list[Stmt]:
statements: list[Stmt] = []
for stmt in node.body:
@@ -53,6 +58,7 @@ class PythonParser:
return statements
def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]:
location: Location = Location.from_ast(node)
match node:
case ast.AnnAssign():
return self.parse_annotation_assign(node)
@@ -60,11 +66,29 @@ class PythonParser:
case ast.Assign():
return self.parse_assign(node)
case ast.AugAssign():
return self.parse_aug_assign(node)
case ast.FunctionDef():
return self.parse_function(node)
case ast.Expr(value=expr):
return ExpressionStmt(expr=self.parse_expr(expr))
return ExpressionStmt(
location=location,
expr=self.parse_expr(expr),
)
case ast.Return(value=value):
return ReturnStmt(
location=location,
value=self.parse_expr(value) if value is not None else None,
)
case ast.If():
return self.parse_if(node)
case ast.Pass():
return None
case _:
print(f"Unsupported statement: {ast.unparse(node)}")
@@ -80,8 +104,7 @@ class PythonParser:
value=value,
simple=1,
):
type = self._parse_type(annotation, root=True)
if type is not None:
type = self._parse_type(annotation)
statements.append(
TypeAssign(
location=loc,
@@ -117,6 +140,45 @@ class PythonParser:
value=value,
)
def parse_aug_assign(self, node: ast.AugAssign) -> AssignStmt:
location: Location = Location.from_ast(node)
target: Expr = self.parse_expr(node.target)
value: Expr = self.parse_expr(node.value)
return AssignStmt(
location=location,
targets=[target],
value=BinaryExpr(
location=location,
left=target,
operator=node.op,
right=value,
),
)
def parse_if(self, node: ast.If) -> IfStmt:
body: list[Stmt] = []
for stmt in node.body:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
body.append(stmts)
elif stmts is not None:
body.extend(stmts)
orelse: list[Stmt] = []
for stmt in node.orelse:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
orelse.append(stmts)
elif stmts is not None:
orelse.extend(stmts)
return IfStmt(
location=Location.from_ast(node),
test=self.parse_expr(node.test),
body=body,
orelse=orelse,
)
def parse_function(self, node: ast.FunctionDef) -> Function:
loc: Location = Location.from_ast(node)
match node:
@@ -125,26 +187,74 @@ class PythonParser:
args=ast.arguments(
posonlyargs=posonlyargs,
args=args,
vararg=sink,
kwonlyargs=kwonlyargs,
kwarg=kw_sink,
defaults=defaults,
kw_defaults=kw_defaults,
),
returns=returns,
body=raw_body,
):
def parse_args(args_list: list[ast.arg]) -> list[Function.Argument]:
return [self._parse_function_argument(arg) for arg in args_list]
def parse_args(
args_list: list[ast.arg], defaults: list[Optional[Expr]]
) -> list[Function.Argument]:
return [
self._parse_function_argument(arg, default)
for arg, default in zip(args_list, defaults)
]
body: list[Stmt] = []
for stmt in raw_body:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
body.append(stmts)
elif stmts is not None:
body.extend(stmts)
parsed_defaults: list[Optional[Expr]] = [
self.parse_expr(default) for default in defaults
]
n_posargs: int = len(posonlyargs)
n_args: int = len(args)
n_all_posargs = n_posargs + n_args
parsed_defaults = [
None,
] * (n_all_posargs - len(defaults)) + parsed_defaults
posargs_defaults: list[Optional[Expr]] = parsed_defaults[:n_posargs]
args_defaults: list[Optional[Expr]] = parsed_defaults[n_posargs:]
kwargs_defaults: list[Optional[Expr]] = [
self.parse_expr(default) if default is not None else None
for default in kw_defaults
]
return Function(
location=loc,
name=name,
posonlyargs=parse_args(posonlyargs),
args=parse_args(args),
kwonlyargs=parse_args(kwonlyargs),
posonlyargs=parse_args(posonlyargs, posargs_defaults),
args=parse_args(args, args_defaults),
sink=(
self._parse_function_argument(sink, None)
if sink is not None
else None
),
kwonlyargs=parse_args(kwonlyargs, kwargs_defaults),
kw_sink=(
self._parse_function_argument(kw_sink, None)
if kw_sink is not None
else None
),
returns=self._parse_type(returns) if returns is not None else None,
body=body,
)
case _:
print(f"Unsupported function definition: {ast.unparse(node)}")
def _parse_function_argument(self, arg: ast.arg) -> Function.Argument:
def _parse_function_argument(
self, arg: ast.arg, default: Optional[Expr]
) -> Function.Argument:
loc: Location = Location.from_ast(arg)
name: str = arg.arg
type: Optional[MidasType] = None
@@ -154,11 +264,10 @@ class PythonParser:
location=loc,
name=name,
type=type,
default=default,
)
def _parse_type(
self, type_expr: ast.expr, root: bool = False
) -> Optional[MidasType]:
def _parse_type(self, type_expr: ast.expr) -> MidasType:
loc: Location = Location.from_ast(type_expr)
match type_expr:
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
@@ -205,9 +314,14 @@ class PythonParser:
constraint=right_expr,
)
case ast.Constant(value=None):
return BaseType(
location=loc,
base="None",
param=None,
)
case _:
if root:
return None
raise UnsupportedSyntaxError(type_expr)
def _parse_frame_type(self, schema: ast.expr) -> FrameType:
@@ -257,12 +371,14 @@ class PythonParser:
raise UnsupportedSyntaxError(column)
def parse_expr(self, node: ast.expr) -> Expr:
location: Location = Location.from_ast(node)
match node:
case ast.BoolOp():
return self.parse_bool_op(node)
case ast.BinOp(left=left, op=op, right=right):
return BinaryExpr(
location=location,
left=self.parse_expr(left),
operator=op,
right=self.parse_expr(right),
@@ -270,6 +386,7 @@ class PythonParser:
case ast.UnaryOp(op=op, operand=right):
return UnaryExpr(
location=location,
operator=op,
right=self.parse_expr(right),
)
@@ -277,62 +394,96 @@ class PythonParser:
case ast.Compare():
return self.parse_compare(node)
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
return self.parse_cast(node)
case ast.Call():
return self.parse_call(node)
case ast.IfExp():
return self.parse_ternary(node)
case ast.Constant(value=value):
return LiteralExpr(value=value)
return LiteralExpr(location=location, value=value)
case ast.Attribute(value=object, attr=name):
return GetExpr(
location=location,
object=self.parse_expr(object),
name=name,
)
case ast.Name(id=name):
return VariableExpr(name=name)
return VariableExpr(location=location, name=name)
case _:
raise UnsupportedSyntaxError(node)
def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr:
op: ast.boolop = node.op
values: list[ast.expr] = node.values
rights: list[Expr] = [self.parse_expr(expr) for expr in node.values]
expr: LogicalExpr = LogicalExpr(
left=self.parse_expr(values[0]),
location=Location.span(
rights[0].location,
rights[1].location,
),
left=rights[0],
operator=op,
right=self.parse_expr(values[1]),
right=rights[1],
)
for value in values[2:]:
for right in rights[2:]:
expr = LogicalExpr(
location=Location.span(expr.location, right.location),
left=expr,
operator=op,
right=self.parse_expr(value),
right=right,
)
return expr
def parse_compare(self, node: ast.Compare) -> Expr:
ops: list[ast.cmpop] = node.ops
left: Expr = self.parse_expr(node.left)
rights: list[Expr] = [self.parse_expr(expr) for expr in node.comparators]
expr: Expr = CompareExpr(
left=self.parse_expr(node.left),
location=Location.span(
left.location,
rights[0].location,
),
left=left,
operator=ops[0],
right=rights[0],
)
for i, right in enumerate(rights[1:]):
expr = LogicalExpr(
left=expr,
operator=ast.And(),
right=CompareExpr(
comparison = CompareExpr(
location=Location.span(rights[i].location, right.location),
left=rights[i],
operator=ops[i],
right=right,
),
)
expr = LogicalExpr(
location=Location.span(expr.location, comparison.location),
left=expr,
operator=ast.And(),
right=comparison,
)
return expr
def parse_cast(self, node: ast.Call) -> CastExpr:
match node:
case ast.Call(args=[type, expr], keywords=[]):
return CastExpr(
location=Location.from_ast(node),
type=self._parse_type(type),
expr=self.parse_expr(expr),
)
case _:
raise InvalidSyntaxError(
f"Invalid call to {self.CAST_FUNCTION}, expected type and expression"
)
def parse_call(self, node: ast.Call) -> CallExpr:
return CallExpr(
location=Location.from_ast(node),
callee=self.parse_expr(node.func),
arguments=[self.parse_expr(arg) for arg in node.args],
keywords={
@@ -341,3 +492,11 @@ class PythonParser:
if arg.arg is not None # Should always be True, type checker happy
},
)
def parse_ternary(self, node: ast.IfExp) -> TernaryExpr:
return TernaryExpr(
location=Location.from_ast(node),
test=self.parse_expr(node.test),
if_true=self.parse_expr(node.body),
if_false=self.parse_expr(node.orelse),
)

72
midas/resolver/builtin.py Normal file
View File

@@ -0,0 +1,72 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from midas.checker.types import BaseType, Type, UnitType
if TYPE_CHECKING:
from midas.resolver.midas import MidasResolver
def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
ctx.define_operation(
left=t1,
operator=operator,
right=t2,
result=t3,
)
def basic_op(ctx: MidasResolver, type: Type, op: str):
ctx.define_operation(
left=type,
operator=op,
right=type,
result=type,
)
def define_builtins(ctx: MidasResolver):
"""Define builtin types and operations"""
unit = ctx.define_type("None", UnitType())
bool = ctx.define_type("bool", BaseType(name="bool"))
int = ctx.define_type("int", BaseType(name="int"))
float = ctx.define_type("float", BaseType(name="float"))
str = ctx.define_type("str", BaseType(name="str"))
basic_op(ctx, int, "__add__") # int + int = int
basic_op(ctx, int, "__sub__") # int - int = int
basic_op(ctx, int, "__mul__") # int * int = int
basic_op(ctx, int, "__pow__") # int ** int = int
basic_op(ctx, int, "__mod__") # int % int = int
basic_op(ctx, int, "__and__") # int & int = int
basic_op(ctx, int, "__or__") # int | int = int
basic_op(ctx, int, "__xor__") # int ^ int = int
op(ctx, int, "__lt__", int, bool) # int < int = bool
op(ctx, int, "__gt__", int, bool) # int > int = bool
op(ctx, int, "__le__", int, bool) # int <= int = bool
op(ctx, int, "__ge__", int, bool) # int >= int = bool
op(ctx, int, "__eq__", int, bool) # int == int = bool
basic_op(ctx, float, "__add__") # float + float = float
basic_op(ctx, float, "__sub__") # float - float = float
basic_op(ctx, float, "__mul__") # float * float = float
basic_op(ctx, float, "__truediv__") # float / float = float
op(ctx, float, "__lt__", float, bool) # float < float = bool
op(ctx, float, "__gt__", float, bool) # float > float = bool
op(ctx, float, "__le__", float, bool) # float <= float = bool
op(ctx, float, "__ge__", float, bool) # float >= float = bool
op(ctx, float, "__eq__", float, bool) # float == float = bool
basic_op(ctx, str, "__add__") # str + str = str
op(ctx, str, "__eq__", str, bool) # str == str = bool
op(ctx, int, "__lt__", float, bool) # int < float = bool
op(ctx, int, "__gt__", float, bool) # int > float = bool
op(ctx, int, "__le__", float, bool) # int <= float = bool
op(ctx, int, "__ge__", float, bool) # int >= float = bool
op(ctx, int, "__eq__", float, bool) # int == float = bool
op(ctx, float, "__lt__", int, bool) # float < int = bool
op(ctx, float, "__gt__", int, bool) # float > int = bool
op(ctx, float, "__le__", int, bool) # float <= int = bool
op(ctx, float, "__ge__", int, bool) # float >= int = bool
op(ctx, float, "__eq__", int, bool) # float == int = bool

163
midas/resolver/midas.py Normal file
View File

@@ -0,0 +1,163 @@
from typing import Optional
import midas.ast.midas as m
from midas.checker.types import (
AliasType,
Type,
UnknownType,
)
from midas.resolver.builtin import define_builtins
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry"""
def __init__(self) -> None:
self._types: dict[str, Type] = {}
self._operations: dict[tuple[Type, str, Type], Type] = {}
define_builtins(self)
def get_type(self, name: str) -> Type:
"""Get a type from its name
Args:
name (str): the name of the type
Raises:
NameError: if the type is not defined
Returns:
Type: the type
"""
type: Optional[Type] = self._types.get(name)
if type is None:
raise NameError(f"Undefined type {name}")
return type
def get_operation_result(
self, left: Type, operator: str, right: Type
) -> Optional[Type]:
"""Get the resulting type of an operation
Args:
left (Type): the type of the left operand
operator (str): the operation name
right (Type): the type of the right operand
Returns:
Optional[Type]: the result type, or None if no matching operation was found
"""
operation: tuple[Type, str, Type] = (left, operator, right)
result: Optional[Type] = self._operations.get(operation)
return result
def define_type(self, name: str, type: Type) -> Type:
"""Define a type in the registry
Args:
name (str): the name of the type
type (Type): the type to define
Raises:
ValueError: if a type is already defined with that name
Returns:
Type: the defined type
"""
if name in self._types:
raise ValueError(f"Type {name} already defined")
self._types[name] = type
return type
def define_operation(self, left: Type, operator: str, right: Type, result: Type):
"""Define an operation in the registry
Args:
left (Type): the type of the left operand
operator (str): the operation name
right (Type): the type of the right operand
result (Type): the result type
Raises:
ValueError: if an operation is already defined with these operands and name
"""
operation: tuple[Type, str, Type] = (left, operator, right)
if operation in self._operations:
raise ValueError(
f"Operation {operator} already defined between {left} and {right}"
)
self._operations[operation] = result
def resolve(self, stmts: list[m.Stmt]):
"""Process a sequence of statements
Args:
stmts (list[m.Stmt]): the statements
"""
for stmt in stmts:
stmt.accept(self)
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
type: Type = stmt.type.accept(self)
for param in stmt.params:
if param.bound is not None:
param.bound.accept(self)
name: str = stmt.name.lexeme
self.define_type(name, AliasType(name=name, type=type))
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
base: Type = stmt.type.accept(self)
for op in stmt.operations:
right: Type = op.operand.accept(self)
result: Type = op.result.accept(self)
self.define_operation(
left=base,
operator=op.name.lexeme,
right=right,
result=result,
)
def visit_op_stmt(self, stmt: m.OpStmt) -> None: ...
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
def visit_get_expr(self, expr: m.GetExpr) -> None: ...
def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
return expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
def visit_named_type(self, type: m.NamedType) -> Type:
return self.get_type(type.name.lexeme)
def visit_generic_type(self, type: m.GenericType) -> Type:
type_: Type = type.type.accept(self)
params: list[Type] = [param.accept(self) for param in type.params]
# TODO
return UnknownType()
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
type_: Type = type.type.accept(self)
type.constraint.accept(self)
# TODO
return UnknownType()
def visit_complex_type(self, type: m.ComplexType) -> Type:
for prop in type.properties:
prop.accept(self)
# TODO
return UnknownType()

187
midas/resolver/resolver.py Normal file
View File

@@ -0,0 +1,187 @@
import midas.ast.python as p
class ResolverError(Exception): ...
class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
"""A variable assignment and reference resolver
This class keeps track of which scope a variable is defined in and which
scope is referred to when a variable is referenced
"""
def __init__(self):
self.locals: dict[p.Expr, int] = {}
self.scopes: list[dict[str, bool]] = []
def resolve(self, *objects: p.Stmt | p.Expr) -> None:
"""Resolve the given statements or expressions"""
for obj in objects:
obj.accept(self)
def begin_scope(self):
"""Begin a new scope inside the current one"""
self.scopes.append({})
def end_scope(self):
"""Close the current scope"""
self.scopes.pop()
def declare(self, name: str) -> None:
"""Declare a variable in the current scope
This method must be called *before* evaluating the variable initializer
Args:
name (str): the name of the variable
Raises:
ResolverError: if the variable has already been declared in the current scope
"""
if len(self.scopes) == 0:
return
scope: dict[str, bool] = self.scopes[-1]
if name in scope:
raise ResolverError(
f"A variable with the name {name} is already declared in this scope"
)
scope[name] = False
def define(self, name: str) -> None:
"""Define a variable in the current scope
This method must be called *after* evaluating the variable initializer
Args:
name (str): the name of the variable
"""
if len(self.scopes) == 0:
return
self.scopes[-1][name] = True
def resolve_local(self, expr: p.Expr, name: str) -> None:
"""Resolve a variable reference and store the scope distance
This method associates to the variable expression a number representing
the "distance" of the variable declaration, i.e. the number of scope
levels to go "up" to find the closest declaration for that variable.
Args:
expr (p.Expr): the variable expression
name (str): the name of the variable
"""
for i, scope in enumerate(reversed(self.scopes)):
if name in scope:
self.locals[expr] = i
return
def resolve_function(self, function: p.Function) -> None:
"""Resolve a function definition
This method creates a new scope for the function, resolves all the
parameter declarations and then the body.
Args:
function (p.Function): the function to resolve
"""
self.begin_scope()
for param in function.all_args:
self.declare(param.name)
self.define(param.name)
self.resolve(*function.body)
self.end_scope()
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
stmt.expr.accept(self)
def visit_function(self, stmt: p.Function) -> None:
# Declare before resolving body to allow recursion
self.declare(stmt.name)
self.define(stmt.name)
self.resolve_function(stmt)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
self.declare(stmt.name)
# NOTE: resolve type here?
self.define(stmt.name)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
self.resolve(stmt.value)
for target in stmt.targets:
match target:
case p.VariableExpr(name=name):
self.resolve_local(target, name)
# TODO: declare if not found
case _:
raise Exception(f"Unsupported assignment to {target}")
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
if stmt.value is not None:
self.resolve(stmt.value)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
# Not resolved in sub-environment because assignments in the test leak out of the if
# For example:
# if (m := 1 + 1) < 2:
# ...
# print(m) # <- m is still defined
self.resolve(stmt.test)
# Body
self.begin_scope()
self.resolve(*stmt.body)
self.end_scope()
# Else
self.begin_scope()
self.resolve(*stmt.orelse)
self.end_scope()
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self.resolve(expr.left)
self.resolve(expr.right)
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
self.resolve(expr.left)
self.resolve(expr.right)
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
self.resolve(expr.right)
def visit_call_expr(self, expr: p.CallExpr) -> None:
self.resolve(expr.callee)
for arg in expr.arguments:
self.resolve(arg)
for arg in expr.keywords.values():
self.resolve(arg)
def visit_get_expr(self, expr: p.GetExpr) -> None:
self.resolve(expr.object)
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
pass
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
if len(self.scopes) != 0 and self.scopes[-1].get(expr.name) is False:
raise ResolverError(
f"Cannot use local variable '{expr.name}' in its own initializer"
) # aka. UnboundLocalError
self.resolve_local(expr, expr.name)
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
self.resolve(expr.left)
self.resolve(expr.right)
def visit_set_expr(self, expr: p.SetExpr) -> None:
self.resolve(expr.value)
self.resolve(expr.object)
def visit_cast_expr(self, expr: p.CastExpr) -> None:
self.resolve(expr.expr)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
self.resolve(expr.test)
self.resolve(expr.if_true)
self.resolve(expr.if_false)

54
midas/utils.py Normal file
View File

@@ -0,0 +1,54 @@
from typing import Any, Callable, Optional
AllowRepeat = Callable[[object], bool]
class UniversalJSONDumper:
@classmethod
def dump(
cls,
obj: Any,
include_keys: Optional[list[str | tuple[str, str]]] = None,
allow_repeat: Optional[AllowRepeat] = None,
) -> Any:
if include_keys is None:
include_keys = []
return cls._dump(obj, include_keys, allow_repeat, [])
@classmethod
def _dump(
cls,
obj: Any,
include_keys: list[str | tuple[str, str]],
allow_repeat: Optional[AllowRepeat],
visited: list[Any],
) -> Any:
if obj in visited:
return None
match obj:
case str() | int() | float() | None:
return obj
case list() | set() | tuple():
return [
cls._dump(child, include_keys, allow_repeat, visited)
for child in obj
]
case dict():
return {
str(k): cls._dump(v, include_keys, allow_repeat, visited)
for k, v in obj.items()
}
case object():
if allow_repeat is None or not allow_repeat(obj):
visited.append(obj)
return {
"_type": obj.__class__.__name__,
} | {
k: cls._dump(v, include_keys, allow_repeat, visited)
for k, v in obj.__dict__.items()
if not k.startswith("_")
or k in include_keys
or (obj.__class__.__name__, k) in include_keys
}
case _:
raise ValueError(f"Unsupported value: {obj}")

View File

@@ -19,16 +19,24 @@ Comparison ::= Unary (ComparisonOp Unary)*
Equality ::= Comparison (EqualityOp Comparison)*
Constraint ::= Equality ("&" Equality)*
SimpleType ::= Identifier "?"?
Template ::= "[" Type "]"
Type ::= Identifier Template? "?"?
TemplateParam ::= Identifier ("<:" Type)?
Template ::= "[" (TemplateParam ("," TemplateParam)*)? "]"
TypeProperty ::= Identifier ":" Type
ComplexType ::= "{" TypeProperty* "}"
NamedType ::= Identifier
TypeParams ::= "[" (Type ("," Type)*)? "]"
GenericType ::= NamedType TypeParams?
GroupedType ::= "(" Type ")"
BaseType ::= GroupedType | ComplexType | GenericType
ConstraintType ::= BaseType ("where" Constraint)?
Type ::= ConstraintType
TypeProperty ::= Identifier ":" Type ("where" Constraints)?
ComplexTypeBody ::= "{" TypeProperty* "}"
OpDefinition ::= "op" Identifier "(" Type ")" "->" Type
ExtendBody ::= "{" OpDefinition* "}"
TypeStatement ::= "type" Identifier Template? ("(" Type ")" ("where" Constraint)? | ComplexTypeBody)
TypeStatement ::= "type" Identifier Template? "=" Type
ExtendStatement ::= "extend" Type ExtendBody
PredicateStatement ::= "predicate" Identifier "(" Identifier ":" Type ")" "=" Constraint

View File

@@ -43,28 +43,52 @@ svg.railroad .terminal rect {
{[`constraint` 'equality'*"&"]}
```
#let simple-type = ```
{[`simple-type` 'identifier' <!, "?">]}
#let template-param = ```
{[`template-param` 'identifier' <!, ["<:" 'type']>]}
```
#let template = ```
{[`template` "[" 'type' "]"]}
```
#let type = ```
{[`type` 'identifier' <!, 'template'> <!, "?">]}
{[`template` "[" <!, 'template-param'*","> "]"]}
```
#let type-property = ```
{[`type-property` 'identifier' ":" 'type' <!, ["where" 'constraint']>]}
{[`type-property` 'identifier' ":" 'type']}
```
#let type-body = ```
{[`type-body` "{" <!, 'type-property'*!> "}"]}
#let complex-type = ```
{[`complex-type` "{" <!, 'type-property'*!> "}"]}
```
#let named-type = ```
{[`named-type` 'identifier']}
```
#let type-params = ```
{[`type-params` "[" <!, 'type'*","> "]"]}
```
#let generic-type = ```
{[`generic-type` 'named-type' <!, 'type-params'>]}
```
#let grouped-type = ```
{[`grouped-type` "(" 'type' ")"]}
```
#let base-type = ```
{[`base-type` <'grouped-type', 'complex-type', 'generic-type'>]}
```
#let constraint-type = ```
{[`constraint-type` 'base-type' <!, ["where" 'constraint']>]}
```
#let type = ```
{[`type` 'constraint-type']}
```
#let type-statement = ```
{[`type-statement` "type" 'identifier' <!, 'template'> <[["(" 'type' ")"] <!, ["where" 'constraint']>], 'type-body'>]}
{[`type-statement` "type" 'identifier' <!, 'template'> "=" 'type']}
```
#let op-definition = ```
@@ -92,11 +116,17 @@ svg.railroad .terminal rect {
comparison: comparison,
equality: equality,
constraint: constraint,
simple-type: simple-type,
template-param: template-param,
template: template,
type: type,
type-property: type-property,
type-body: type-body,
complex-type: complex-type,
named-type: named-type,
type-params: type-params,
generic-type: generic-type,
grouped-type: grouped-type,
base-type: base-type,
constraint-type: constraint-type,
type: type,
type-statement: type-statement,
op-definition: op-definition,
extend-statement: extend-statement,
@@ -107,10 +137,16 @@ svg.railroad .terminal rect {
#let inline = (
"grouping",
"value",
"template-param",
"template",
"simple-type",
"type-property",
"type-body",
"complex-type",
"type-params",
"named-type",
"grouped-type",
"generic-type",
"base-type",
"constraint-type",
"op-definition",
"type-statement",
"extend-statement",

204
tester.py
View File

@@ -1,204 +0,0 @@
from __future__ import annotations
import argparse
import difflib
import json
import sys
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Iterator, Optional
from midas.ast.json_serializer import AstJsonSerializer
from midas.ast.midas import Stmt
from midas.lexer.base import MidasSyntaxError
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
DEFAULT_BASE_DIR: Path = Path() / "tests"
@dataclass
class CaseResult:
tokens: Optional[list[dict]] = None
stmts: Optional[list[dict]] = None
errors: list[dict] = field(default_factory=list)
def dumps(self) -> str:
return json.dumps(asdict(self), indent=2)
class Tester:
"""A test runner to check for regressions in the lexer and parser"""
def __init__(self, base_dir: Path):
self.base_dir: Path = base_dir
def _list_tests(self) -> list[Path]:
return list(self.base_dir.rglob("*.midas"))
def run_all_tests(self) -> bool:
paths: list[Path] = self._list_tests()
return self.run_tests(paths)
def run_tests(self, tests: list[Path]) -> bool:
rule: str = "-" * 80
n: int = len(tests)
successes: int = 0
failures: int = 0
print(rule)
for i, test in enumerate(tests):
print(f"Case {i+1}/{n}: {test}")
success: bool = self._run_test(test)
if success:
successes += 1
else:
failures += 1
print(rule)
print(f"Success: {successes}/{n}")
print(f"Failed: {failures}/{n}")
print(rule)
return failures == 0
def _run_test(self, path: Path) -> bool:
result: CaseResult = self._exec_case(path)
result_path: Path = self._result_path(path)
expected: str = result_path.read_text()
actual: str = result.dumps()
if expected == actual:
return True
diff = difflib.unified_diff(
expected.splitlines(keepends=True),
actual.splitlines(keepends=True),
fromfile="Snapshot",
tofile="Result",
)
self._print_diff(diff)
return False
def _exec_case(self, path: Path) -> CaseResult:
if not path.exists():
raise FileNotFoundError(f"Could not find test '{path}'")
if not path.is_file():
raise TypeError(f"Test '{path}' is not a file")
result: CaseResult = CaseResult()
content: str = path.read_text()
lexer: MidasLexer = MidasLexer(content)
tokens: list[Token] = []
try:
tokens = lexer.process()
result.tokens = [
{
"type": token.type.name,
"lexeme": token.lexeme,
"line": token.position.line,
"column": token.position.column,
}
for token in tokens
]
except MidasSyntaxError as e:
result.errors.append(
{
"type": "SyntaxError",
"line": e.pos.line,
"column": e.pos.column,
"message": e.message,
}
)
return result
parser: MidasParser = MidasParser(tokens)
stmts: list[Stmt] = parser.parse()
result.stmts = AstJsonSerializer().serialize(stmts)
result.errors.extend(
[
{
"line": e.token.position.line,
"column": e.token.position.column,
"message": e.message,
}
for e in parser.errors
]
)
return result
def update_all_tests(self):
paths: list[Path] = self._list_tests()
return self.update_tests(paths)
def update_tests(self, tests: list[Path]):
updated: int = 0
for test in tests:
if self._update_test(test):
updated += 1
print(f"Updated {updated}/{len(tests)} tests")
def _update_test(self, path: Path) -> bool:
result: CaseResult = self._exec_case(path)
result_path: Path = self._result_path(path)
current: str = result_path.read_text()
new: str = result.dumps()
if current == new:
return False
result_path.write_text(new)
return True
def _result_path(self, test_path: Path) -> Path:
return test_path.parent / (test_path.name + ".ref.json")
def _print_diff(self, diff: Iterator[str]):
for line in diff:
if line.startswith("+") and not line.startswith("+++"):
print(f"\033[92m{line}\033[0m", end="")
elif line.startswith("-") and not line.startswith("---"):
print(f"\033[91m{line}\033[0m", end="")
else:
print(line, end="")
print()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-D",
"--base-dir",
help="Base directory containing test files",
type=Path,
default=DEFAULT_BASE_DIR,
)
subparsers = parser.add_subparsers(dest="subcommand")
update = subparsers.add_parser("update")
update.add_argument("-a", "--all", action="store_true")
update.add_argument("FILE", type=Path, nargs="*")
run = subparsers.add_parser("run")
run.add_argument("-a", "--all", action="store_true")
run.add_argument("FILE", type=Path, nargs="*")
args = parser.parse_args()
tester: Tester = Tester(args.base_dir)
match args.subcommand:
case "update":
if args.all:
tester.update_all_tests()
else:
tester.update_tests(args.FILE)
case "run":
success: bool
if args.all:
success = tester.run_all_tests()
else:
success = tester.run_tests(args.FILE)
if not success:
sys.exit(1)
if __name__ == "__main__":
main()

143
tests/base.py Normal file
View File

@@ -0,0 +1,143 @@
from __future__ import annotations
import argparse
import difflib
import sys
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Iterator, Protocol
class CaseResult(Protocol):
def dumps(self) -> str: ...
class Tester(ABC):
"""A test runner to check for regressions in the lexer and parser"""
CASES_DIR: Path = Path(__file__).parent / "cases"
@property
@abstractmethod
def namespace(self) -> str: ...
@property
def base_dir(self) -> Path:
return self.CASES_DIR / self.namespace
@abstractmethod
def _list_tests(self) -> list[Path]: ...
def run_all_tests(self) -> bool:
paths: list[Path] = self._list_tests()
return self.run_tests(paths)
def run_tests(self, tests: list[Path]) -> bool:
rule: str = "-" * 80
n: int = len(tests)
successes: int = 0
failures: int = 0
print(rule)
for i, test in enumerate(tests):
print(f"Case {i+1}/{n}: {test.relative_to(self.CASES_DIR)}")
success: bool = self._run_test(test)
if success:
successes += 1
else:
failures += 1
print(rule)
print(f"Success: {successes}/{n}")
print(f"Failed: {failures}/{n}")
print(rule)
return failures == 0
def _run_test(self, path: Path) -> bool:
result_path: Path = self._result_path(path)
if not result_path.exists():
print("Missing snapshot. Please run the update command first")
return False
result: CaseResult = self._exec_case(path)
expected: str = result_path.read_text()
actual: str = result.dumps()
if expected == actual:
return True
diff = difflib.unified_diff(
expected.splitlines(keepends=True),
actual.splitlines(keepends=True),
fromfile="Snapshot",
tofile="Result",
)
self._print_diff(diff)
return False
@abstractmethod
def _exec_case(self, path: Path) -> CaseResult: ...
def update_all_tests(self):
paths: list[Path] = self._list_tests()
return self.update_tests(paths)
def update_tests(self, tests: list[Path]):
updated: int = 0
for test in tests:
if self._update_test(test):
updated += 1
print(f"Updated {updated}/{len(tests)} tests")
def _update_test(self, path: Path) -> bool:
result: CaseResult = self._exec_case(path)
result_path: Path = self._result_path(path)
current: str = result_path.read_text() if result_path.exists() else ""
new: str = result.dumps()
if current == new:
return False
result_path.write_text(new)
return True
def _result_path(self, test_path: Path) -> Path:
return test_path.parent / (test_path.name + ".ref.json")
def _print_diff(self, diff: Iterator[str]):
for line in diff:
if line.startswith("+") and not line.startswith("+++"):
print(f"\033[92m{line}\033[0m", end="")
elif line.startswith("-") and not line.startswith("---"):
print(f"\033[91m{line}\033[0m", end="")
else:
print(line, end="")
print()
@classmethod
def main(cls):
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="subcommand")
update = subparsers.add_parser("update")
update.add_argument("-a", "--all", action="store_true")
update.add_argument("FILE", type=Path, nargs="*")
run = subparsers.add_parser("run")
run.add_argument("-a", "--all", action="store_true")
run.add_argument("FILE", type=Path, nargs="*")
args = parser.parse_args()
tester: Tester = cls()
match args.subcommand:
case "update":
if args.all:
tester.update_all_tests()
else:
tester.update_tests(args.FILE)
case "run":
success: bool
if args.all:
success = tester.run_all_tests()
else:
success = tester.run_tests(args.FILE)
if not success:
sys.exit(1)

View File

@@ -0,0 +1,14 @@
# type: ignore
# ruff: disable[F821]
from __future__ import annotations
df: Frame[
verified: bool,
birth_year: int,
height: float + ( _ > 0 ) + ( _ < 250 ),
name: str,
date: datetime,
float,
unknown: _,
_
]

View File

@@ -0,0 +1,3 @@
{
"diagnostics": []
}

View File

@@ -0,0 +1,11 @@
a: int = 3
b: int = 4
c = a + b
c = "invalid"
d = True
e = d + d
f: float = a

View File

@@ -0,0 +1,46 @@
{
"diagnostics": [
{
"type": "Error",
"location": {
"start": [
6,
0
],
"end": [
6,
13
]
},
"message": "Cannot assign BaseType(name='str') to c of type BaseType(name='int')"
},
{
"type": "Error",
"location": {
"start": [
9,
4
],
"end": [
9,
9
]
},
"message": "Undefined operation __add__ between BaseType(name='bool') and BaseType(name='bool')"
},
{
"type": "Error",
"location": {
"start": [
11,
0
],
"end": [
11,
12
]
},
"message": "Cannot assign BaseType(name='int') to f of type BaseType(name='float')"
}
]
}

View File

@@ -0,0 +1,18 @@
def foo(a: int, /, b: float, *, c: str):
return True
r1 = foo()
r2 = foo(1)
r3 = foo(1, 2.0)
r4 = foo(1, b=2.0)
r5 = foo(1, 2.0, "test")
r6 = foo(1, 2.0, b=3.0)
r7 = foo(a=1)
r8 = foo(g="test")
r9a = foo(1, 2.0, c="test")
r9b = foo(1, b=2.0, c="test")
r9c = foo(1, c="test", b=2.0)
r10 = foo("a", 3, c=False)

View File

@@ -0,0 +1,270 @@
{
"diagnostics": [
{
"type": "Error",
"location": {
"start": [
5,
5
],
"end": [
5,
10
]
},
"message": "Missing required positional arguments: 'a' and 'b'"
},
{
"type": "Error",
"location": {
"start": [
5,
5
],
"end": [
5,
10
]
},
"message": "Missing required keyword argument: 'c'"
},
{
"type": "Error",
"location": {
"start": [
6,
5
],
"end": [
6,
11
]
},
"message": "Missing required positional argument: 'b'"
},
{
"type": "Error",
"location": {
"start": [
6,
5
],
"end": [
6,
11
]
},
"message": "Missing required keyword argument: 'c'"
},
{
"type": "Error",
"location": {
"start": [
7,
5
],
"end": [
7,
16
]
},
"message": "Missing required keyword argument: 'c'"
},
{
"type": "Error",
"location": {
"start": [
8,
5
],
"end": [
8,
18
]
},
"message": "Missing required keyword argument: 'c'"
},
{
"type": "Error",
"location": {
"start": [
9,
17
],
"end": [
9,
23
]
},
"message": "Too many positional arguments"
},
{
"type": "Error",
"location": {
"start": [
9,
5
],
"end": [
9,
24
]
},
"message": "Missing required keyword argument: 'c'"
},
{
"type": "Error",
"location": {
"start": [
10,
19
],
"end": [
10,
22
]
},
"message": "Multiple values for argument 'b'"
},
{
"type": "Error",
"location": {
"start": [
10,
5
],
"end": [
10,
23
]
},
"message": "Missing required keyword argument: 'c'"
},
{
"type": "Error",
"location": {
"start": [
11,
11
],
"end": [
11,
12
]
},
"message": "Unknown keyword argument 'a'"
},
{
"type": "Error",
"location": {
"start": [
11,
5
],
"end": [
11,
13
]
},
"message": "Missing required positional arguments: 'a' and 'b'"
},
{
"type": "Error",
"location": {
"start": [
11,
5
],
"end": [
11,
13
]
},
"message": "Missing required keyword argument: 'c'"
},
{
"type": "Error",
"location": {
"start": [
12,
11
],
"end": [
12,
17
]
},
"message": "Unknown keyword argument 'g'"
},
{
"type": "Error",
"location": {
"start": [
12,
5
],
"end": [
12,
18
]
},
"message": "Missing required positional arguments: 'a' and 'b'"
},
{
"type": "Error",
"location": {
"start": [
12,
5
],
"end": [
12,
18
]
},
"message": "Missing required keyword argument: 'c'"
},
{
"type": "Error",
"location": {
"start": [
18,
10
],
"end": [
18,
13
]
},
"message": "Wrong type for argument 'a', expected BaseType(name='int'), got BaseType(name='str')"
},
{
"type": "Error",
"location": {
"start": [
18,
15
],
"end": [
18,
16
]
},
"message": "Wrong type for argument 'b', expected BaseType(name='float'), got BaseType(name='int')"
},
{
"type": "Error",
"location": {
"start": [
18,
20
],
"end": [
18,
25
]
},
"message": "Wrong type for argument 'c', expected BaseType(name='str'), got BaseType(name='bool')"
}
]
}

View File

@@ -0,0 +1,14 @@
type Meter = float
type Second = float
type MeterPerSecond = float
extend Meter {
op __add__(Meter) -> Meter
op __sub__(Meter) -> Meter
op __truediv__(Second) -> MeterPerSecond
}
extend Second {
op __add__(Second) -> Second
op __sub__(Second) -> Second
}

View File

@@ -0,0 +1,6 @@
# type: ignore
# ruff: disable [F821]
distance: Meter = cast(Meter, 123.45)
time: Second = cast(Second, 6.7)
speed = distance / time

View File

@@ -0,0 +1,3 @@
{
"diagnostics": []
}

View File

@@ -0,0 +1,25 @@
def valid(a: int, b: int) -> int:
return a + b
def with_if(a: int, b: int) -> int:
if a < b:
return b - a
else:
return a - b
def unreachable1():
return
a = 0
def unreachable2(a: int) -> int:
if a > 10:
return a - 10
else:
return a
b = 0
def mixed(a: int, b: int):
if a < b:
return b - a
else:
return "oops"

View File

@@ -0,0 +1,46 @@
{
"diagnostics": [
{
"type": "Warning",
"location": {
"start": [
12,
4
],
"end": [
12,
9
]
},
"message": "Unreachable statement"
},
{
"type": "Warning",
"location": {
"start": [
19,
4
],
"end": [
19,
9
]
},
"message": "Unreachable statement"
},
{
"type": "Error",
"location": {
"start": [
21,
0
],
"end": [
25,
21
]
},
"message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]"
}
]
}

View File

@@ -1,15 +1,15 @@
// Simple custom type derived from float
type Custom(float)
type Custom = float
// Simple custom types with constraints
type Latitude(float) where (-90 <= _ <= 90)
type Longitude(float) where (-180 <= _ <= 180)
type Latitude = float where (-90 <= _ <= 90)
type Longitude = float where (-180 <= _ <= 180)
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
type Difference[T](T)
type Difference[T] = T
// Complex custom type, containing two values accessible through properties
type GeoLocation {
type GeoLocation = {
lat: Latitude
lon: Longitude
}
@@ -24,7 +24,7 @@ extend GeoLocation {
// For complex generics, you need to specify how the genericity the properties
// are handled
type Difference[GeoLocation] {
type Difference[GeoLocation] = {
lat: Difference[Latitude]
lon: Difference[Longitude]
}
@@ -44,11 +44,11 @@ predicate StrictlyPositive(v: float) = v > 0
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
type Person {
type Person = {
name: str
// Property with an inline constraint
age: int? where (0 <= _ < 150)
age: Optional[int where (0 <= _ < 150)]
// Property referencing a predicate
height: float where StrictlyPositive

View File

@@ -0,0 +1,14 @@
# type: ignore
# ruff: disable[F821]
from __future__ import annotations
df: Frame[
verified: bool,
birth_year: int,
height: float + ( _ > 0 ) + ( _ < 250 ),
name: str,
date: datetime,
float,
unknown: _,
_
]

View File

@@ -0,0 +1,85 @@
{
"stmts": [
{
"_type": "TypeAssign",
"name": "df",
"type": {
"_type": "FrameType",
"columns": [
{
"_type": "FrameColumn",
"name": "verified",
"type": {
"_type": "BaseType",
"base": "bool",
"param": null
}
},
{
"_type": "FrameColumn",
"name": "birth_year",
"type": {
"_type": "BaseType",
"base": "int",
"param": null
}
},
{
"_type": "FrameColumn",
"name": "height",
"type": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "(_ > 0) + (_ < 250)"
}
},
{
"_type": "FrameColumn",
"name": "name",
"type": {
"_type": "BaseType",
"base": "str",
"param": null
}
},
{
"_type": "FrameColumn",
"name": "date",
"type": {
"_type": "BaseType",
"base": "datetime",
"param": null
}
},
{
"_type": "FrameColumn",
"name": null,
"type": {
"_type": "BaseType",
"base": "float",
"param": null
}
},
{
"_type": "FrameColumn",
"name": "unknown",
"type": null
},
{
"_type": "FrameColumn",
"name": null,
"type": {
"_type": "BaseType",
"base": "_",
"param": null
}
}
]
}
}
]
}

View File

@@ -0,0 +1,25 @@
# type: ignore
# ruff: disable[F821]
from __future__ import annotations
df: Frame[
location: GeoLocation
]
lat: Column[GeoLocation] = df["location"].lat
lon: Column[GeoLocation] = df["location"].lon
lat + lon
lat1: Latitude = lat[0]
lat2: Latitude = lat[1]
lat_diff: Difference[Latitude] = lat2 - lat1
df2: Frame[
age: int + (_ >= 0),
height: float + (_ >= 0),
]
df2_bis: Frame[
age: int + Positive,
height: float + Positive,
]

View File

@@ -0,0 +1,141 @@
{
"stmts": [
{
"_type": "TypeAssign",
"name": "df",
"type": {
"_type": "FrameType",
"columns": [
{
"_type": "FrameColumn",
"name": "location",
"type": {
"_type": "BaseType",
"base": "GeoLocation",
"param": null
}
}
]
}
},
{
"_type": "ExpressionStmt",
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "lat"
},
"operator": "+",
"right": {
"_type": "VariableExpr",
"name": "lon"
}
}
},
{
"_type": "TypeAssign",
"name": "lat_diff",
"type": {
"_type": "BaseType",
"base": "Difference",
"param": {
"_type": "BaseType",
"base": "Latitude",
"param": null
}
}
},
{
"_type": "AssignStmt",
"targets": [
{
"_type": "VariableExpr",
"name": "lat_diff"
}
],
"value": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "lat2"
},
"operator": "-",
"right": {
"_type": "VariableExpr",
"name": "lat1"
}
}
},
{
"_type": "TypeAssign",
"name": "df2",
"type": {
"_type": "FrameType",
"columns": [
{
"_type": "FrameColumn",
"name": "age",
"type": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "int",
"param": null
},
"constraint": "_ >= 0"
}
},
{
"_type": "FrameColumn",
"name": "height",
"type": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "_ >= 0"
}
}
]
}
},
{
"_type": "TypeAssign",
"name": "df2_bis",
"type": {
"_type": "FrameType",
"columns": [
{
"_type": "FrameColumn",
"name": "age",
"type": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "int",
"param": null
},
"constraint": "Positive"
}
},
{
"_type": "FrameColumn",
"name": "height",
"type": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "Positive"
}
}
]
}
}
]
}

View File

@@ -0,0 +1,15 @@
# type: ignore
# ruff: disable[F821]
from __future__ import annotations
def func(
col1: Column[float + (0 <= _ <= 1)],
col2: Column[float + (0 <= _ <= 1)],
) -> Column[float + (0 <= _ <= 2)]:
result: Column[float + (0 <= _ <= 2)] = col1 + col2
return result
def func2(a: int, /, b: float, *, c: str):
pass

View File

@@ -0,0 +1,149 @@
{
"stmts": [
{
"_type": "Function",
"name": "func",
"posonlyargs": [],
"args": [
{
"name": "col1",
"type": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "0 <= _ <= 1"
}
},
"default": null
},
{
"name": "col2",
"type": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "0 <= _ <= 1"
}
},
"default": null
}
],
"sink": null,
"kwonlyargs": [],
"kw_sink": null,
"returns": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "0 <= _ <= 2"
}
},
"body": [
{
"_type": "TypeAssign",
"name": "result",
"type": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "ConstraintType",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"constraint": "0 <= _ <= 2"
}
}
},
{
"_type": "AssignStmt",
"targets": [
{
"_type": "VariableExpr",
"name": "result"
}
],
"value": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "col1"
},
"operator": "+",
"right": {
"_type": "VariableExpr",
"name": "col2"
}
}
},
{
"_type": "ReturnStmt",
"value": {
"_type": "VariableExpr",
"name": "result"
}
}
]
},
{
"_type": "Function",
"name": "func2",
"posonlyargs": [
{
"name": "a",
"type": {
"_type": "BaseType",
"base": "int",
"param": null
},
"default": null
}
],
"args": [
{
"name": "b",
"type": {
"_type": "BaseType",
"base": "float",
"param": null
},
"default": null
}
],
"sink": null,
"kwonlyargs": [
{
"name": "c",
"type": {
"_type": "BaseType",
"base": "str",
"param": null
},
"default": null
}
],
"kw_sink": null,
"returns": null,
"body": []
}
]
}

75
tests/checker.py Normal file
View File

@@ -0,0 +1,75 @@
import ast
import json
from dataclasses import asdict, dataclass, field
from pathlib import Path
import midas.ast.python as p
from midas.checker.checker import Checker
from midas.checker.diagnostic import Diagnostic
from midas.parser.python import PythonParser
from midas.resolver.resolver import Resolver
from tests.base import Tester
@dataclass
class CaseResult:
diagnostics: list[dict] = field(default_factory=list)
def dumps(self) -> str:
return json.dumps(asdict(self), indent=2)
class CheckerTester(Tester):
@property
def namespace(self) -> str:
return "checker"
def _list_tests(self) -> list[Path]:
return list(self.base_dir.rglob("*.py"))
def _exec_case(self, path: Path) -> CaseResult:
if not path.exists():
raise FileNotFoundError(f"Could not find test '{path}'")
if not path.is_file():
raise TypeError(f"Test '{path}' is not a file")
types_paths: list[Path] = []
types_path: Path = path.with_suffix(".midas")
if types_path.exists():
types_paths.append(types_path)
source: str = path.read_text()
tree: ast.Module = ast.parse(source, filename=path)
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
resolver = Resolver()
resolver.resolve(*stmts)
result: CaseResult = CaseResult()
checker = Checker(
resolver.locals,
source_path=path,
types_paths=types_paths,
)
diagnostics: list[Diagnostic] = checker.check(stmts)
for diagnostic in diagnostics:
result.diagnostics.append(
{
"type": str(diagnostic.type),
"location": {
"start": (
diagnostic.location.lineno,
diagnostic.location.col_offset,
),
"end": (
diagnostic.location.end_lineno,
diagnostic.location.end_col_offset,
),
},
"message": diagnostic.message,
}
)
return result
if __name__ == "__main__":
CheckerTester.main()

82
tests/midas.py Normal file
View File

@@ -0,0 +1,82 @@
import json
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Optional
from midas.ast.midas import Stmt
from midas.lexer.base import MidasSyntaxError
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
from tests.base import Tester
from tests.serializer.midas import MidasAstJsonSerializer
@dataclass
class CaseResult:
tokens: Optional[list[dict]] = None
stmts: Optional[list[dict]] = None
errors: list[dict] = field(default_factory=list)
def dumps(self) -> str:
return json.dumps(asdict(self), indent=2)
class MidasTester(Tester):
@property
def namespace(self) -> str:
return "midas-parser"
def _list_tests(self) -> list[Path]:
return list(self.base_dir.rglob("*.midas"))
def _exec_case(self, path: Path) -> CaseResult:
if not path.exists():
raise FileNotFoundError(f"Could not find test '{path}'")
if not path.is_file():
raise TypeError(f"Test '{path}' is not a file")
result: CaseResult = CaseResult()
content: str = path.read_text()
lexer: MidasLexer = MidasLexer(content)
tokens: list[Token] = []
try:
tokens = lexer.process()
result.tokens = [
{
"type": token.type.name,
"lexeme": token.lexeme,
"line": token.position.line,
"column": token.position.column,
}
for token in tokens
]
except MidasSyntaxError as e:
result.errors.append(
{
"type": "SyntaxError",
"line": e.pos.line,
"column": e.pos.column,
"message": e.message,
}
)
return result
parser: MidasParser = MidasParser(tokens)
stmts: list[Stmt] = parser.parse()
result.stmts = MidasAstJsonSerializer().serialize(stmts)
result.errors.extend(
[
{
"line": e.token.position.line,
"column": e.token.position.column,
"message": e.message,
}
for e in parser.errors
]
)
return result
if __name__ == "__main__":
MidasTester.main()

46
tests/python.py Normal file
View File

@@ -0,0 +1,46 @@
import ast
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Optional
from midas.ast.python import Stmt
from midas.parser.python import PythonParser
from tests.base import Tester
from tests.serializer.python import PythonAstJsonSerializer
@dataclass
class CaseResult:
stmts: Optional[list[dict]] = None
def dumps(self) -> str:
return json.dumps(asdict(self), indent=2)
class PythonTester(Tester):
@property
def namespace(self) -> str:
return "python-parser"
def _list_tests(self) -> list[Path]:
return list(self.base_dir.rglob("*.py"))
def _exec_case(self, path: Path) -> CaseResult:
if not path.exists():
raise FileNotFoundError(f"Could not find test '{path}'")
if not path.is_file():
raise TypeError(f"Test '{path}' is not a file")
result: CaseResult = CaseResult()
content: str = path.read_text()
tree: ast.Module = ast.parse(content)
parser: PythonParser = PythonParser()
stmts: list[Stmt] = parser.parse_module(tree)
result.stmts = PythonAstJsonSerializer().serialize(stmts)
return result
if __name__ == "__main__":
PythonTester.main()

View File

@@ -2,56 +2,60 @@ from typing import Optional, Sequence
from midas.ast.midas import (
BinaryExpr,
ComplexTypeStmt,
ComplexType,
ConstraintType,
Expr,
ExtendStmt,
GenericType,
GetExpr,
GroupingExpr,
LiteralExpr,
LogicalExpr,
NamedType,
OpStmt,
PredicateStmt,
PropertyStmt,
SimpleTypeExpr,
SimpleTypeStmt,
Stmt,
TemplateExpr,
TypeExpr,
Type,
TypeStmt,
UnaryExpr,
VariableExpr,
WildcardExpr,
)
class AstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
class MidasAstJsonSerializer(
Stmt.Visitor[dict], Expr.Visitor[dict], Type.Visitor[dict]
):
"""An AST serializer which produces a JSON-compatible structure"""
def serialize(self, stmts: list[Stmt]) -> list[dict]:
return [stmt.accept(self) for stmt in stmts]
def _serialize_optional(self, element: Optional[Stmt | Expr]) -> Optional[dict]:
def _serialize_optional(
self, element: Optional[Stmt | Expr | Type]
) -> Optional[dict]:
if element is None:
return None
return element.accept(self)
def _serialize_list(self, elements: Sequence[Stmt | Expr]) -> list[dict]:
def _serialize_list(self, elements: Sequence[Stmt | Expr | Type]) -> list[dict]:
return [element.accept(self) for element in elements]
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> dict:
def visit_type_stmt(self, stmt: TypeStmt) -> dict:
return {
"_type": "SimpleTypeStmt",
"template": self._serialize_optional(stmt.template),
"_type": "TypeStmt",
"name": stmt.name.lexeme,
"base": stmt.base.accept(self),
"constraint": self._serialize_optional(stmt.constraint),
"params": [
self._serialize_type_stmt_template_param(param) for param in stmt.params
],
"type": stmt.type.accept(self),
}
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> dict:
def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict:
return {
"_type": "ComplexTypeStmt",
"name": stmt.name.lexeme,
"template": self._serialize_optional(stmt.template),
"properties": self._serialize_list(stmt.properties),
"name": param.name.lexeme,
"bound": self._serialize_optional(param.bound),
}
def visit_property_stmt(self, stmt: PropertyStmt) -> dict:
@@ -59,7 +63,6 @@ class AstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
"_type": "PropertyStmt",
"name": stmt.name.lexeme,
"type": stmt.type.accept(self),
"constraint": self._serialize_optional(stmt.constraint),
}
def visit_extend_stmt(self, stmt: ExtendStmt) -> dict:
@@ -86,13 +89,6 @@ class AstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
"condition": stmt.condition.accept(self),
}
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> dict:
return {
"_type": "SimpleTypeExpr",
"name": expr.name.lexeme,
"optional": expr.optional,
}
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
return {
"_type": "LogicalExpr",
@@ -144,16 +140,28 @@ class AstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
def visit_wildcard_expr(self, expr: WildcardExpr) -> dict:
return {"_type": "WildcardExpr"}
def visit_template_expr(self, expr: TemplateExpr) -> dict:
def visit_named_type(self, type: NamedType) -> dict:
return {
"_type": "TemplateExpr",
"type": expr.type.accept(self),
"_type": "NamedType",
"name": type.name.lexeme,
}
def visit_type_expr(self, expr: TypeExpr) -> dict:
def visit_generic_type(self, type: GenericType) -> dict:
return {
"_type": "TypeExpr",
"name": expr.name.lexeme,
"template": self._serialize_optional(expr.template),
"optional": expr.optional,
"_type": "GenericType",
"type": type.type.accept(self),
"params": self._serialize_list(type.params),
}
def visit_constraint_type(self, type: ConstraintType) -> dict:
return {
"_type": "ConstraintType",
"type": type.type.accept(self),
"constraint": type.constraint.accept(self),
}
def visit_complex_type(self, type: ComplexType) -> dict:
return {
"_type": "ComplexType",
"properties": self._serialize_list(type.properties),
}

256
tests/serializer/python.py Normal file
View File

@@ -0,0 +1,256 @@
import ast
from typing import Optional, Sequence, Type
from midas.ast.python import (
AssignStmt,
BaseType,
BinaryExpr,
CallExpr,
CastExpr,
CompareExpr,
ConstraintType,
Expr,
ExpressionStmt,
FrameColumn,
FrameType,
Function,
GetExpr,
IfStmt,
LiteralExpr,
LogicalExpr,
MidasType,
ReturnStmt,
SetExpr,
Stmt,
TernaryExpr,
TypeAssign,
UnaryExpr,
VariableExpr,
)
unary_ops: dict[Type[ast.unaryop], str] = {
ast.Invert: "~",
ast.Not: "not",
ast.UAdd: "+",
ast.USub: "-",
}
binary_ops: dict[Type[ast.operator], str] = {
ast.Add: "+",
ast.Sub: "-",
ast.Mult: "*",
ast.MatMult: "@",
ast.Div: "/",
ast.Mod: "%",
ast.LShift: "<<",
ast.RShift: ">>",
ast.BitOr: "|",
ast.BitXor: "^",
ast.BitAnd: "&",
ast.FloorDiv: "//",
ast.Pow: "**",
}
compare_ops: dict[Type[ast.cmpop], str] = {
ast.Eq: "==",
ast.NotEq: "!=",
ast.Lt: "<",
ast.LtE: "<=",
ast.Gt: ">",
ast.GtE: ">=",
ast.Is: "is",
ast.IsNot: "is not",
ast.In: "in",
ast.NotIn: "not in",
}
boolean_ops: dict[Type[ast.boolop], str] = {
ast.And: "and",
ast.Or: "or",
}
class PythonAstJsonSerializer(
Stmt.Visitor[dict], Expr.Visitor[dict], MidasType.Visitor[dict]
):
"""An AST serializer which produces a JSON-compatible structure"""
def serialize(self, stmts: list[Stmt]) -> list[dict]:
return [stmt.accept(self) for stmt in stmts]
def _serialize_optional(
self, element: Optional[Stmt | Expr | MidasType]
) -> Optional[dict]:
if element is None:
return None
return element.accept(self)
def _serialize_list(
self, elements: Sequence[Stmt | Expr | MidasType]
) -> list[dict]:
return [element.accept(self) for element in elements]
def visit_base_type(self, node: BaseType) -> dict:
return {
"_type": "BaseType",
"base": node.base,
"param": self._serialize_optional(node.param),
}
def visit_constraint_type(self, node: ConstraintType) -> dict:
return {
"_type": "ConstraintType",
"type": node.type.accept(self),
"constraint": ast.unparse(node.constraint),
}
def visit_frame_column(self, node: FrameColumn) -> dict:
return {
"_type": "FrameColumn",
"name": node.name,
"type": self._serialize_optional(node.type),
}
def visit_frame_type(self, node: FrameType) -> dict:
return {
"_type": "FrameType",
"columns": self._serialize_list(node.columns),
}
def visit_expression_stmt(self, stmt: ExpressionStmt) -> dict:
return {
"_type": "ExpressionStmt",
"expr": stmt.expr.accept(self),
}
def _serialize_argument(self, arg: Function.Argument) -> dict:
return {
"name": arg.name,
"type": self._serialize_optional(arg.type),
"default": self._serialize_optional(arg.default),
}
def visit_function(self, stmt: Function) -> dict:
return {
"_type": "Function",
"name": stmt.name,
"posonlyargs": [self._serialize_argument(arg) for arg in stmt.posonlyargs],
"args": [self._serialize_argument(arg) for arg in stmt.args],
"sink": (
self._serialize_argument(stmt.sink) if stmt.sink is not None else None
),
"kwonlyargs": [self._serialize_argument(arg) for arg in stmt.kwonlyargs],
"kw_sink": (
self._serialize_argument(stmt.kw_sink)
if stmt.kw_sink is not None
else None
),
"returns": self._serialize_optional(stmt.returns),
"body": self._serialize_list(stmt.body),
}
def visit_type_assign(self, stmt: TypeAssign) -> dict:
return {
"_type": "TypeAssign",
"name": stmt.name,
"type": stmt.type.accept(self),
}
def visit_assign_stmt(self, stmt: AssignStmt) -> dict:
return {
"_type": "AssignStmt",
"targets": self._serialize_list(stmt.targets),
"value": stmt.value.accept(self),
}
def visit_return_stmt(self, stmt: ReturnStmt) -> dict:
return {
"_type": "ReturnStmt",
"value": self._serialize_optional(stmt.value),
}
def visit_if_stmt(self, stmt: IfStmt) -> dict:
return {
"_type": "IfStmt",
"test": stmt.test.accept(self),
"body": self._serialize_list(stmt.body),
"orelse": self._serialize_list(stmt.orelse),
}
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
return {
"_type": "BinaryExpr",
"left": expr.left.accept(self),
"operator": binary_ops[expr.operator.__class__],
"right": expr.right.accept(self),
}
def visit_compare_expr(self, expr: CompareExpr) -> dict:
return {
"_type": "CompareExpr",
"left": expr.left.accept(self),
"operator": compare_ops[expr.operator.__class__],
"right": expr.right.accept(self),
}
def visit_unary_expr(self, expr: UnaryExpr) -> dict:
return {
"_type": "UnaryExpr",
"operator": unary_ops[expr.operator.__class__],
"right": expr.right.accept(self),
}
def visit_call_expr(self, expr: CallExpr) -> dict:
return {
"_type": "CallExpr",
"callee": expr.callee.accept(self),
"arguments": self._serialize_list(expr.arguments),
"keywords": {name: arg.accept(self) for name, arg in expr.keywords.items()},
}
def visit_get_expr(self, expr: GetExpr) -> dict:
return {
"_type": "GetExpr",
"object": expr.object.accept(self),
"name": expr.name,
}
def visit_literal_expr(self, expr: LiteralExpr) -> dict:
return {
"_type": "LiteralExpr",
"value": expr.value,
}
def visit_variable_expr(self, expr: VariableExpr) -> dict:
return {
"_type": "VariableExpr",
"name": expr.name,
}
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
return {
"_type": "LogicalExpr",
"left": expr.left.accept(self),
"operator": boolean_ops[expr.operator.__class__],
"right": expr.right.accept(self),
}
def visit_set_expr(self, expr: SetExpr) -> dict:
return {
"_type": "SetExpr",
"object": expr.object.accept(self),
"name": expr.name,
"value": expr.value.accept(self),
}
def visit_cast_expr(self, expr: CastExpr) -> dict:
return {
"_type": "CastExpr",
"type": expr.type.accept(self),
"expr": expr.expr.accept(self),
}
def visit_ternary_expr(self, expr: TernaryExpr) -> dict:
return {
"_type": "TernaryExpr",
"test": expr.test.accept(self),
"if_true": expr.if_true.accept(self),
"if_false": expr.if_false.accept(self),
}

View File

@@ -31,22 +31,32 @@
]
},
"type-base": {
"begin": "<",
"end": ">",
"begin": "(\\()([a-zA-Z_][a-zA-Z_\\d]*)(\\))",
"end": "$",
"beginCaptures": {
"0": {
"1": {
"name": "punctuation.definition.base.begin.midas"
}
},
"endCaptures": {
"0": {
"2": {
"name": "variable.name"
},
"3": {
"name": "punctuation.definition.base.end.midas"
}
},
"patterns": [
{"include": "source.python"}
{ "include": "#type-cond" }
]
},
"type-cond": {
"begin": "where",
"end": "$",
"beginCaptures": {
"0": {
"name": "keyword.control.where.midas"
}
}
},
"type-body": {
"begin": "\\{",
"end": "\\}",
@@ -61,7 +71,8 @@
}
},
"patterns": [
{"include": "#type-prop"}
{"include": "#type-prop"},
{"include": "#comment"}
]
},
"type-prop": {
@@ -78,44 +89,67 @@
}
}
},
"op-def": {
"match": "\\b(op)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>\\s+(\\S+)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>\\s+(=)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>",
"captures": {
"1": {
"name": "keyword.control.op.midas"
},
"2": {
"name" : "variable.name"
},
"3": {
"name" : "keyword.operator"
},
"4": {
"name" : "variable.name"
},
"5": {
"name" : "keyword.operator.assignment"
},
"6": {
"name" : "variable.name"
}
},
"patterns": [
{ "include": "#type-base" },
{ "include": "#type-body" }
]
},
"constr-def": {
"begin": "(constraint)\\s+([a-zA-Z_][a-zA-Z_\\d]*)\\s*(=)",
"end": "$",
"extend-def": {
"begin": "\\b(extend)\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\s+(\\{)",
"end": "\\}",
"beginCaptures": {
"1": {
"name": "keyword.control.constr.midas"
"name": "keyword.control.extend.midas"
},
"2": {
"name": "variable.name"
},
"3": {
"name": "punctuation.definition.extend-body.begin.midas"
}
},
"endCaptures": {
"0": {
"name": "punctuation.definition.extend-body.end.midas"
}
},
"patterns": [
{"include": "#op-def"},
{"include": "#comment"}
]
},
"op-def": {
"match": "\\b(op)\\s+(\\S+)\\s*\\(\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\s*\\)\\s*(->)\\s*([a-zA-Z_][a-zA-Z_\\d]*)",
"captures": {
"1": {
"name": "keyword.control.op.midas"
},
"2": {
"name" : "keyword.operator"
},
"3": {
"name" : "variable.name"
},
"4": {
"name" : "keyword.operator.assignment"
},
"5": {
"name" : "variable.name"
}
}
},
"pred-def": {
"begin": "(predicate)\\s+([a-zA-Z_][a-zA-Z_\\d]*)\\(([a-zA-Z_][a-zA-Z_\\d]*):\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\)\\s*(=)",
"end": "$",
"beginCaptures": {
"1": {
"name": "keyword.control.pred.midas"
},
"2": {
"name": "variable.name"
},
"3": {
"name": "variable.name"
},
"4": {
"name": "variable.name"
},
"5": {
"name": "keyword.operator.assignment"
}
},
@@ -127,8 +161,8 @@
"patterns": [
{ "include": "#comment" },
{ "include": "#type-def" },
{ "include": "#op-def" },
{ "include": "#constr-def" }
{ "include": "#extend-def" },
{ "include": "#pred-def" }
]
}
}