diff --git a/.gitignore b/.gitignore index b540ed8..66f4aa9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ venv .venv *.pyc uv.lock -.python-version \ No newline at end of file +.python-version +/out \ No newline at end of file diff --git a/docs/architecture.typ b/docs/architecture.typ new file mode 100644 index 0000000..d83bdc6 --- /dev/null +++ b/docs/architecture.typ @@ -0,0 +1,150 @@ +#import "@preview/cetz:0.5.2": canvas, draw + +#let diagram-only = false + +#set document( + title: [Midas Architecture], + //author: "Louis Heredero", +) + +#set text( + font: "Source Sans 3", +) + +#let diagram = canvas({ + let framed = draw.content.with( + padding: (x: .8em, y: 1em), + frame: "rect", + stroke: black, + ) + let arrow = draw.line.with(mark: (end: ">", fill: black)) + framed( + (0, 0), + name: "python-parser", + )[Python parser] + + draw.content( + (rel: (0, 1), to: "python-parser.north"), + padding: 5pt, + anchor: "south", + name: "source-py", + )[_`source.py`_] + arrow("source-py", "python-parser") + + framed( + (rel: (3, 0), to: "python-parser.east"), + anchor: "west", + name: "custom-parser", + align(center)[Custom python\ parser], + ) + + arrow("python-parser", "custom-parser", name: "arrow-python-ast") + draw.content( + "arrow-python-ast", + anchor: "south", + padding: 5pt, + )[`ast.Module`] + + framed( + (rel: (-3, -2), to: "custom-parser.south"), + anchor: "east", + name: "python-resolver", + )[Python Resolver] + arrow( + "custom-parser", + ((), "|-", "python-resolver.east"), + "python-resolver", + name: "arrow-python-custom-ast", + ) + draw.content( + (rel: (1.5, 0), to: "arrow-python-custom-ast.end"), + padding: 5pt, + anchor: "south", + )[P-AST#footnote[#strong[P]ython *AST*]] + draw.content( + "python-resolver.west", + padding: 5pt, + anchor: "south-east", + )[Resolved P-AST@fn-past] + + draw.circle( + (rel: (1, -2), to: "custom-parser.south-east"), + radius: .4, + name: "midas-loader", + ) + arrow( + "custom-parser", + "midas-loader", + name: "arrow-load-midas", + mark: (end: (symbol: ">", fill: black), start: "o"), + ) + draw.content( + "arrow-load-midas", + anchor: "west", + padding: 5pt, + )[```python midas.using("types.midas")```] + + framed( + (rel: (0, -2), to: "midas-loader.south"), + name: "midas-parser", + )[Midas lexer/parser] + arrow("midas-loader", "midas-parser", name: "arrow-midas-source") + draw.content( + "arrow-midas-source", + anchor: "west", + padding: 5pt, + )[_`types.midas`_] + + + framed( + (rel: (-2, 0), to: "midas-parser.west"), + anchor: "east", + name: "midas-resolver", + )[Midas Resolver] + arrow("midas-parser", "midas-resolver", name: "arrow-midas-ast") + draw.content( + "arrow-midas-ast", + anchor: "south", + padding: 5pt, + )[M-AST#footnote[#strong[M]idas *AST*]] + + framed( + (rel: (-3, 0), to: "midas-resolver.west"), + anchor: "east", + name: "checker", + )[Checker] + arrow("midas-resolver", "checker", name: "arrow-type-ctx") + arrow( + "python-resolver", + ((), "-|", "checker.north"), + "checker", + ) + draw.content( + "arrow-type-ctx", + anchor: "south", + padding: 5pt, + )[Types context] +}) + +#show: doc => if diagram-only { + set page(width: auto, height: auto, margin: .5cm) + diagram +} else { doc } + +#align(center, title()) + +#v(1cm) + +#figure( + diagram, + caption: [Midas type-checker architecture], +) + +== Components + +- *Python parser*: builtin Python AST parser, extracts abstract syntax from the raw Python source (```python ast.parse(...)```) +- *Custom python parser*: converts the raw Python AST into custom, more suitable constructs, especially for type annotations +- *Python resolver*: resolves bindings and references, tracks binding scopes +- *Midas lexer/parser*: parses a Midas type definition file and extracts its AST +- *Midas resolver*: walks the AST and fills the environment with the defined types and operations +- *Checker*: evaluates expressions and checks type coherence diff --git a/examples/00_syntax_prototype/02_custom_types.py b/examples/00_syntax_prototype/02_custom_types.py index 16bf442..8678ba0 100644 --- a/examples/00_syntax_prototype/02_custom_types.py +++ b/examples/00_syntax_prototype/02_custom_types.py @@ -2,10 +2,6 @@ # ruff: disable[F821] from __future__ import annotations -# Prototype of custom type import to use valid Python syntax -import midas -midas.using("02_custom_types.midas") - # A data-frame using a custom type df: Frame[ location: GeoLocation diff --git a/examples/00_syntax_prototype/05_custom_types_v3.midas b/examples/00_syntax_prototype/05_custom_types_v3.midas new file mode 100644 index 0000000..a339318 --- /dev/null +++ b/examples/00_syntax_prototype/05_custom_types_v3.midas @@ -0,0 +1,33 @@ +type Foo1 = float +type Foo2 = float where (_ > 3) +type Foo3 = int | float +type Foo4 = int where (_ > 3) | float where (_ > 3) +type Foo5 = (int | float) where (_ > 3) +type Foo6 = { + foo: float + bar: float where (_ > 3) +} + +type Foo7[T] = T where (_ > 3) +type Foo8[A, B<:int] = { + a: A + b: B +} + +type Complex = { + a: int + b: int +} +type Complex2 = Complex where (_.a > 3 & _.b < 5) + +predicate Positive(n: int) = n >= 0 + +extend Foo1 { + op __add__(Foo1) -> Foo1 +} + +extend Foo7[T] { + op __add__(Foo7[T]) -> Foo7[T] +} + +type Optional[T] = None | T diff --git a/examples/01_simple_type_checking/01_simple_operations.py b/examples/01_simple_type_checking/01_simple_operations.py new file mode 100644 index 0000000..a3ac707 --- /dev/null +++ b/examples/01_simple_type_checking/01_simple_operations.py @@ -0,0 +1,11 @@ +a: int = 3 +b: int = 4 + +c = a + b # -> int + +c = "invalid" # -> can't assign str to int variable + +d = True +e = d + d + +f: float = a diff --git a/examples/01_simple_type_checking/02_simple_types.midas b/examples/01_simple_type_checking/02_simple_types.midas new file mode 100644 index 0000000..6a1a6a2 --- /dev/null +++ b/examples/01_simple_type_checking/02_simple_types.midas @@ -0,0 +1,14 @@ +type Meter = float +type Second = float +type MeterPerSecond = float + +extend Meter { + op __add__(Meter) -> Meter + op __sub__(Meter) -> Meter + op __truediv__(Second) -> MeterPerSecond +} + +extend Second { + op __add__(Second) -> Second + op __sub__(Second) -> Second +} diff --git a/examples/01_simple_type_checking/02_simple_types.py b/examples/01_simple_type_checking/02_simple_types.py new file mode 100644 index 0000000..c015a75 --- /dev/null +++ b/examples/01_simple_type_checking/02_simple_types.py @@ -0,0 +1,6 @@ +# type: ignore +# ruff: disable [F821] + +distance: Meter = cast(Meter, 123.45) +time: Second = cast(Second, 6.7) +speed = distance / time diff --git a/examples/01_simple_type_checking/03_control_flow.py b/examples/01_simple_type_checking/03_control_flow.py new file mode 100644 index 0000000..772c9ac --- /dev/null +++ b/examples/01_simple_type_checking/03_control_flow.py @@ -0,0 +1,23 @@ +def minimum(x: int, y: int): + if x < y: + return x + else: + return y + + +a = 15 +b = 72 +c = minimum(a, b) + + +def factorial(n: int) -> int: + if n <= 1: + return 1 + return n * factorial(n - 1) + + +category = "Category 1" if a < 10 else "Category 2" + + +def foo() -> None: + pass diff --git a/gen/gen.py b/gen/gen.py index 75e6100..e78c872 100644 --- a/gen/gen.py +++ b/gen/gen.py @@ -1,5 +1,5 @@ -from pathlib import Path import re +from pathlib import Path HEADER = '''""" This file was generated by a script. Any manual changes might be overwritten. @@ -11,7 +11,7 @@ SECTION_TEMPLATE = """{banner} @dataclass(frozen=True, kw_only=True) class {base}(ABC): - location: Optional[Location] = None + location: Location @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... diff --git a/gen/midas.py b/gen/midas.py index 7187554..e1c304d 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -13,40 +13,38 @@ from midas.lexer.token import Token ###> Stmt | Statements -class SimpleTypeStmt: +class TypeStmt: name: Token - template: Optional[TemplateExpr] - base: TypeExpr - constraint: Optional[Expr] + params: list[Param] + type: Type - -class ComplexTypeStmt: - name: Token - template: Optional[TemplateExpr] - properties: list[PropertyStmt] + @dataclass(frozen=True, kw_only=True) + class Param: + location: Location + name: Token + bound: Optional[Type] class PropertyStmt: name: Token - type: TypeExpr - constraint: Optional[Expr] + type: Type class ExtendStmt: - type: TypeExpr + type: Type operations: list[OpStmt] class OpStmt: name: Token - operand: TypeExpr - result: TypeExpr + operand: Type + result: Type class PredicateStmt: name: Token subject: Token - type: TypeExpr + type: Type condition: Expr @@ -54,9 +52,6 @@ class PredicateStmt: ###> Expr | Expressions -class SimpleTypeExpr: - name: Token - optional: bool class LogicalExpr: @@ -97,14 +92,27 @@ class WildcardExpr: token: Token -class TemplateExpr: - type: TypeExpr +###< + +###> Type | Types -class TypeExpr: +class NamedType: name: Token - template: Optional[TemplateExpr] - optional: bool + + +class GenericType: + type: Type + params: list[Type] + + +class ConstraintType: + type: Type + constraint: Expr + + +class ComplexType: + properties: list[PropertyStmt] ###< diff --git a/gen/python.py b/gen/python.py index 0aadd57..09d21b8 100644 --- a/gen/python.py +++ b/gen/python.py @@ -44,14 +44,22 @@ class Function: name: str posonlyargs: list[Argument] args: list[Argument] + sink: Optional[Argument] kwonlyargs: list[Argument] + kw_sink: Optional[Argument] returns: Optional[MidasType] + body: list[Stmt] @dataclass(frozen=True, kw_only=True) class Argument: location: Optional[Location] = None - name: Optional[str] + name: str type: Optional[MidasType] + default: Optional[Expr] + + @property + def all_args(self) -> list[Argument]: + return self.posonlyargs + self.args + self.kwonlyargs class TypeAssign: @@ -64,6 +72,16 @@ class AssignStmt: value: Expr +class ReturnStmt: + value: Optional[Expr] + + +class IfStmt: + test: Expr + body: list[Stmt] + orelse: list[Stmt] + + ###< @@ -116,4 +134,15 @@ class SetExpr: value: Expr +class CastExpr: + type: MidasType + expr: Expr + + +class TernaryExpr: + test: Expr + if_true: Expr + if_false: Expr + + ###< diff --git a/midas/ast/midas.py b/midas/ast/midas.py index 9cea8c2..335e5cf 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -21,17 +21,14 @@ T = TypeVar("T") @dataclass(frozen=True, kw_only=True) class Stmt(ABC): - location: Optional[Location] = None + location: Location @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... class Visitor(ABC, Generic[T]): @abstractmethod - def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> T: ... - - @abstractmethod - def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> T: ... + def visit_type_stmt(self, stmt: TypeStmt) -> T: ... @abstractmethod def visit_property_stmt(self, stmt: PropertyStmt) -> T: ... @@ -47,31 +44,25 @@ class Stmt(ABC): @dataclass(frozen=True) -class SimpleTypeStmt(Stmt): +class TypeStmt(Stmt): name: Token - template: Optional[TemplateExpr] - base: TypeExpr - constraint: Optional[Expr] + params: list[Param] + type: Type + + @dataclass(frozen=True, kw_only=True) + class Param: + location: Location + name: Token + bound: Optional[Type] def accept(self, visitor: Stmt.Visitor[T]) -> T: - return visitor.visit_simple_type_stmt(self) - - -@dataclass(frozen=True) -class ComplexTypeStmt(Stmt): - name: Token - template: Optional[TemplateExpr] - properties: list[PropertyStmt] - - def accept(self, visitor: Stmt.Visitor[T]) -> T: - return visitor.visit_complex_type_stmt(self) + return visitor.visit_type_stmt(self) @dataclass(frozen=True) class PropertyStmt(Stmt): name: Token - type: TypeExpr - constraint: Optional[Expr] + type: Type def accept(self, visitor: Stmt.Visitor[T]) -> T: return visitor.visit_property_stmt(self) @@ -79,7 +70,7 @@ class PropertyStmt(Stmt): @dataclass(frozen=True) class ExtendStmt(Stmt): - type: TypeExpr + type: Type operations: list[OpStmt] def accept(self, visitor: Stmt.Visitor[T]) -> T: @@ -89,8 +80,8 @@ class ExtendStmt(Stmt): @dataclass(frozen=True) class OpStmt(Stmt): name: Token - operand: TypeExpr - result: TypeExpr + operand: Type + result: Type def accept(self, visitor: Stmt.Visitor[T]) -> T: return visitor.visit_op_stmt(self) @@ -100,7 +91,7 @@ class OpStmt(Stmt): class PredicateStmt(Stmt): name: Token subject: Token - type: TypeExpr + type: Type condition: Expr def accept(self, visitor: Stmt.Visitor[T]) -> T: @@ -114,15 +105,12 @@ class PredicateStmt(Stmt): @dataclass(frozen=True, kw_only=True) class Expr(ABC): - location: Optional[Location] = None + location: Location @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... class Visitor(ABC, Generic[T]): - @abstractmethod - def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> T: ... - @abstractmethod def visit_logical_expr(self, expr: LogicalExpr) -> T: ... @@ -147,21 +135,6 @@ class Expr(ABC): @abstractmethod def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ... - @abstractmethod - def visit_template_expr(self, expr: TemplateExpr) -> T: ... - - @abstractmethod - def visit_type_expr(self, expr: TypeExpr) -> T: ... - - -@dataclass(frozen=True) -class SimpleTypeExpr(Expr): - name: Token - optional: bool - - def accept(self, visitor: Expr.Visitor[T]) -> T: - return visitor.visit_simple_type_expr(self) - @dataclass(frozen=True) class LogicalExpr(Expr): @@ -233,19 +206,61 @@ class WildcardExpr(Expr): return visitor.visit_wildcard_expr(self) -@dataclass(frozen=True) -class TemplateExpr(Expr): - type: TypeExpr +######### +# Types # +######### - def accept(self, visitor: Expr.Visitor[T]) -> T: - return visitor.visit_template_expr(self) + +@dataclass(frozen=True, kw_only=True) +class Type(ABC): + location: Location + + @abstractmethod + def accept(self, visitor: Visitor[T]) -> T: ... + + class Visitor(ABC, Generic[T]): + @abstractmethod + def visit_named_type(self, type: NamedType) -> T: ... + + @abstractmethod + def visit_generic_type(self, type: GenericType) -> T: ... + + @abstractmethod + def visit_constraint_type(self, type: ConstraintType) -> T: ... + + @abstractmethod + def visit_complex_type(self, type: ComplexType) -> T: ... @dataclass(frozen=True) -class TypeExpr(Expr): +class NamedType(Type): name: Token - template: Optional[TemplateExpr] - optional: bool - def accept(self, visitor: Expr.Visitor[T]) -> T: - return visitor.visit_type_expr(self) + def accept(self, visitor: Type.Visitor[T]) -> T: + return visitor.visit_named_type(self) + + +@dataclass(frozen=True) +class GenericType(Type): + type: Type + params: list[Type] + + def accept(self, visitor: Type.Visitor[T]) -> T: + return visitor.visit_generic_type(self) + + +@dataclass(frozen=True) +class ConstraintType(Type): + type: Type + constraint: Expr + + def accept(self, visitor: Type.Visitor[T]) -> T: + return visitor.visit_constraint_type(self) + + +@dataclass(frozen=True) +class ComplexType(Type): + properties: list[PropertyStmt] + + def accept(self, visitor: Type.Visitor[T]) -> T: + return visitor.visit_complex_type(self) diff --git a/midas/ast/printer.py b/midas/ast/printer.py index e3ecde9..41dd6a0 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -85,40 +85,39 @@ class AstPrinter(Generic[T]): child.accept(self) -class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]): +class MidasAstPrinter( + AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None], m.Type.Visitor[None] +): # Statements - def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt): - self._write_line("SimpleTypeStmt") + def visit_type_stmt(self, stmt: m.TypeStmt) -> None: + self._write_line("TypeStmt") with self._child_level(): self._write_line(f'name: "{stmt.name.lexeme}"') - self._write_optional_child("template", stmt.template) - self._write_line("base") - with self._child_level(single=True): - stmt.base.accept(self) - self._write_optional_child("constraint", stmt.constraint, last=True) - - def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt): - self._write_line("ComplexTypeStmt") - with self._child_level(): - self._write_line(f'name: "{stmt.name.lexeme}"') - self._write_optional_child("template", stmt.template) - self._write_line("properties", last=True) + self._write_line("params") with self._child_level(): - for i, prop in enumerate(stmt.properties): + for i, param in enumerate(stmt.params): self._idx = i - if i == len(stmt.properties) - 1: + if i == len(stmt.params) - 1: self._mark_last() - prop.accept(self) + self._print_type_stmt_param(param) + self._write_line("type", last=True) + with self._child_level(single=True): + stmt.type.accept(self) + + def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None: + self._write_line("Param") + with self._child_level(): + self._write_line(f'name: "{param.name.lexeme}"') + self._write_optional_child("bound", param.bound, last=True) def visit_property_stmt(self, stmt: m.PropertyStmt): self._write_line("PropertyStmt") with self._child_level(): self._write_line(f'name: "{stmt.name.lexeme}"') - self._write_line("type") + self._write_line("type", last=True) with self._child_level(single=True): stmt.type.accept(self) - self._write_optional_child("constraint", stmt.constraint, last=True) def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: self._write_line("ExtendStmt") @@ -161,12 +160,6 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]): # Expressions - def visit_simple_type_expr(self, expr: m.SimpleTypeExpr): - self._write_line("SimpleTypeExpr") - with self._child_level(): - self._write_line(f'name: "{expr.name.lexeme}"') - self._write_line(f"optional: {expr.optional}", last=True) - def visit_logical_expr(self, expr: m.LogicalExpr): self._write_line("LogicalExpr") with self._child_level(): @@ -230,22 +223,48 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]): def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: self._write_line("WildcardExpr") - def visit_template_expr(self, expr: m.TemplateExpr) -> None: - self._write_line("TemplateExpr") - with self._child_level(single=True): + def visit_named_type(self, type: m.NamedType) -> None: + self._write_line("NamedType") + with self._child_level(): + self._write_line(f'name: "{type.name.lexeme}"', last=True) + + def visit_generic_type(self, type: m.GenericType) -> None: + self._write_line("GenericType") + with self._child_level(): + self._write_line("type") + with self._child_level(): + type.type.accept(self) + self._write_line("params", last=True) + with self._child_level(): + for i, param in enumerate(type.params): + self._idx = i + if i == len(type.params) - 1: + self._mark_last() + param.accept(self) + + def visit_constraint_type(self, type: m.ConstraintType) -> None: + self._write_line("ConstraintType") + with self._child_level(): self._write_line("type") with self._child_level(single=True): - expr.type.accept(self) + type.type.accept(self) + self._write_line("constraint", last=True) + with self._child_level(single=True): + type.constraint.accept(self) - def visit_type_expr(self, expr: m.TypeExpr): - self._write_line("TypeExpr") + def visit_complex_type(self, type: m.ComplexType) -> None: + self._write_line("ComplexType") with self._child_level(): - self._write_line(f'name: "{expr.name.lexeme}"') - self._write_optional_child("template", expr.template) - self._write_line(f"optional: {expr.optional}", last=True) + self._write_line("properties", last=True) + with self._child_level(): + for i, prop in enumerate(type.properties): + self._idx = i + if i == len(type.properties) - 1: + self._mark_last() + prop.accept(self) -class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): +class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]): def __init__(self, indent: int = 4): self.indent: int = indent self.level: int = 0 @@ -253,33 +272,28 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): def indented(self, text: str) -> str: return " " * (self.level * self.indent) + text - def print(self, expr: m.Expr | m.Stmt): + def print(self, expr: m.Expr | m.Stmt | m.Type) -> str: self.level = 0 return expr.accept(self) - def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt): - template: str = stmt.template.accept(self) if stmt.template is not None else "" - res: str = f"type {stmt.name.lexeme}{template}({stmt.base.accept(self)})" - if stmt.constraint is not None: - res += " where " + stmt.constraint.accept(self) + def visit_type_stmt(self, stmt: m.TypeStmt) -> str: + template: str = "" + if len(stmt.params) != 0: + params: list[str] = [ + self._print_type_template_param(param) for param in stmt.params + ] + template = f"[{', '.join(params)}]" + res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}" return self.indented(res) - def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt): - template: str = stmt.template.accept(self) if stmt.template is not None else "" - res: str = self.indented(f"type {stmt.name.lexeme}{template}") - res += " {\n" - self.level += 1 - for prop in stmt.properties: - res += prop.accept(self) - res += "\n" - self.level -= 1 - res += self.indented("}") + def _print_type_template_param(self, param: m.TypeStmt.Param) -> str: + res: str = param.name.lexeme + if param.bound is not None: + res += "<:" + param.bound.accept(self) return res def visit_property_stmt(self, stmt: m.PropertyStmt): res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}" - if stmt.constraint is not None: - res += " where " + stmt.constraint.accept(self) return self.indented(res) def visit_extend_stmt(self, stmt: m.ExtendStmt): @@ -289,13 +303,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): for op in stmt.operations: res += op.accept(self) self.level -= 1 - res += "\n" + self.indented("}") + res += self.indented("}") return res def visit_op_stmt(self, stmt: m.OpStmt): operand: str = stmt.operand.accept(self) result: str = stmt.result.accept(self) - return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}") + return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}\n") def visit_predicate_stmt(self, stmt: m.PredicateStmt): name: str = stmt.name.lexeme @@ -304,9 +318,6 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): condition: str = stmt.condition.accept(self) return self.indented(f"predicate {name}({subject}: {type}) = {condition}") - def visit_simple_type_expr(self, expr: m.SimpleTypeExpr): - return f"{expr.name.lexeme}{'?' if expr.optional else ''}" - def visit_logical_expr(self, expr: m.LogicalExpr): left: str = expr.left.accept(self) operator: str = expr.operator.lexeme @@ -342,12 +353,30 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): def visit_wildcard_expr(self, expr: m.WildcardExpr): return "_" - def visit_template_expr(self, expr: m.TemplateExpr): - return f"[{expr.type.accept(self)}]" + def visit_named_type(self, type: m.NamedType) -> str: + return type.name.lexeme - def visit_type_expr(self, expr: m.TypeExpr): - template: str = expr.template.accept(self) if expr.template is not None else "" - return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}" + def visit_generic_type(self, type: m.GenericType) -> str: + res: str = type.type.accept(self) + if len(type.params) != 0: + params: list[str] = [param.accept(self) for param in type.params] + res += f"[{', '.join(params)}]" + return res + + def visit_constraint_type(self, type: m.ConstraintType) -> str: + res: str = type.type.accept(self) + res += " where " + type.constraint.accept(self) + return res + + def visit_complex_type(self, type: m.ComplexType) -> str: + res: str = "{\n" + self.level += 1 + for prop in type.properties: + res += prop.accept(self) + res += "\n" + self.level -= 1 + res += self.indented("}") + return res class PythonAstPrinter( @@ -419,7 +448,14 @@ class PythonAstPrinter( self._mark_last() self._print_argument(arg) - self._write_optional_child("returns", stmt.returns, last=True) + self._write_optional_child("returns", stmt.returns) + self._write_line("body", last=True) + with self._child_level(): + for i, body_stmt in enumerate(stmt.body): + self._idx = i + if i == len(stmt.body) - 1: + self._mark_last() + body_stmt.accept(self) def _print_argument(self, arg: p.Function.Argument) -> None: self._write_line("FunctionArgument") @@ -449,6 +485,32 @@ class PythonAstPrinter( with self._child_level(single=True): stmt.value.accept(self) + def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: + self._write_line("ReturnStmt") + with self._child_level(): + self._write_optional_child("value", stmt.value, last=True) + + def visit_if_stmt(self, stmt: p.IfStmt) -> None: + self._write_line("IfStmt") + with self._child_level(): + self._write_line("test") + with self._child_level(single=True): + stmt.test.accept(self) + self._write_line("body") + with self._child_level(): + for i, body_stmt in enumerate(stmt.body): + self._idx = i + if i == len(stmt.body) - 1: + self._mark_last() + body_stmt.accept(self) + self._write_line("orelse", last=True) + with self._child_level(): + for i, else_stmt in enumerate(stmt.orelse): + self._idx = i + if i == len(stmt.orelse) - 1: + self._mark_last() + else_stmt.accept(self) + def visit_binary_expr(self, expr: p.BinaryExpr) -> None: self._write_line("BinaryExpr") with self._child_level(): @@ -550,3 +612,28 @@ class PythonAstPrinter( self._write_line("value", last=True) with self._child_level(single=True): expr.value.accept(self) + + def visit_cast_expr(self, expr: p.CastExpr) -> None: + self._write_line("CastExpr") + with self._child_level(): + self._write_line("type") + with self._child_level(single=True): + expr.type.accept(self) + self._write_line("expr", last=True) + with self._child_level(single=True): + expr.expr.accept(self) + + def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: + self._write_line("TernaryExpr") + with self._child_level(): + self._write_line("test") + with self._child_level(single=True): + expr.test.accept(self) + + self._write_line("if_true") + with self._child_level(single=True): + expr.if_true.accept(self) + + self._write_line("if_false", last=True) + with self._child_level(single=True): + expr.if_false.accept(self) diff --git a/midas/ast/python.py b/midas/ast/python.py index d4fc032..8607cd2 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -21,7 +21,7 @@ T = TypeVar("T") @dataclass(frozen=True, kw_only=True) class MidasType(ABC): - location: Optional[Location] = None + location: Location @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... @@ -82,7 +82,7 @@ class FrameType(MidasType): @dataclass(frozen=True, kw_only=True) class Stmt(ABC): - location: Optional[Location] = None + location: Location @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... @@ -100,6 +100,12 @@ class Stmt(ABC): @abstractmethod def visit_assign_stmt(self, stmt: AssignStmt) -> T: ... + @abstractmethod + def visit_return_stmt(self, stmt: ReturnStmt) -> T: ... + + @abstractmethod + def visit_if_stmt(self, stmt: IfStmt) -> T: ... + @dataclass(frozen=True) class ExpressionStmt(Stmt): @@ -114,14 +120,22 @@ class Function(Stmt): name: str posonlyargs: list[Argument] args: list[Argument] + sink: Optional[Argument] kwonlyargs: list[Argument] + kw_sink: Optional[Argument] returns: Optional[MidasType] + body: list[Stmt] @dataclass(frozen=True, kw_only=True) class Argument: location: Optional[Location] = None - name: Optional[str] + name: str type: Optional[MidasType] + default: Optional[Expr] + + @property + def all_args(self) -> list[Argument]: + return self.posonlyargs + self.args + self.kwonlyargs def accept(self, visitor: Stmt.Visitor[T]) -> T: return visitor.visit_function(self) @@ -145,6 +159,24 @@ class AssignStmt(Stmt): return visitor.visit_assign_stmt(self) +@dataclass(frozen=True) +class ReturnStmt(Stmt): + value: Optional[Expr] + + def accept(self, visitor: Stmt.Visitor[T]) -> T: + return visitor.visit_return_stmt(self) + + +@dataclass(frozen=True) +class IfStmt(Stmt): + test: Expr + body: list[Stmt] + orelse: list[Stmt] + + def accept(self, visitor: Stmt.Visitor[T]) -> T: + return visitor.visit_if_stmt(self) + + ############### # Expressions # ############### @@ -152,7 +184,7 @@ class AssignStmt(Stmt): @dataclass(frozen=True, kw_only=True) class Expr(ABC): - location: Optional[Location] = None + location: Location @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... @@ -185,6 +217,12 @@ class Expr(ABC): @abstractmethod def visit_set_expr(self, expr: SetExpr) -> T: ... + @abstractmethod + def visit_cast_expr(self, expr: CastExpr) -> T: ... + + @abstractmethod + def visit_ternary_expr(self, expr: TernaryExpr) -> T: ... + @dataclass(frozen=True) class BinaryExpr(Expr): @@ -268,3 +306,22 @@ class SetExpr(Expr): def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_set_expr(self) + + +@dataclass(frozen=True) +class CastExpr(Expr): + type: MidasType + expr: Expr + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_cast_expr(self) + + +@dataclass(frozen=True) +class TernaryExpr(Expr): + test: Expr + if_true: Expr + if_false: Expr + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_ternary_expr(self) diff --git a/midas/checker/checker.py b/midas/checker/checker.py new file mode 100644 index 0000000..a96b472 --- /dev/null +++ b/midas/checker/checker.py @@ -0,0 +1,540 @@ +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import midas.ast.midas as m +import midas.ast.python as p +from midas.ast.location import Location +from midas.checker.diagnostic import Diagnostic, DiagnosticType +from midas.checker.environment import Environment +from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS +from midas.checker.types import Function, Type, UnitType, UnknownType +from midas.lexer.midas import MidasLexer +from midas.lexer.token import Token +from midas.parser.midas import MidasParser +from midas.resolver.midas import MidasResolver + + +class ReturnException(Exception): + pass + + +@dataclass(frozen=True, kw_only=True) +class MappedArgument: + expr: p.Expr + type: Type + argument: Function.Argument + + +class Checker( + p.Stmt.Visitor[None], + p.Expr.Visitor[Type], + p.MidasType.Visitor[Type], +): + """A type checker which can use custom type definitions""" + + def __init__( + self, + locals: dict[p.Expr, int], + source_path: Path, + types_paths: list[Path], + ): + self.logger: logging.Logger = logging.getLogger("Checker") + self.source_path: Path = source_path + self.types_paths: list[Path] = types_paths + self.ctx: MidasResolver = MidasResolver() + self.global_env: Environment = Environment() + self.env: Environment = self.global_env + self.locals: dict[p.Expr, int] = locals + self.diagnostics: list[Diagnostic] = [] + + def diagnostic(self, type: DiagnosticType, location: Location, message: str): + self.diagnostics.append( + Diagnostic( + file_path=self.source_path, + location=location, + type=type, + message=message, + ) + ) + + def error(self, location: Location, message: str): + self.diagnostic( + type=DiagnosticType.ERROR, + location=location, + message=message, + ) + + def warning(self, location: Location, message: str): + self.diagnostic( + type=DiagnosticType.WARNING, + location=location, + message=message, + ) + + def info(self, location: Location, message: str): + self.diagnostic( + type=DiagnosticType.INFO, + location=location, + message=message, + ) + + def type_of(self, expr: p.Expr) -> Type: + """Evaluate the type of an expression + + Args: + expr (p.Expr): the expression to evaluate + + Returns: + Type: the type of the given expression + """ + return expr.accept(self) + + def process_block(self, block: list[p.Stmt], env: Environment) -> bool: + """Evaluate a sequence of statements + + Args: + block (list[p.Stmt]): the statements to evaluate + env (Environment): the environment in which to evaluate + + Returns: + bool: whether a return statement is present in the block + """ + previous_env: Environment = self.env + self.env = env + returned: bool = False + for i, stmt in enumerate(block): + try: + stmt.accept(self) + except ReturnException: + returned = True + if i < len(block) - 1: + self.warning(block[i + 1].location, "Unreachable statement") + break + self.env = previous_env + return returned + + def check(self, statements: list[p.Stmt]) -> list[Diagnostic]: + """Type check a sequence of statements and returns diagnostics + + Args: + statements (list[p.Stmt]): the statements to evaluate and check + + Returns: + list[Diagnostic]: the list of diagnostics (errors, warning, etc.) + """ + self.diagnostics = [] + + for path in self.types_paths: + self.import_midas(path) + self.logger.debug(f"Midas types: {self.ctx._types}") + self.logger.debug(f"Midas operations: {self.ctx._operations}") + + for stmt in statements: + stmt.accept(self) + + self.logger.debug(f"Final environment: {self.env.flat_dict()}") + return self.diagnostics + + def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]: + """Look up a variable in the environment it was declared + + Args: + name (str): the name of the variable + expr (p.Expr): the variable expression, used to lookup the scope distance + + Returns: + Optional[Type]: the type of the variable, or None if it was not found + """ + distance: Optional[int] = self.locals.get(expr) + if distance is not None: + return self.env.get_at(distance, name) + return self.global_env.get(name) + + def import_midas(self, path: Path) -> None: + """Import Midas definitions from a path + + Args: + path (Path): the import path + """ + self.logger.debug(f"Importing type definitions from {path}") + lexer: MidasLexer = MidasLexer(path.read_text()) + tokens: list[Token] = lexer.process() + parser: MidasParser = MidasParser(tokens) + stmts: list[m.Stmt] = parser.parse() + self.ctx.resolve(stmts) + + def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: + self.type_of(stmt.expr) + + def visit_function(self, stmt: p.Function) -> None: + env: Environment = Environment(self.env) + pos_args: list[Function.Argument] = [] + args: list[Function.Argument] = [] + kw_args: list[Function.Argument] = [] + + def eval_arg_type(arg: p.Function.Argument) -> Type: + if arg.type is not None: + return arg.type.accept(self) + if arg.default is not None: + return arg.default.accept(self) + return UnknownType() + + for arg in stmt.posonlyargs: + pos_args.append( + Function.Argument( + name=arg.name, + type=eval_arg_type(arg), + required=arg.default is None, + ) + ) + for arg in stmt.args: + args.append( + Function.Argument( + name=arg.name, + type=eval_arg_type(arg), + required=arg.default is None, + ) + ) + for arg in stmt.kwonlyargs: + kw_args.append( + Function.Argument( + name=arg.name, + type=eval_arg_type(arg), + required=arg.default is None, + ) + ) + + for arg in pos_args + args + kw_args: + env.define(arg.name, arg.type) + + returns_hint: Optional[Type] = None + if stmt.returns is not None: + returns_hint = stmt.returns.accept(self) + # Early define to handle simple fully-typed recursion + inside_function: Function = Function( + name=stmt.name, + pos_args=pos_args, + args=args, + kw_args=kw_args, + returns=returns_hint, + ) + self.env.define(stmt.name, inside_function) + + returned: bool = self.process_block(stmt.body, env) + inferred_return: Type = UnknownType() + if not returned: + env.return_types.append(UnitType()) + return_types: set[Type] = set(env.return_types) + if len(return_types) == 1: + inferred_return = list(return_types)[0] + elif len(return_types) > 1: + self.error( + stmt.location, + f"Mixed return types: {env.return_types}", + ) + + returns: Type = UnknownType() + if returns_hint is not None: + assert stmt.returns is not None + returns = returns_hint + if returns != inferred_return: + self.error( + stmt.returns.location, + f"Return type mismatch, annotated {returns} but returns {inferred_return}", + ) + else: + returns = inferred_return + + # TODO: handle *args and **kwargs sinks + function: Function = Function( + name=stmt.name, + pos_args=pos_args, + args=args, + kw_args=kw_args, + returns=returns, + ) + self.env.define(stmt.name, function) + + def visit_type_assign(self, stmt: p.TypeAssign) -> None: + # TODO check not yet defined locally + type: Type = stmt.type.accept(self) + self.env.define(stmt.name, type) + + def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: + value: Type = self.type_of(stmt.value) + for target in stmt.targets: + if not isinstance(target, p.VariableExpr): + self.logger.warning(f"Unsupported assignment to {target}") + self.warning(target.location, f"Unsupported assignment to {target}") + continue + name: str = target.name + var_type: Optional[Type] = self.look_up_variable(name, target) + + if var_type is None: + self.env.define(name, value) + else: + # TODO: implement real comparison method + if var_type != value: + self.error( + stmt.location, + f"Cannot assign {value} to {name} of type {var_type}", + ) + + def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: + type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType() + self.env.return_types.append(type) + raise ReturnException() + + def visit_if_stmt(self, stmt: p.IfStmt) -> None: + # Not evaluated in sub-environment because assignments in the test leak out of the if + # For example: + # if (m := 1 + 1) < 2: + # ... + # print(m) # <- m is still defined + test_type: Type = stmt.test.accept(self) + + # TODO Allow subtypes or any type + if test_type != self.ctx.get_type("bool"): + self.error( + stmt.test.location, f"If test must be a boolean, got {test_type}" + ) + + env: Environment = Environment(self.env) + body_returned: bool = self.process_block(stmt.body, env) + else_returned: bool = self.process_block(stmt.orelse, env) + self.env.return_types.extend(env.return_types) + if body_returned and else_returned: + raise ReturnException() + + def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: + method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) + if method is None: + self.logger.warning(f"Unsupported operator {expr.operator}") + self.warning(expr.location, f"Unsupported operator {expr.operator}") + return UnknownType() + left: Type = self.type_of(expr.left) + right: Type = self.type_of(expr.right) + + result: Optional[Type] = self.ctx.get_operation_result(left, method, right) + if result is None: + self.error( + expr.location, + f"Undefined operation {method} between {left} and {right}", + ) + return UnknownType() + return result + + def visit_compare_expr(self, expr: p.CompareExpr) -> Type: + method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__) + if method is None: + self.logger.warning(f"Unsupported operator {expr.operator}") + self.warning(expr.location, f"Unsupported operator {expr.operator}") + return UnknownType() + left: Type = self.type_of(expr.left) + right: Type = self.type_of(expr.right) + + result: Optional[Type] = self.ctx.get_operation_result(left, method, right) + if result is None: + self.error( + expr.location, + f"Undefined operation {method} between {left} and {right}", + ) + return UnknownType() + return result + + def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ... + + def visit_call_expr(self, expr: p.CallExpr) -> Type: + callee: Type = self.type_of(expr.callee) + if not isinstance(callee, Function): + self.error(expr.callee.location, "Callee is not a function") + return UnknownType() + function: Function = callee + mapped: list[MappedArgument] = self.map_call_arguments(function, expr) + for arg in mapped: + if arg.type != arg.argument.type: + self.error( + arg.expr.location, + f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", + ) + return function.returns + + def visit_get_expr(self, expr: p.GetExpr) -> Type: ... + + def visit_literal_expr(self, expr: p.LiteralExpr) -> Type: + match expr.value: + case bool(): # Must be before int + return self.ctx.get_type("bool") + case int(): + return self.ctx.get_type("int") + case float(): + return self.ctx.get_type("float") + case str(): + return self.ctx.get_type("str") + case _: + self.warning(expr.location, f"Unknown literal {expr}") + return UnknownType() + + def visit_variable_expr(self, expr: p.VariableExpr) -> Type: + return self.look_up_variable(expr.name, expr) or UnknownType() + + def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: + left: Type = expr.left.accept(self) + right: Type = expr.right.accept(self) + # TODO: union type + if left != right: + self.error( + expr.location, + f"Operands must be of the same type, left={left} != right={right}", + ) + return left + + def visit_set_expr(self, expr: p.SetExpr) -> Type: ... + + def visit_cast_expr(self, expr: p.CastExpr) -> Type: + return expr.type.accept(self) + + def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type: + test_type: Type = expr.test.accept(self) + + # TODO Allow subtypes or any type + if test_type != self.ctx.get_type("bool"): + self.error( + expr.test.location, f"If test must be a boolean, got {test_type}" + ) + + true_type: Type = expr.if_true.accept(self) + false_type: Type = expr.if_false.accept(self) + if true_type != false_type: + self.error( + expr.location, + f"Type mismatch in ternary if branches: true={true_type} != false={false_type}", + ) + return UnknownType() + return true_type + + def visit_base_type(self, node: p.BaseType) -> Type: + return self.ctx.get_type(node.base) + + def visit_constraint_type(self, node: p.ConstraintType) -> Type: ... + + def visit_frame_column(self, node: p.FrameColumn) -> Type: ... + + def visit_frame_type(self, node: p.FrameType) -> Type: ... + + def map_call_arguments( + self, function: Function, call: p.CallExpr + ) -> list[MappedArgument]: + """Map call arguments to function parameters as defined in its signature + + This method maps positional-only, keyword-only and mixed parameter definitions + with the arguments passed at the call site + + Any mismatched, missing or unexpected argument is reported as a diagnostic + + Args: + function (Function): the function definition + call (p.CallExpr): the call expression + + Returns: + list[MappedArgument]: the list of mapped arguments + """ + positional: list[tuple[p.Expr, Type]] = [ + (arg, self.type_of(arg)) for arg in call.arguments + ] + keywords: dict[str, tuple[p.Expr, Type]] = { + name: (arg, self.type_of(arg)) for name, arg in call.keywords.items() + } + set_args: set[str] = set() + + required_positional: list[str] = [ + arg.name for arg in function.pos_args + function.args if arg.required + ] + required_keyword: list[str] = [ + arg.name for arg in function.kw_args if arg.required + ] + + mapped: list[MappedArgument] = [] + + pos_params: list[Function.Argument] = list(function.pos_args) + mixed_params: list[Function.Argument] = list(function.args) + kw_params: dict[str, Function.Argument] = { + arg.name: arg for arg in function.kw_args + } + + # TODO: handle *args and **kwargs sinks + for arg in positional: + param: Function.Argument + if len(pos_params) != 0: + param = pos_params.pop(0) + elif len(mixed_params) != 0: + param = mixed_params.pop(0) + else: + self.error(arg[0].location, "Too many positional arguments") + break + name: str = param.name + if name in required_positional: + required_positional.remove(name) + if name in required_keyword: + required_keyword.remove(name) + set_args.add(name) + mapped.append( + MappedArgument( + expr=arg[0], + type=arg[1], + argument=param, + ) + ) + + kw_params.update({arg.name: arg for arg in mixed_params}) + for name, arg in keywords.items(): + param: Function.Argument + if name not in kw_params: + if name in set_args: + self.error( + arg[0].location, f"Multiple values for argument '{name}'" + ) + else: + self.error(arg[0].location, f"Unknown keyword argument '{name}'") + continue + param = kw_params.pop(name) + if name in required_positional: + required_positional.remove(name) + if name in required_keyword: + required_keyword.remove(name) + set_args.add(name) + mapped.append( + MappedArgument( + expr=arg[0], + type=arg[1], + argument=param, + ) + ) + + def join_args(args: list[str]) -> str: + args = list(map(lambda a: f"'{a}'", args)) + if len(args) == 0: + return "" + if len(args) == 1: + return args[0] + return ", ".join(args[:-1]) + " and " + args[-1] + + if len(required_positional) != 0: + plural: str = "" if len(required_positional) == 1 else "s" + args: str = join_args(required_positional) + self.error( + call.location, + f"Missing required positional argument{plural}: {args}", + ) + + if len(required_keyword) != 0: + plural: str = "" if len(required_keyword) == 1 else "s" + args: str = join_args(required_keyword) + self.error( + call.location, + f"Missing required keyword argument{plural}: {args}", + ) + + return mapped diff --git a/midas/checker/diagnostic.py b/midas/checker/diagnostic.py new file mode 100644 index 0000000..45514b4 --- /dev/null +++ b/midas/checker/diagnostic.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +from enum import StrEnum +from pathlib import Path +from typing import Optional + +from midas.ast.location import Location + + +class DiagnosticType(StrEnum): + ERROR = "Error" + WARNING = "Warning" + INFO = "Info" + + +@dataclass(frozen=True) +class Diagnostic: + file_path: Path + location: Location + type: DiagnosticType + message: str + + def __str__(self) -> str: + start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}" + end_loc: Optional[str] = "" + if ( + self.location.end_lineno is not None + and self.location.end_col_offset is not None + ): + end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}" + loc: str = ( + f"at {start_loc}" if end_loc is None else f"from {start_loc} to {end_loc}" + ) + return f"{self.type} in {self.file_path} {loc}: {self.message}" diff --git a/midas/checker/environment.py b/midas/checker/environment.py new file mode 100644 index 0000000..7727f6e --- /dev/null +++ b/midas/checker/environment.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from typing import Optional + +from midas.checker.types import Type + + +class Environment: + """ + A scoped environment in which variables are defined + + Each environment can inherit from a parent/enclosing environment. + """ + + def __init__(self, enclosing: Optional[Environment] = None) -> None: + self.enclosing: Optional[Environment] = enclosing + self.values: dict[str, Type] = {} + self.return_types: list[Type] = [] + + self._children: list[Environment] = [] + if enclosing is not None: + enclosing._children.append(self) + + def define(self, name: str, value: Type) -> None: + """Define a variable in this environment + + Args: + name (str): the name of the variable + value (Type): the value + """ + self.values[name] = value + + def get(self, name: str) -> Optional[Type]: + """Get a variable in the closest environment which has a definition for it + + Args: + name (str): the name of the variable + + Returns: + Optional[Type]: the value of the variable, or None if it was not found + """ + if name in self.values: + return self.values[name] + if self.enclosing is not None: + return self.enclosing.get(name) + # raise NameError(f"Undefined variable '{name}'") + return None + + def assign(self, name: str, value: Type) -> bool: + """Assign a new value to a variable in the environment it was defined in + + Args: + name (str): the name of the variable + value (Type): the new value + + Returns: + bool: True if the variable was assigned in this environment or an ancestor, False otherwise + """ + if name not in self.values: + if self.enclosing is None: + return False + if self.enclosing.assign(name, value): + return True + self.values[name] = value + return True + + def clear(self): + """Clear all definitions in this environment""" + self.values = {} + + def get_at(self, distance: int, name: str) -> Optional[Type]: + """Get the value of a variable at a given distance + + A distance of 0 looks up in this environment, 1 in the parent environment, etc. + This methods expects `distance` to be valid. An error will be raised if + the stack does not extend far enough to reach `distance` + + Args: + distance (int): the scope distance + name (str): the name of the variable + + Returns: + Optional[Type]: the value at the given distance, or None if it is not defined in that environment + + Raises: + AssertionError: if the stack does not extend far enough to reach `distance` + """ + return self.ancestor(distance).values.get(name) + + def assign_at(self, distance: int, name: str, value: Type) -> None: + """Assign a new value to a variable at a given distance + + A distance of 0 assigns in this environment, 1 in the parent environment, etc. + + Args: + distance (int): the scope distance + name (str): the name of the variable + value (Type): the new value + + Raises: + AssertionError: if the stack does not extend far enough to reach `distance` + """ + self.ancestor(distance).values[name] = value + + def ancestor(self, distance: int) -> Environment: + """Get the ancestor at a given distance + + A distance of 0 references this environment, 1 the parent environment, etc. + + Args: + distance (int): the scope distance + + Returns: + Environment: the environment + + Raises: + AssertionError: if the stack does not extend far enough to reach `distance` + """ + env: Environment = self + for _ in range(distance): + assert env.enclosing is not None + env = env.enclosing + return env + + def flat_dict(self) -> dict[str, Type]: + """Get the current environment including definitions in its ancestor as a flat dictionary + + This method recursively combines this environment definitions with its ancestor's + + Returns: + dict: the combined environment + """ + if self.enclosing is None: + return self.values + return self.enclosing.flat_dict() | self.values + + def dump(self) -> dict: + return { + "values": self.values, + "return_types": self.return_types, + "children": [child.dump() for child in self._children], + } diff --git a/midas/checker/operators.py b/midas/checker/operators.py new file mode 100644 index 0000000..e65ab07 --- /dev/null +++ b/midas/checker/operators.py @@ -0,0 +1,31 @@ +import ast +from typing import Type + +OPERATOR_METHODS: dict[Type[ast.operator], str] = { + ast.Add: "__add__", + ast.Sub: "__sub__", + ast.Mult: "__mul__", + ast.MatMult: "__matmul__", + ast.Div: "__truediv__", + ast.Mod: "__mod__", + ast.Pow: "__pow__", + ast.LShift: "__lshift__", + ast.RShift: "__rshift__", + ast.BitOr: "__or__", + ast.BitXor: "__xor__", + ast.BitAnd: "__and__", + ast.FloorDiv: "__floordiv__", +} + +COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = { + ast.Eq: "__eq__", + # ast.NotEq: "__noteq__", + ast.Lt: "__lt__", + ast.LtE: "__le__", + ast.Gt: "__gt__", + ast.GtE: "__ge__", + # ast.Is: "__is__", + # ast.IsNot: "__isnot__", + # ast.In: "__in__", + # ast.NotIn: "__notin__", +} diff --git a/midas/checker/types.py b/midas/checker/types.py new file mode 100644 index 0000000..d62c867 --- /dev/null +++ b/midas/checker/types.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True, kw_only=True) +class BaseType: + name: str + + +@dataclass(frozen=True, kw_only=True) +class AliasType: + name: str + type: Type + + +@dataclass(frozen=True, kw_only=True) +class UnknownType: + pass + + +@dataclass(frozen=True, kw_only=True) +class UnitType: + pass + + +@dataclass(frozen=True, kw_only=True) +class Function: + name: str + pos_args: list[Argument] + args: list[Argument] + kw_args: list[Argument] + returns: Type + + @dataclass(frozen=True, kw_only=True) + class Argument: + name: str + type: Type + required: bool + + +@dataclass(frozen=True, kw_only=True) +class ComplexType: + properties: dict[str, Type] + + +Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType diff --git a/midas/cli/highlight.css b/midas/cli/highlight.css index 31f005d..8da787b 100644 --- a/midas/cli/highlight.css +++ b/midas/cli/highlight.css @@ -53,5 +53,6 @@ span { &.keyword { color: rgb(211, 72, 9); + pointer-events: none; } } \ No newline at end of file diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py index f4801bb..b1b705f 100644 --- a/midas/cli/highlighter.py +++ b/midas/cli/highlighter.py @@ -1,12 +1,15 @@ from __future__ import annotations from abc import ABC, abstractmethod +from dataclasses import dataclass from pathlib import Path from typing import Generic, Optional, Protocol, TextIO, TypeVar -from midas.ast.location import Location import midas.ast.midas as m import midas.ast.python as p +from midas.ast.location import Location +from midas.checker.diagnostic import Diagnostic +from midas.lexer.token import Token H = TypeVar("H", bound="Highlighter", contravariant=True) @@ -21,6 +24,15 @@ class Locatable(Protocol): def location(self) -> Optional[Location]: ... +@dataclass(frozen=True) +class LocatableToken: + token: Token + + @property + def location(self) -> Location: + return self.token.get_location() + + class Highlighter(ABC): BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css" EXTRA_CSS_PATH: Optional[Path] = None @@ -71,6 +83,7 @@ class Highlighter(ABC): openings: list[str] = self.openings.get(pos, []) line_buf += "".join(closings + openings) line_buf += char + line_buf += "".join(self.closings.get((lineno, len(line)), [])) line_buf += "" lines.append(" " + line_buf) lines.extend( @@ -83,7 +96,7 @@ class Highlighter(ABC): buf.write("\n".join(lines)) - def wrap(self, node: Locatable, cls: str): + def wrap(self, node: Locatable, cls: str, message: Optional[str] = None): if node.location is None: return if node.location.end_lineno is None or node.location.end_col_offset is None: @@ -95,6 +108,10 @@ class Highlighter(ABC): ) opening: str = f'' closing: str = "" + if message is not None: + opening = f'{opening}' + closing = f'{closing}{message}' + self.openings.setdefault(start_pos, []).append(opening) self.closings.setdefault(end_pos, []).insert(0, closing) if start_pos[0] != end_pos[0]: @@ -142,6 +159,8 @@ class PythonHighlighter( self.wrap(stmt, "function") for arg in stmt.posonlyargs + stmt.args + stmt.kwonlyargs: self._highlight_function_argument(arg) + for body_stmt in stmt.body: + body_stmt.accept(self) def _highlight_function_argument(self, arg: p.Function.Argument) -> None: self.wrap(arg, "argument") @@ -151,7 +170,23 @@ class PythonHighlighter( def visit_type_assign(self, stmt: p.TypeAssign) -> None: stmt.type.accept(self) - def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: ... + def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: + for target in stmt.targets: + target.accept(self) + stmt.value.accept(self) + + def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: + self.wrap(stmt, "return") + if stmt.value is not None: + stmt.value.accept(self) + + def visit_if_stmt(self, stmt: p.IfStmt) -> None: + self.wrap(stmt, "if") + stmt.test.accept(self) + for body_stmt in stmt.body: + body_stmt.accept(self) + for else_stmt in stmt.orelse: + else_stmt.accept(self) def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ... @@ -159,7 +194,13 @@ class PythonHighlighter( def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ... - def visit_call_expr(self, expr: p.CallExpr) -> None: ... + def visit_call_expr(self, expr: p.CallExpr) -> None: + self.wrap(expr, "call") + expr.callee.accept(self) + for arg in expr.arguments: + arg.accept(self) + for arg in expr.keywords.values(): + arg.accept(self) def visit_get_expr(self, expr: p.GetExpr) -> None: ... @@ -171,35 +212,27 @@ class PythonHighlighter( def visit_set_expr(self, expr: p.SetExpr) -> None: ... + def visit_cast_expr(self, expr: p.CastExpr) -> None: ... -class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]): + def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ... + + +class MidasHighlighter( + Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None] +): EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css" def highlight(self, node: Highlightable[MidasHighlighter]): node.accept(self) - def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None: - self.wrap(stmt, "simple-type") - if stmt.template is not None: - stmt.template.accept(self) - stmt.base.accept(self) - if stmt.constraint is not None: - self.wrap(stmt.constraint, "constraint") - stmt.constraint.accept(self) - - def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None: - self.wrap(stmt, "complex-type") - if stmt.template is not None: - stmt.template.accept(self) - for prop in stmt.properties: - prop.accept(self) + def visit_type_stmt(self, stmt: m.TypeStmt) -> None: + self.wrap(stmt, "type-stmt") + self.wrap(LocatableToken(stmt.name), "type-name") + stmt.type.accept(self) def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: self.wrap(stmt, "property") stmt.type.accept(self) - if stmt.constraint is not None: - self.wrap(stmt.constraint, "constraint") - stmt.constraint.accept(self) def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: self.wrap(stmt, "extend") @@ -209,17 +242,16 @@ class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]): def visit_op_stmt(self, stmt: m.OpStmt) -> None: self.wrap(stmt, "op") + self.wrap(LocatableToken(stmt.name), "op-name") stmt.operand.accept(self) stmt.result.accept(self) def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: self.wrap(stmt, "predicate") + self.wrap(LocatableToken(stmt.name), "predicate-name") stmt.type.accept(self) stmt.condition.accept(self) - def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> None: - self.wrap(expr, "simple-type-expr") - def visit_logical_expr(self, expr: m.LogicalExpr) -> None: self.wrap(expr, "logical-expr") expr.left.accept(self) @@ -248,11 +280,29 @@ class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]): def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ... - def visit_template_expr(self, expr: m.TemplateExpr) -> None: - self.wrap(expr, "template") - expr.type.accept(self) + def visit_named_type(self, type: m.NamedType) -> None: + self.wrap(type, "named-type") - def visit_type_expr(self, expr: m.TypeExpr) -> None: - self.wrap(expr, "type") - if expr.template is not None: - expr.template.accept(self) + def visit_generic_type(self, type: m.GenericType) -> None: + self.wrap(type, "generic-type") + type.type.accept(self) + for param in type.params: + param.accept(self) + + def visit_constraint_type(self, type: m.ConstraintType) -> None: + self.wrap(type, "constraint-type") + type.type.accept(self) + type.constraint.accept(self) + + def visit_complex_type(self, type: m.ComplexType) -> None: + self.wrap(type, "complex-type") + for prop in type.properties: + prop.accept(self) + + +class DiagnosticsHighlighter(Highlighter): + EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css" + + def highlight(self, diagnostics: list[Diagnostic]): + for diagnostic in diagnostics: + self.wrap(diagnostic, str(diagnostic.type).lower(), diagnostic.message) diff --git a/midas/cli/hl_diagnostic.css b/midas/cli/hl_diagnostic.css new file mode 100644 index 0000000..8b09b7f --- /dev/null +++ b/midas/cli/hl_diagnostic.css @@ -0,0 +1,39 @@ +span { + --opacity: 0.4; + + &.error { + --col: 255, 0, 0; + } + &.warning { + --col: 250, 160, 0; + } + &.info { + --col: 150, 190, 250; + } + + &.with-msg { + position: relative; + + .message { + display: none; + } + + &:hover:not(:has(.with-msg:hover)) { + .message { + display: inline-block; + } + } + + .message { + position: absolute; + top: calc(100% + 0.2em); + left: -.2em; + background-color: black; + color: white; + padding: 0.2em 0.4em; + border-radius: .2em; + z-index: 10; + width: 300%; + } + } +} \ No newline at end of file diff --git a/midas/cli/hl_midas.css b/midas/cli/hl_midas.css index e8adef6..fabb84e 100644 --- a/midas/cli/hl_midas.css +++ b/midas/cli/hl_midas.css @@ -5,12 +5,11 @@ span { font-style: italic; } - &.simple-type { - --col: 108, 233, 108; - } - + &.named-type, + &.generic-type, + &.constraint-type, &.complex-type { - --col: 233, 206, 108; + --col: 150, 150, 150; } &.constraint { @@ -33,10 +32,6 @@ span { --col: 193, 108, 233; } - &.simple-type-expr { - --col: 150, 150, 150; - } - &.logical-expr, &.binary-expr, &.unary-expr, @@ -48,7 +43,9 @@ span { --col: 163, 117, 71; } - &.type { + &.type-name, + &.op-name, + &.predicate-name { --col: 200, 200, 200; font-weight: bold; } diff --git a/midas/cli/main.py b/midas/cli/main.py index 11d69e0..07dfd07 100644 --- a/midas/cli/main.py +++ b/midas/cli/main.py @@ -1,18 +1,30 @@ import ast -from dataclasses import dataclass -from typing import Optional, TextIO +import json +import logging +from pathlib import Path +from typing import Optional, TextIO, get_args import click import midas.ast.midas as m import midas.ast.python as p -from midas.ast.location import Location -from midas.ast.printer import PythonAstPrinter -from midas.cli.highlighter import Highlighter, MidasHighlighter, PythonHighlighter +from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter +from midas.checker.checker import Checker +from midas.checker.diagnostic import Diagnostic +from midas.checker.types import Type +from midas.cli.highlighter import ( + DiagnosticsHighlighter, + Highlighter, + LocatableToken, + MidasHighlighter, + PythonHighlighter, +) from midas.lexer.midas import MidasLexer from midas.lexer.token import Token, TokenType from midas.parser.midas import MidasParser from midas.parser.python import PythonParser +from midas.resolver.resolver import Resolver +from midas.utils import UniversalJSONDumper @click.group() @@ -21,9 +33,41 @@ def midas(): @midas.command() +@click.option("-l", "--highlight", type=click.File("w")) +@click.option("-t", "--types", type=click.File("r"), multiple=True) @click.argument("file", type=click.File("r")) -def compile(file: TextIO): - raise NotImplementedError +def compile(highlight: Optional[TextIO], file: TextIO, types: tuple[TextIO]): + logging.basicConfig(level=logging.DEBUG) + source: str = file.read() + tree: ast.Module = ast.parse(source, filename=file.name) + parser = PythonParser() + stmts: list[p.Stmt] = parser.parse_module(tree) + resolver = Resolver() + resolver.resolve(*stmts) + types_paths: list[Path] = [Path(t.name).resolve() for t in types] + checker = Checker( + resolver.locals, + source_path=Path(file.name).resolve(), + types_paths=types_paths, + ) + diagnostics: list[Diagnostic] = checker.check(stmts) + for diagnostic in diagnostics: + print(diagnostic) + + print( + json.dumps( + UniversalJSONDumper.dump( + checker.global_env, + [("Environment", "_children")], + lambda obj: isinstance(obj, get_args(Type)), + ), + indent=4, + ) + ) + if highlight is not None: + highlighter = DiagnosticsHighlighter(source) + highlighter.highlight(diagnostics) + highlighter.dump(highlight) @midas.group() @@ -31,26 +75,52 @@ def utils(): pass +def dump_python_ast(tree: ast.Module) -> str: + parser = PythonParser() + stmts: list[p.Stmt] = parser.parse_module(tree) + printer = PythonAstPrinter() + dump: str = "" + for stmt in stmts: + dump += printer.print(stmt) + dump += "\n" + return dump + + +def dump_midas_ast(source: str, filename: str) -> str: + lexer = MidasLexer(source, file=filename) + tokens: list[Token] = lexer.process() + parser = MidasParser(tokens) + stmts: list[m.Stmt] = parser.parse() + if len(parser.errors) != 0: + for err in parser.errors: + print(err.get_report()) + raise RuntimeError("A parsing error occurred") + printer = MidasAstPrinter() + dump: str = "" + for stmt in stmts: + dump += printer.print(stmt) + dump += "\n" + return dump + + @utils.command() @click.option("-o", "--output", type=click.File("w")) @click.option("-p", "--parse", is_flag=True) @click.argument("file", type=click.File("r")) def dump_ast(output: Optional[TextIO], parse: bool, file: TextIO): source: str = file.read() - tree: ast.Module = ast.parse(source, filename=file.name) + dump: str - - if parse: - parser = PythonParser() - stmts: list[p.Stmt] = parser.parse_module(tree) - printer = PythonAstPrinter() - dump = "" - for stmt in stmts: - dump += printer.print(stmt) - dump += "\n" - + if file.name.endswith(".py"): + tree: ast.Module = ast.parse(source, filename=file.name) + if parse: + dump = dump_python_ast(tree) + else: + dump = ast.dump(tree, indent=4) + elif file.name.endswith(".midas"): + dump = dump_midas_ast(source, file.name) else: - dump = ast.dump(tree, indent=4) + raise ValueError("Unsupported file type") if output is None: click.echo(dump) @@ -77,14 +147,6 @@ def highlight_midas(source: str, path: str) -> Highlighter: for err in parser.errors: print(err.get_report()) - @dataclass(frozen=True) - class LocatableToken: - token: Token - - @property - def location(self) -> Location: - return self.token.get_location() - for stmt in stmts: highlighter.highlight(stmt) for token in tokens: @@ -109,3 +171,23 @@ def highlight(output: TextIO, file: TextIO): else: raise ValueError("Unsupported file type") highlighter.dump(output) + + +@midas.command() +@click.option("-o", "--output", type=click.File("w"), default="-") +@click.argument("file", type=click.File("r")) +def format(output: TextIO, file: TextIO): + source: str = file.read() + printer = MidasPrinter() + lexer = MidasLexer(source, file=file.name) + tokens: list[Token] = lexer.process() + parser = MidasParser(tokens) + stmts: list[m.Stmt] = parser.parse() + for err in parser.errors: + print(err.get_report()) + for stmt in stmts: + output.write(printer.print(stmt) + "\n") + + +if __name__ == "__main__": + midas() diff --git a/midas/lexer/midas.py b/midas/lexer/midas.py index acc97d6..124ea09 100644 --- a/midas/lexer/midas.py +++ b/midas/lexer/midas.py @@ -40,8 +40,8 @@ class MidasLexer(Lexer): self.add_token(TokenType.AND) case "?": self.add_token(TokenType.QMARK) - # case ",": - # self.add_token(TokenType.COMMA) + case ",": + self.add_token(TokenType.COMMA) case "_" if not self.is_identifier_char(self.peek_next(), start=False): self.add_token(TokenType.UNDERSCORE) case "-" if self.match(">"): diff --git a/midas/lexer/token.py b/midas/lexer/token.py index a518a8b..f08964a 100644 --- a/midas/lexer/token.py +++ b/midas/lexer/token.py @@ -17,7 +17,7 @@ class TokenType(Enum): LEFT_BRACE = auto() RIGHT_BRACE = auto() COLON = auto() - # COMMA = auto() + COMMA = auto() UNDERSCORE = auto() ARROW = auto() AND = auto() diff --git a/midas/parser/midas.py b/midas/parser/midas.py index 4998c51..5d09b83 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -3,21 +3,22 @@ from typing import Optional from midas.ast.location import Location from midas.ast.midas import ( BinaryExpr, - ComplexTypeStmt, + ComplexType, + ConstraintType, Expr, ExtendStmt, + GenericType, GetExpr, GroupingExpr, LiteralExpr, LogicalExpr, + NamedType, OpStmt, PredicateStmt, PropertyStmt, - SimpleTypeExpr, - SimpleTypeStmt, Stmt, - TemplateExpr, - TypeExpr, + Type, + TypeStmt, UnaryExpr, VariableExpr, WildcardExpr, @@ -81,7 +82,7 @@ class MidasParser(Parser): self.synchronize() return None - def type_declaration(self) -> SimpleTypeStmt | ComplexTypeStmt: + def type_declaration(self) -> TypeStmt: """Parse a type declaration A type declaration can either be a simple type alias or a new complex type. @@ -107,33 +108,22 @@ class MidasParser(Parser): """ keyword: Token = self.previous() name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") - template: Optional[TemplateExpr] = None + params: list[TypeStmt.Param] = [] if self.check(TokenType.LEFT_BRACKET): - template = self.template_expr() + params = self.type_stmt_params() - if self.match(TokenType.LEFT_PAREN): - base: TypeExpr = self.type_expr() - self.consume(TokenType.RIGHT_PAREN, "Unclosed base type parenthesis") - constraint: Optional[Expr] = None - if self.match(TokenType.WHERE): - constraint = self.constraint() - return SimpleTypeStmt( - location=keyword.location_to(self.previous()), - name=name, - template=template, - base=base, - constraint=constraint, - ) - else: - properties: list[PropertyStmt] = self.type_properties() - return ComplexTypeStmt( - location=keyword.location_to(self.previous()), - name=name, - template=template, - properties=properties, - ) + self.consume(TokenType.EQUAL, "Expected '=' before type definition") - def template_expr(self) -> TemplateExpr: + type: Type = self.type_expr() + + return TypeStmt( + location=keyword.location_to(self.previous()), + name=name, + params=params, + type=type, + ) + + def type_stmt_params(self) -> list[TypeStmt.Param]: """Parse a generic template expression A template is written `[TypeExpr]` @@ -141,16 +131,27 @@ class MidasParser(Parser): Returns: TemplateExpr: the parsed template expression """ - left: Token = self.consume( - TokenType.LEFT_BRACKET, "Missing '[' before template expression" - ) - type: TypeExpr = self.type_expr() - right: Token = self.consume( - TokenType.RIGHT_BRACKET, "Missing ']' after template expression" - ) - return TemplateExpr(location=left.location_to(right), type=type) + self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression") + params: list[TypeStmt.Param] = [] + while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET): + name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable") + bound: Optional[Type] = None + if self.match(TokenType.LESS): + self.consume(TokenType.COLON, "Expected ':' after '<'") + bound = self.type_expr() + params.append( + TypeStmt.Param( + location=name.location_to(self.previous()), + name=name, + bound=bound, + ) + ) + if not self.match(TokenType.COMMA): + break + self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression") + return params - def type_expr(self) -> TypeExpr: + def type_expr(self) -> Type: """Parse a type expression A type is an identifier, optionally followed by a template expression. @@ -159,30 +160,82 @@ class MidasParser(Parser): Returns: TypeExpr: the parsed type expression """ - name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") - template: Optional[TemplateExpr] = None + return self.constraint_type() + + def constraint_type(self) -> Type: + type: Type = self.base_type() + if self.match(TokenType.WHERE): + constraint: Expr = self.constraint() + return ConstraintType( + location=Location.span(type.location, constraint.location), + type=type, + constraint=constraint, + ) + return type + + def base_type(self) -> Type: + if self.match(TokenType.LEFT_PAREN): + type: Type = self.type_expr() + self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis") + return type + + if self.check(TokenType.LEFT_BRACE): + return self.complex_type() + + return self.generic_type() + + def generic_type(self) -> Type: + type: Type = self.named_type() if self.check(TokenType.LEFT_BRACKET): - template = self.template_expr() - optional: bool = self.match(TokenType.QMARK) - return TypeExpr( - location=name.location_to(self.previous()), + params: list[Type] = self.type_params() + return GenericType( + location=Location.span(type.location, self.previous().get_location()), + type=type, + params=params, + ) + return type + + def type_params(self) -> list[Type]: + params: list[Type] = [] + self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters") + while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET): + params.append(self.type_expr()) + if not self.match(TokenType.COMMA): + break + self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters") + return params + + def named_type(self) -> Type: + name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") + return NamedType( + location=name.get_location(), name=name, - template=template, - optional=optional, ) - def simple_type_expr(self) -> SimpleTypeExpr: - """Parse a simple type expression + def complex_type(self) -> Type: + """Parse a type definition body - A simple type is just an identifier optionally followed by a '?' + A type definition body is a set of whitespace-separated + property statements enclosed in curly braces Returns: - SimpleTypeExpr: the parsed simple type expression + list[PropertyStmt]: the parsed type properties """ - name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") - optional: bool = self.match(TokenType.QMARK) - return SimpleTypeExpr( - location=name.location_to(self.previous()), name=name, optional=optional + left: Token = self.consume( + TokenType.LEFT_BRACE, "Expected '{' to start type body" + ) + properties: list[PropertyStmt] = [] + names: set[str] = set() + while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end(): + prop: PropertyStmt = self.property_stmt() + if prop.name.lexeme in names: + raise self.error(prop.name, "Duplicate property") + names.add(prop.name.lexeme) + properties.append(prop) + right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body") + return ComplexType( + location=left.location_to(right), + properties=properties, ) def constraint(self) -> Expr: @@ -205,9 +258,7 @@ class MidasParser(Parser): while self.match(TokenType.AND): operator: Token = self.previous() right: Expr = self.equality() - location: Optional[Location] = None - if expr.location and right.location: - location = Location.span(expr.location, right.location) + location: Location = Location.span(expr.location, right.location) expr = LogicalExpr( location=location, left=expr, operator=operator, right=right ) @@ -223,9 +274,7 @@ class MidasParser(Parser): while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL): operator: Token = self.previous() right: Expr = self.comparison() - location: Optional[Location] = None - if expr.location and right.location: - location = Location.span(expr.location, right.location) + location: Location = Location.span(expr.location, right.location) expr = BinaryExpr( location=location, left=expr, operator=operator, right=right ) @@ -246,9 +295,7 @@ class MidasParser(Parser): ): operator: Token = self.previous() right: Expr = self.unary() - location: Optional[Location] = None - if expr.location and right.location: - location = Location.span(expr.location, right.location) + location: Location = Location.span(expr.location, right.location) expr = BinaryExpr( location=location, left=expr, operator=operator, right=right ) @@ -263,9 +310,7 @@ class MidasParser(Parser): if self.match(TokenType.MINUS): operator: Token = self.previous() right: Expr = self.unary() - location: Optional[Location] = None - if right.location: - location = Location.span(operator.get_location(), right.location) + location: Location = Location.span(operator.get_location(), right.location) return UnaryExpr(location=location, operator=operator, right=right) return self.reference() @@ -280,9 +325,7 @@ class MidasParser(Parser): name: Token = self.consume( TokenType.IDENTIFIER, "Expected property name after '.'" ) - location: Optional[Location] = None - if expr.location: - location = Location.span(expr.location, name.get_location()) + location: Location = Location.span(expr.location, name.get_location()) expr = GetExpr(location=location, expr=expr, name=name) return expr @@ -318,22 +361,6 @@ class MidasParser(Parser): raise self.error(self.peek(), "Expected expression") - def type_properties(self) -> list[PropertyStmt]: - """Parse a type definition body - - A type definition body is a set of whitespace-separated - property statements enclosed in curly braces - - Returns: - list[PropertyStmt]: the parsed type properties - """ - self.consume(TokenType.LEFT_BRACE, "Expected '{' to start type body") - properties: list[PropertyStmt] = [] - while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end(): - properties.append(self.property_stmt()) - self.consume(TokenType.RIGHT_BRACE, "Unclosed type body") - return properties - def property_stmt(self) -> PropertyStmt: """Parse a property statement @@ -344,15 +371,11 @@ class MidasParser(Parser): """ name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name") self.consume(TokenType.COLON, "Expected ':' after property name") - type: TypeExpr = self.type_expr() - constraint: Optional[Expr] = None - if self.match(TokenType.WHERE): - constraint = self.constraint() + type: Type = self.type_expr() return PropertyStmt( location=name.location_to(self.previous()), name=name, type=type, - constraint=constraint, ) def extend_declaration(self) -> ExtendStmt: @@ -364,15 +387,13 @@ class MidasParser(Parser): ExtendStmt: the parsed extension statement """ keyword: Token = self.previous() - type: TypeExpr = self.type_expr() + type: Type = self.type_expr() self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body") operations: list[OpStmt] = [] while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE): operations.append(self.op_declaration()) self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body") - location: Optional[Location] = None - if type.location: - location = keyword.location_to(self.previous()) + location: Location = keyword.location_to(self.previous()) return ExtendStmt(location=location, type=type, operations=operations) def op_declaration(self) -> OpStmt: @@ -387,11 +408,11 @@ class MidasParser(Parser): name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name") self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type") - operand: TypeExpr = self.type_expr() + operand: Type = self.type_expr() self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type") self.consume(TokenType.ARROW, "Expected '->' before result type") - result: TypeExpr = self.type_expr() + result: Type = self.type_expr() return OpStmt( location=keyword.location_to(self.previous()), @@ -413,7 +434,7 @@ class MidasParser(Parser): self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject") subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name") self.consume(TokenType.COLON, "Expected ':' after subject name") - type: TypeExpr = self.type_expr() + type: Type = self.type_expr() self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject") self.consume(TokenType.EQUAL, "Expected '=' after predicate subject") condition: Expr = self.constraint() diff --git a/midas/parser/python.py b/midas/parser/python.py index 4b6a3f1..79011bc 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -2,12 +2,12 @@ import ast from typing import Optional from midas.ast.location import Location - from midas.ast.python import ( AssignStmt, BaseType, BinaryExpr, CallExpr, + CastExpr, CompareExpr, ConstraintType, Expr, @@ -16,10 +16,13 @@ from midas.ast.python import ( FrameType, Function, GetExpr, + IfStmt, LiteralExpr, LogicalExpr, MidasType, + ReturnStmt, Stmt, + TernaryExpr, TypeAssign, UnaryExpr, VariableExpr, @@ -38,6 +41,8 @@ class UnsupportedSyntaxError(Exception): class PythonParser: + CAST_FUNCTION = "cast" + def parse_module(self, node: ast.Module) -> list[Stmt]: statements: list[Stmt] = [] for stmt in node.body: @@ -53,6 +58,7 @@ class PythonParser: return statements def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]: + location: Location = Location.from_ast(node) match node: case ast.AnnAssign(): return self.parse_annotation_assign(node) @@ -60,11 +66,29 @@ class PythonParser: case ast.Assign(): return self.parse_assign(node) + case ast.AugAssign(): + return self.parse_aug_assign(node) + case ast.FunctionDef(): return self.parse_function(node) case ast.Expr(value=expr): - return ExpressionStmt(expr=self.parse_expr(expr)) + return ExpressionStmt( + location=location, + expr=self.parse_expr(expr), + ) + + case ast.Return(value=value): + return ReturnStmt( + location=location, + value=self.parse_expr(value) if value is not None else None, + ) + + case ast.If(): + return self.parse_if(node) + + case ast.Pass(): + return None case _: print(f"Unsupported statement: {ast.unparse(node)}") @@ -80,15 +104,14 @@ class PythonParser: value=value, simple=1, ): - type = self._parse_type(annotation, root=True) - if type is not None: - statements.append( - TypeAssign( - location=loc, - name=target, - type=type, - ) + type = self._parse_type(annotation) + statements.append( + TypeAssign( + location=loc, + name=target, + type=type, ) + ) if value is not None: statements.append( @@ -117,6 +140,45 @@ class PythonParser: value=value, ) + def parse_aug_assign(self, node: ast.AugAssign) -> AssignStmt: + location: Location = Location.from_ast(node) + target: Expr = self.parse_expr(node.target) + value: Expr = self.parse_expr(node.value) + return AssignStmt( + location=location, + targets=[target], + value=BinaryExpr( + location=location, + left=target, + operator=node.op, + right=value, + ), + ) + + def parse_if(self, node: ast.If) -> IfStmt: + body: list[Stmt] = [] + for stmt in node.body: + stmts = self.parse_stmt(stmt) + if isinstance(stmts, Stmt): + body.append(stmts) + elif stmts is not None: + body.extend(stmts) + + orelse: list[Stmt] = [] + for stmt in node.orelse: + stmts = self.parse_stmt(stmt) + if isinstance(stmts, Stmt): + orelse.append(stmts) + elif stmts is not None: + orelse.extend(stmts) + + return IfStmt( + location=Location.from_ast(node), + test=self.parse_expr(node.test), + body=body, + orelse=orelse, + ) + def parse_function(self, node: ast.FunctionDef) -> Function: loc: Location = Location.from_ast(node) match node: @@ -125,26 +187,74 @@ class PythonParser: args=ast.arguments( posonlyargs=posonlyargs, args=args, + vararg=sink, kwonlyargs=kwonlyargs, + kwarg=kw_sink, + defaults=defaults, + kw_defaults=kw_defaults, ), returns=returns, + body=raw_body, ): - def parse_args(args_list: list[ast.arg]) -> list[Function.Argument]: - return [self._parse_function_argument(arg) for arg in args_list] + def parse_args( + args_list: list[ast.arg], defaults: list[Optional[Expr]] + ) -> list[Function.Argument]: + return [ + self._parse_function_argument(arg, default) + for arg, default in zip(args_list, defaults) + ] + + body: list[Stmt] = [] + for stmt in raw_body: + stmts = self.parse_stmt(stmt) + if isinstance(stmts, Stmt): + body.append(stmts) + elif stmts is not None: + body.extend(stmts) + + parsed_defaults: list[Optional[Expr]] = [ + self.parse_expr(default) for default in defaults + ] + n_posargs: int = len(posonlyargs) + n_args: int = len(args) + n_all_posargs = n_posargs + n_args + parsed_defaults = [ + None, + ] * (n_all_posargs - len(defaults)) + parsed_defaults + + posargs_defaults: list[Optional[Expr]] = parsed_defaults[:n_posargs] + args_defaults: list[Optional[Expr]] = parsed_defaults[n_posargs:] + kwargs_defaults: list[Optional[Expr]] = [ + self.parse_expr(default) if default is not None else None + for default in kw_defaults + ] return Function( location=loc, name=name, - posonlyargs=parse_args(posonlyargs), - args=parse_args(args), - kwonlyargs=parse_args(kwonlyargs), + posonlyargs=parse_args(posonlyargs, posargs_defaults), + args=parse_args(args, args_defaults), + sink=( + self._parse_function_argument(sink, None) + if sink is not None + else None + ), + kwonlyargs=parse_args(kwonlyargs, kwargs_defaults), + kw_sink=( + self._parse_function_argument(kw_sink, None) + if kw_sink is not None + else None + ), returns=self._parse_type(returns) if returns is not None else None, + body=body, ) case _: print(f"Unsupported function definition: {ast.unparse(node)}") - def _parse_function_argument(self, arg: ast.arg) -> Function.Argument: + def _parse_function_argument( + self, arg: ast.arg, default: Optional[Expr] + ) -> Function.Argument: loc: Location = Location.from_ast(arg) name: str = arg.arg type: Optional[MidasType] = None @@ -154,11 +264,10 @@ class PythonParser: location=loc, name=name, type=type, + default=default, ) - def _parse_type( - self, type_expr: ast.expr, root: bool = False - ) -> Optional[MidasType]: + def _parse_type(self, type_expr: ast.expr) -> MidasType: loc: Location = Location.from_ast(type_expr) match type_expr: case ast.Subscript(value=ast.Name(id="Frame"), slice=schema): @@ -205,9 +314,14 @@ class PythonParser: constraint=right_expr, ) + case ast.Constant(value=None): + return BaseType( + location=loc, + base="None", + param=None, + ) + case _: - if root: - return None raise UnsupportedSyntaxError(type_expr) def _parse_frame_type(self, schema: ast.expr) -> FrameType: @@ -257,12 +371,14 @@ class PythonParser: raise UnsupportedSyntaxError(column) def parse_expr(self, node: ast.expr) -> Expr: + location: Location = Location.from_ast(node) match node: case ast.BoolOp(): return self.parse_bool_op(node) case ast.BinOp(left=left, op=op, right=right): return BinaryExpr( + location=location, left=self.parse_expr(left), operator=op, right=self.parse_expr(right), @@ -270,6 +386,7 @@ class PythonParser: case ast.UnaryOp(op=op, operand=right): return UnaryExpr( + location=location, operator=op, right=self.parse_expr(right), ) @@ -277,62 +394,96 @@ class PythonParser: case ast.Compare(): return self.parse_compare(node) + case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)): + return self.parse_cast(node) + case ast.Call(): return self.parse_call(node) + case ast.IfExp(): + return self.parse_ternary(node) + case ast.Constant(value=value): - return LiteralExpr(value=value) + return LiteralExpr(location=location, value=value) case ast.Attribute(value=object, attr=name): return GetExpr( + location=location, object=self.parse_expr(object), name=name, ) case ast.Name(id=name): - return VariableExpr(name=name) + return VariableExpr(location=location, name=name) case _: raise UnsupportedSyntaxError(node) def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr: op: ast.boolop = node.op - values: list[ast.expr] = node.values + rights: list[Expr] = [self.parse_expr(expr) for expr in node.values] expr: LogicalExpr = LogicalExpr( - left=self.parse_expr(values[0]), + location=Location.span( + rights[0].location, + rights[1].location, + ), + left=rights[0], operator=op, - right=self.parse_expr(values[1]), + right=rights[1], ) - for value in values[2:]: + for right in rights[2:]: expr = LogicalExpr( + location=Location.span(expr.location, right.location), left=expr, operator=op, - right=self.parse_expr(value), + right=right, ) return expr def parse_compare(self, node: ast.Compare) -> Expr: ops: list[ast.cmpop] = node.ops + left: Expr = self.parse_expr(node.left) rights: list[Expr] = [self.parse_expr(expr) for expr in node.comparators] expr: Expr = CompareExpr( - left=self.parse_expr(node.left), + location=Location.span( + left.location, + rights[0].location, + ), + left=left, operator=ops[0], right=rights[0], ) for i, right in enumerate(rights[1:]): + comparison = CompareExpr( + location=Location.span(rights[i].location, right.location), + left=rights[i], + operator=ops[i], + right=right, + ) expr = LogicalExpr( + location=Location.span(expr.location, comparison.location), left=expr, operator=ast.And(), - right=CompareExpr( - left=rights[i], - operator=ops[i], - right=right, - ), + right=comparison, ) return expr + def parse_cast(self, node: ast.Call) -> CastExpr: + match node: + case ast.Call(args=[type, expr], keywords=[]): + return CastExpr( + location=Location.from_ast(node), + type=self._parse_type(type), + expr=self.parse_expr(expr), + ) + case _: + raise InvalidSyntaxError( + f"Invalid call to {self.CAST_FUNCTION}, expected type and expression" + ) + def parse_call(self, node: ast.Call) -> CallExpr: return CallExpr( + location=Location.from_ast(node), callee=self.parse_expr(node.func), arguments=[self.parse_expr(arg) for arg in node.args], keywords={ @@ -341,3 +492,11 @@ class PythonParser: if arg.arg is not None # Should always be True, type checker happy }, ) + + def parse_ternary(self, node: ast.IfExp) -> TernaryExpr: + return TernaryExpr( + location=Location.from_ast(node), + test=self.parse_expr(node.test), + if_true=self.parse_expr(node.body), + if_false=self.parse_expr(node.orelse), + ) diff --git a/midas/resolver/builtin.py b/midas/resolver/builtin.py new file mode 100644 index 0000000..04bc6e3 --- /dev/null +++ b/midas/resolver/builtin.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from midas.checker.types import BaseType, Type, UnitType + +if TYPE_CHECKING: + from midas.resolver.midas import MidasResolver + + +def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type): + ctx.define_operation( + left=t1, + operator=operator, + right=t2, + result=t3, + ) + + +def basic_op(ctx: MidasResolver, type: Type, op: str): + ctx.define_operation( + left=type, + operator=op, + right=type, + result=type, + ) + + +def define_builtins(ctx: MidasResolver): + """Define builtin types and operations""" + unit = ctx.define_type("None", UnitType()) + bool = ctx.define_type("bool", BaseType(name="bool")) + int = ctx.define_type("int", BaseType(name="int")) + float = ctx.define_type("float", BaseType(name="float")) + str = ctx.define_type("str", BaseType(name="str")) + + basic_op(ctx, int, "__add__") # int + int = int + basic_op(ctx, int, "__sub__") # int - int = int + basic_op(ctx, int, "__mul__") # int * int = int + basic_op(ctx, int, "__pow__") # int ** int = int + basic_op(ctx, int, "__mod__") # int % int = int + basic_op(ctx, int, "__and__") # int & int = int + basic_op(ctx, int, "__or__") # int | int = int + basic_op(ctx, int, "__xor__") # int ^ int = int + op(ctx, int, "__lt__", int, bool) # int < int = bool + op(ctx, int, "__gt__", int, bool) # int > int = bool + op(ctx, int, "__le__", int, bool) # int <= int = bool + op(ctx, int, "__ge__", int, bool) # int >= int = bool + op(ctx, int, "__eq__", int, bool) # int == int = bool + basic_op(ctx, float, "__add__") # float + float = float + basic_op(ctx, float, "__sub__") # float - float = float + basic_op(ctx, float, "__mul__") # float * float = float + basic_op(ctx, float, "__truediv__") # float / float = float + op(ctx, float, "__lt__", float, bool) # float < float = bool + op(ctx, float, "__gt__", float, bool) # float > float = bool + op(ctx, float, "__le__", float, bool) # float <= float = bool + op(ctx, float, "__ge__", float, bool) # float >= float = bool + op(ctx, float, "__eq__", float, bool) # float == float = bool + basic_op(ctx, str, "__add__") # str + str = str + op(ctx, str, "__eq__", str, bool) # str == str = bool + + op(ctx, int, "__lt__", float, bool) # int < float = bool + op(ctx, int, "__gt__", float, bool) # int > float = bool + op(ctx, int, "__le__", float, bool) # int <= float = bool + op(ctx, int, "__ge__", float, bool) # int >= float = bool + op(ctx, int, "__eq__", float, bool) # int == float = bool + + op(ctx, float, "__lt__", int, bool) # float < int = bool + op(ctx, float, "__gt__", int, bool) # float > int = bool + op(ctx, float, "__le__", int, bool) # float <= int = bool + op(ctx, float, "__ge__", int, bool) # float >= int = bool + op(ctx, float, "__eq__", int, bool) # float == int = bool diff --git a/midas/resolver/midas.py b/midas/resolver/midas.py new file mode 100644 index 0000000..acbbe96 --- /dev/null +++ b/midas/resolver/midas.py @@ -0,0 +1,163 @@ +from typing import Optional + +import midas.ast.midas as m +from midas.checker.types import ( + AliasType, + Type, + UnknownType, +) +from midas.resolver.builtin import define_builtins + + +class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]): + """A resolver which evaluates Midas type definitions and build a registry""" + + def __init__(self) -> None: + self._types: dict[str, Type] = {} + self._operations: dict[tuple[Type, str, Type], Type] = {} + + define_builtins(self) + + def get_type(self, name: str) -> Type: + """Get a type from its name + + Args: + name (str): the name of the type + + Raises: + NameError: if the type is not defined + + Returns: + Type: the type + """ + type: Optional[Type] = self._types.get(name) + if type is None: + raise NameError(f"Undefined type {name}") + return type + + def get_operation_result( + self, left: Type, operator: str, right: Type + ) -> Optional[Type]: + """Get the resulting type of an operation + + Args: + left (Type): the type of the left operand + operator (str): the operation name + right (Type): the type of the right operand + + Returns: + Optional[Type]: the result type, or None if no matching operation was found + """ + operation: tuple[Type, str, Type] = (left, operator, right) + result: Optional[Type] = self._operations.get(operation) + return result + + def define_type(self, name: str, type: Type) -> Type: + """Define a type in the registry + + Args: + name (str): the name of the type + type (Type): the type to define + + Raises: + ValueError: if a type is already defined with that name + + Returns: + Type: the defined type + """ + if name in self._types: + raise ValueError(f"Type {name} already defined") + self._types[name] = type + return type + + def define_operation(self, left: Type, operator: str, right: Type, result: Type): + """Define an operation in the registry + + Args: + left (Type): the type of the left operand + operator (str): the operation name + right (Type): the type of the right operand + result (Type): the result type + + Raises: + ValueError: if an operation is already defined with these operands and name + """ + operation: tuple[Type, str, Type] = (left, operator, right) + if operation in self._operations: + raise ValueError( + f"Operation {operator} already defined between {left} and {right}" + ) + self._operations[operation] = result + + def resolve(self, stmts: list[m.Stmt]): + """Process a sequence of statements + + Args: + stmts (list[m.Stmt]): the statements + """ + for stmt in stmts: + stmt.accept(self) + + def visit_type_stmt(self, stmt: m.TypeStmt) -> None: + type: Type = stmt.type.accept(self) + for param in stmt.params: + if param.bound is not None: + param.bound.accept(self) + name: str = stmt.name.lexeme + self.define_type(name, AliasType(name=name, type=type)) + + def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ... + + def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: + base: Type = stmt.type.accept(self) + for op in stmt.operations: + right: Type = op.operand.accept(self) + result: Type = op.result.accept(self) + self.define_operation( + left=base, + operator=op.name.lexeme, + right=right, + result=result, + ) + + def visit_op_stmt(self, stmt: m.OpStmt) -> None: ... + + def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ... + + def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ... + + def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ... + + def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ... + + def visit_get_expr(self, expr: m.GetExpr) -> None: ... + + def visit_variable_expr(self, expr: m.VariableExpr) -> None: ... + + def visit_grouping_expr(self, expr: m.GroupingExpr) -> None: + return expr.expr.accept(self) + + def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ... + + def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ... + + def visit_named_type(self, type: m.NamedType) -> Type: + return self.get_type(type.name.lexeme) + + def visit_generic_type(self, type: m.GenericType) -> Type: + type_: Type = type.type.accept(self) + params: list[Type] = [param.accept(self) for param in type.params] + # TODO + return UnknownType() + + def visit_constraint_type(self, type: m.ConstraintType) -> Type: + type_: Type = type.type.accept(self) + type.constraint.accept(self) + # TODO + return UnknownType() + + def visit_complex_type(self, type: m.ComplexType) -> Type: + for prop in type.properties: + prop.accept(self) + # TODO + return UnknownType() diff --git a/midas/resolver/resolver.py b/midas/resolver/resolver.py new file mode 100644 index 0000000..15166bd --- /dev/null +++ b/midas/resolver/resolver.py @@ -0,0 +1,187 @@ +import midas.ast.python as p + + +class ResolverError(Exception): ... + + +class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): + """A variable assignment and reference resolver + + This class keeps track of which scope a variable is defined in and which + scope is referred to when a variable is referenced + """ + + def __init__(self): + self.locals: dict[p.Expr, int] = {} + self.scopes: list[dict[str, bool]] = [] + + def resolve(self, *objects: p.Stmt | p.Expr) -> None: + """Resolve the given statements or expressions""" + + for obj in objects: + obj.accept(self) + + def begin_scope(self): + """Begin a new scope inside the current one""" + self.scopes.append({}) + + def end_scope(self): + """Close the current scope""" + self.scopes.pop() + + def declare(self, name: str) -> None: + """Declare a variable in the current scope + + This method must be called *before* evaluating the variable initializer + + Args: + name (str): the name of the variable + + Raises: + ResolverError: if the variable has already been declared in the current scope + """ + if len(self.scopes) == 0: + return + scope: dict[str, bool] = self.scopes[-1] + if name in scope: + raise ResolverError( + f"A variable with the name {name} is already declared in this scope" + ) + scope[name] = False + + def define(self, name: str) -> None: + """Define a variable in the current scope + + This method must be called *after* evaluating the variable initializer + + Args: + name (str): the name of the variable + """ + if len(self.scopes) == 0: + return + self.scopes[-1][name] = True + + def resolve_local(self, expr: p.Expr, name: str) -> None: + """Resolve a variable reference and store the scope distance + + This method associates to the variable expression a number representing + the "distance" of the variable declaration, i.e. the number of scope + levels to go "up" to find the closest declaration for that variable. + + Args: + expr (p.Expr): the variable expression + name (str): the name of the variable + """ + for i, scope in enumerate(reversed(self.scopes)): + if name in scope: + self.locals[expr] = i + return + + def resolve_function(self, function: p.Function) -> None: + """Resolve a function definition + + This method creates a new scope for the function, resolves all the + parameter declarations and then the body. + + Args: + function (p.Function): the function to resolve + """ + self.begin_scope() + for param in function.all_args: + self.declare(param.name) + self.define(param.name) + self.resolve(*function.body) + self.end_scope() + + def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: + stmt.expr.accept(self) + + def visit_function(self, stmt: p.Function) -> None: + # Declare before resolving body to allow recursion + self.declare(stmt.name) + self.define(stmt.name) + self.resolve_function(stmt) + + def visit_type_assign(self, stmt: p.TypeAssign) -> None: + self.declare(stmt.name) + # NOTE: resolve type here? + self.define(stmt.name) + + def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: + self.resolve(stmt.value) + for target in stmt.targets: + match target: + case p.VariableExpr(name=name): + self.resolve_local(target, name) + # TODO: declare if not found + case _: + raise Exception(f"Unsupported assignment to {target}") + + def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: + if stmt.value is not None: + self.resolve(stmt.value) + + def visit_if_stmt(self, stmt: p.IfStmt) -> None: + # Not resolved in sub-environment because assignments in the test leak out of the if + # For example: + # if (m := 1 + 1) < 2: + # ... + # print(m) # <- m is still defined + self.resolve(stmt.test) + + # Body + self.begin_scope() + self.resolve(*stmt.body) + self.end_scope() + + # Else + self.begin_scope() + self.resolve(*stmt.orelse) + self.end_scope() + + def visit_binary_expr(self, expr: p.BinaryExpr) -> None: + self.resolve(expr.left) + self.resolve(expr.right) + + def visit_compare_expr(self, expr: p.CompareExpr) -> None: + self.resolve(expr.left) + self.resolve(expr.right) + + def visit_unary_expr(self, expr: p.UnaryExpr) -> None: + self.resolve(expr.right) + + def visit_call_expr(self, expr: p.CallExpr) -> None: + self.resolve(expr.callee) + for arg in expr.arguments: + self.resolve(arg) + for arg in expr.keywords.values(): + self.resolve(arg) + + def visit_get_expr(self, expr: p.GetExpr) -> None: + self.resolve(expr.object) + + def visit_literal_expr(self, expr: p.LiteralExpr) -> None: + pass + + def visit_variable_expr(self, expr: p.VariableExpr) -> None: + if len(self.scopes) != 0 and self.scopes[-1].get(expr.name) is False: + raise ResolverError( + f"Cannot use local variable '{expr.name}' in its own initializer" + ) # aka. UnboundLocalError + self.resolve_local(expr, expr.name) + + def visit_logical_expr(self, expr: p.LogicalExpr) -> None: + self.resolve(expr.left) + self.resolve(expr.right) + + def visit_set_expr(self, expr: p.SetExpr) -> None: + self.resolve(expr.value) + self.resolve(expr.object) + + def visit_cast_expr(self, expr: p.CastExpr) -> None: + self.resolve(expr.expr) + + def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: + self.resolve(expr.test) + self.resolve(expr.if_true) + self.resolve(expr.if_false) diff --git a/midas/utils.py b/midas/utils.py new file mode 100644 index 0000000..bf10b95 --- /dev/null +++ b/midas/utils.py @@ -0,0 +1,54 @@ +from typing import Any, Callable, Optional + +AllowRepeat = Callable[[object], bool] + + +class UniversalJSONDumper: + @classmethod + def dump( + cls, + obj: Any, + include_keys: Optional[list[str | tuple[str, str]]] = None, + allow_repeat: Optional[AllowRepeat] = None, + ) -> Any: + if include_keys is None: + include_keys = [] + return cls._dump(obj, include_keys, allow_repeat, []) + + @classmethod + def _dump( + cls, + obj: Any, + include_keys: list[str | tuple[str, str]], + allow_repeat: Optional[AllowRepeat], + visited: list[Any], + ) -> Any: + if obj in visited: + return None + match obj: + case str() | int() | float() | None: + return obj + case list() | set() | tuple(): + return [ + cls._dump(child, include_keys, allow_repeat, visited) + for child in obj + ] + case dict(): + return { + str(k): cls._dump(v, include_keys, allow_repeat, visited) + for k, v in obj.items() + } + case object(): + if allow_repeat is None or not allow_repeat(obj): + visited.append(obj) + return { + "_type": obj.__class__.__name__, + } | { + k: cls._dump(v, include_keys, allow_repeat, visited) + for k, v in obj.__dict__.items() + if not k.startswith("_") + or k in include_keys + or (obj.__class__.__name__, k) in include_keys + } + case _: + raise ValueError(f"Unsupported value: {obj}") diff --git a/syntax/midas.ebnf b/syntax/midas.ebnf index 526e122..4626412 100644 --- a/syntax/midas.ebnf +++ b/syntax/midas.ebnf @@ -19,16 +19,24 @@ Comparison ::= Unary (ComparisonOp Unary)* Equality ::= Comparison (EqualityOp Comparison)* Constraint ::= Equality ("&" Equality)* -SimpleType ::= Identifier "?"? -Template ::= "[" Type "]" -Type ::= Identifier Template? "?"? +TemplateParam ::= Identifier ("<:" Type)? +Template ::= "[" (TemplateParam ("," TemplateParam)*)? "]" + + +TypeProperty ::= Identifier ":" Type +ComplexType ::= "{" TypeProperty* "}" +NamedType ::= Identifier +TypeParams ::= "[" (Type ("," Type)*)? "]" +GenericType ::= NamedType TypeParams? +GroupedType ::= "(" Type ")" +BaseType ::= GroupedType | ComplexType | GenericType +ConstraintType ::= BaseType ("where" Constraint)? +Type ::= ConstraintType -TypeProperty ::= Identifier ":" Type ("where" Constraints)? -ComplexTypeBody ::= "{" TypeProperty* "}" OpDefinition ::= "op" Identifier "(" Type ")" "->" Type ExtendBody ::= "{" OpDefinition* "}" -TypeStatement ::= "type" Identifier Template? ("(" Type ")" ("where" Constraint)? | ComplexTypeBody) +TypeStatement ::= "type" Identifier Template? "=" Type ExtendStatement ::= "extend" Type ExtendBody PredicateStatement ::= "predicate" Identifier "(" Identifier ":" Type ")" "=" Constraint diff --git a/syntax/midas.typ b/syntax/midas.typ index 3e16f19..b0b6438 100644 --- a/syntax/midas.typ +++ b/syntax/midas.typ @@ -43,28 +43,52 @@ svg.railroad .terminal rect { {[`constraint` 'equality'*"&"]} ``` -#let simple-type = ``` -{[`simple-type` 'identifier' ]} +#let template-param = ``` +{[`template-param` 'identifier' ]} ``` #let template = ``` -{[`template` "[" 'type' "]"]} -``` - -#let type = ``` -{[`type` 'identifier' ]} +{[`template` "[" "]"]} ``` #let type-property = ``` -{[`type-property` 'identifier' ":" 'type' ]} +{[`type-property` 'identifier' ":" 'type']} ``` -#let type-body = ``` -{[`type-body` "{" "}"]} +#let complex-type = ``` +{[`complex-type` "{" "}"]} +``` + +#let named-type = ``` +{[`named-type` 'identifier']} +``` + +#let type-params = ``` +{[`type-params` "[" "]"]} +``` + +#let generic-type = ``` +{[`generic-type` 'named-type' ]} +``` + +#let grouped-type = ``` +{[`grouped-type` "(" 'type' ")"]} +``` + +#let base-type = ``` +{[`base-type` <'grouped-type', 'complex-type', 'generic-type'>]} +``` + +#let constraint-type = ``` +{[`constraint-type` 'base-type' ]} +``` + +#let type = ``` +{[`type` 'constraint-type']} ``` #let type-statement = ``` -{[`type-statement` "type" 'identifier' <[["(" 'type' ")"] ], 'type-body'>]} +{[`type-statement` "type" 'identifier' "=" 'type']} ``` #let op-definition = ``` @@ -92,11 +116,17 @@ svg.railroad .terminal rect { comparison: comparison, equality: equality, constraint: constraint, - simple-type: simple-type, + template-param: template-param, template: template, - type: type, type-property: type-property, - type-body: type-body, + complex-type: complex-type, + named-type: named-type, + type-params: type-params, + generic-type: generic-type, + grouped-type: grouped-type, + base-type: base-type, + constraint-type: constraint-type, + type: type, type-statement: type-statement, op-definition: op-definition, extend-statement: extend-statement, @@ -107,10 +137,16 @@ svg.railroad .terminal rect { #let inline = ( "grouping", "value", + "template-param", "template", - "simple-type", "type-property", - "type-body", + "complex-type", + "type-params", + "named-type", + "grouped-type", + "generic-type", + "base-type", + "constraint-type", "op-definition", "type-statement", "extend-statement", diff --git a/tester.py b/tester.py deleted file mode 100644 index 3238a67..0000000 --- a/tester.py +++ /dev/null @@ -1,204 +0,0 @@ -from __future__ import annotations - -import argparse -import difflib -import json -import sys -from dataclasses import asdict, dataclass, field -from pathlib import Path -from typing import Iterator, Optional - -from midas.ast.json_serializer import AstJsonSerializer -from midas.ast.midas import Stmt -from midas.lexer.base import MidasSyntaxError -from midas.lexer.midas import MidasLexer -from midas.lexer.token import Token -from midas.parser.midas import MidasParser - -DEFAULT_BASE_DIR: Path = Path() / "tests" - - -@dataclass -class CaseResult: - tokens: Optional[list[dict]] = None - stmts: Optional[list[dict]] = None - errors: list[dict] = field(default_factory=list) - - def dumps(self) -> str: - return json.dumps(asdict(self), indent=2) - - -class Tester: - """A test runner to check for regressions in the lexer and parser""" - - def __init__(self, base_dir: Path): - self.base_dir: Path = base_dir - - def _list_tests(self) -> list[Path]: - return list(self.base_dir.rglob("*.midas")) - - def run_all_tests(self) -> bool: - paths: list[Path] = self._list_tests() - return self.run_tests(paths) - - def run_tests(self, tests: list[Path]) -> bool: - rule: str = "-" * 80 - n: int = len(tests) - successes: int = 0 - failures: int = 0 - - print(rule) - for i, test in enumerate(tests): - print(f"Case {i+1}/{n}: {test}") - success: bool = self._run_test(test) - if success: - successes += 1 - else: - failures += 1 - - print(rule) - print(f"Success: {successes}/{n}") - print(f"Failed: {failures}/{n}") - print(rule) - return failures == 0 - - def _run_test(self, path: Path) -> bool: - result: CaseResult = self._exec_case(path) - result_path: Path = self._result_path(path) - expected: str = result_path.read_text() - actual: str = result.dumps() - - if expected == actual: - return True - - diff = difflib.unified_diff( - expected.splitlines(keepends=True), - actual.splitlines(keepends=True), - fromfile="Snapshot", - tofile="Result", - ) - self._print_diff(diff) - return False - - def _exec_case(self, path: Path) -> CaseResult: - if not path.exists(): - raise FileNotFoundError(f"Could not find test '{path}'") - if not path.is_file(): - raise TypeError(f"Test '{path}' is not a file") - - result: CaseResult = CaseResult() - content: str = path.read_text() - lexer: MidasLexer = MidasLexer(content) - tokens: list[Token] = [] - try: - tokens = lexer.process() - result.tokens = [ - { - "type": token.type.name, - "lexeme": token.lexeme, - "line": token.position.line, - "column": token.position.column, - } - for token in tokens - ] - except MidasSyntaxError as e: - result.errors.append( - { - "type": "SyntaxError", - "line": e.pos.line, - "column": e.pos.column, - "message": e.message, - } - ) - return result - - parser: MidasParser = MidasParser(tokens) - stmts: list[Stmt] = parser.parse() - result.stmts = AstJsonSerializer().serialize(stmts) - result.errors.extend( - [ - { - "line": e.token.position.line, - "column": e.token.position.column, - "message": e.message, - } - for e in parser.errors - ] - ) - return result - - def update_all_tests(self): - paths: list[Path] = self._list_tests() - return self.update_tests(paths) - - def update_tests(self, tests: list[Path]): - updated: int = 0 - for test in tests: - if self._update_test(test): - updated += 1 - print(f"Updated {updated}/{len(tests)} tests") - - def _update_test(self, path: Path) -> bool: - result: CaseResult = self._exec_case(path) - result_path: Path = self._result_path(path) - current: str = result_path.read_text() - new: str = result.dumps() - if current == new: - return False - result_path.write_text(new) - return True - - def _result_path(self, test_path: Path) -> Path: - return test_path.parent / (test_path.name + ".ref.json") - - def _print_diff(self, diff: Iterator[str]): - for line in diff: - if line.startswith("+") and not line.startswith("+++"): - print(f"\033[92m{line}\033[0m", end="") - elif line.startswith("-") and not line.startswith("---"): - print(f"\033[91m{line}\033[0m", end="") - else: - print(line, end="") - print() - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "-D", - "--base-dir", - help="Base directory containing test files", - type=Path, - default=DEFAULT_BASE_DIR, - ) - subparsers = parser.add_subparsers(dest="subcommand") - - update = subparsers.add_parser("update") - update.add_argument("-a", "--all", action="store_true") - update.add_argument("FILE", type=Path, nargs="*") - - run = subparsers.add_parser("run") - run.add_argument("-a", "--all", action="store_true") - run.add_argument("FILE", type=Path, nargs="*") - args = parser.parse_args() - - tester: Tester = Tester(args.base_dir) - - match args.subcommand: - case "update": - if args.all: - tester.update_all_tests() - else: - tester.update_tests(args.FILE) - case "run": - success: bool - if args.all: - success = tester.run_all_tests() - else: - success = tester.run_tests(args.FILE) - if not success: - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/tests/base.py b/tests/base.py new file mode 100644 index 0000000..a4c4c3e --- /dev/null +++ b/tests/base.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import argparse +import difflib +import sys +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Iterator, Protocol + + +class CaseResult(Protocol): + def dumps(self) -> str: ... + + +class Tester(ABC): + """A test runner to check for regressions in the lexer and parser""" + + CASES_DIR: Path = Path(__file__).parent / "cases" + + @property + @abstractmethod + def namespace(self) -> str: ... + + @property + def base_dir(self) -> Path: + return self.CASES_DIR / self.namespace + + @abstractmethod + def _list_tests(self) -> list[Path]: ... + + def run_all_tests(self) -> bool: + paths: list[Path] = self._list_tests() + return self.run_tests(paths) + + def run_tests(self, tests: list[Path]) -> bool: + rule: str = "-" * 80 + n: int = len(tests) + successes: int = 0 + failures: int = 0 + + print(rule) + for i, test in enumerate(tests): + print(f"Case {i+1}/{n}: {test.relative_to(self.CASES_DIR)}") + success: bool = self._run_test(test) + if success: + successes += 1 + else: + failures += 1 + + print(rule) + print(f"Success: {successes}/{n}") + print(f"Failed: {failures}/{n}") + print(rule) + return failures == 0 + + def _run_test(self, path: Path) -> bool: + result_path: Path = self._result_path(path) + if not result_path.exists(): + print("Missing snapshot. Please run the update command first") + return False + result: CaseResult = self._exec_case(path) + expected: str = result_path.read_text() + actual: str = result.dumps() + + if expected == actual: + return True + + diff = difflib.unified_diff( + expected.splitlines(keepends=True), + actual.splitlines(keepends=True), + fromfile="Snapshot", + tofile="Result", + ) + self._print_diff(diff) + return False + + @abstractmethod + def _exec_case(self, path: Path) -> CaseResult: ... + + def update_all_tests(self): + paths: list[Path] = self._list_tests() + return self.update_tests(paths) + + def update_tests(self, tests: list[Path]): + updated: int = 0 + for test in tests: + if self._update_test(test): + updated += 1 + print(f"Updated {updated}/{len(tests)} tests") + + def _update_test(self, path: Path) -> bool: + result: CaseResult = self._exec_case(path) + result_path: Path = self._result_path(path) + current: str = result_path.read_text() if result_path.exists() else "" + new: str = result.dumps() + if current == new: + return False + result_path.write_text(new) + return True + + def _result_path(self, test_path: Path) -> Path: + return test_path.parent / (test_path.name + ".ref.json") + + def _print_diff(self, diff: Iterator[str]): + for line in diff: + if line.startswith("+") and not line.startswith("+++"): + print(f"\033[92m{line}\033[0m", end="") + elif line.startswith("-") and not line.startswith("---"): + print(f"\033[91m{line}\033[0m", end="") + else: + print(line, end="") + print() + + @classmethod + def main(cls): + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="subcommand") + + update = subparsers.add_parser("update") + update.add_argument("-a", "--all", action="store_true") + update.add_argument("FILE", type=Path, nargs="*") + + run = subparsers.add_parser("run") + run.add_argument("-a", "--all", action="store_true") + run.add_argument("FILE", type=Path, nargs="*") + args = parser.parse_args() + + tester: Tester = cls() + + match args.subcommand: + case "update": + if args.all: + tester.update_all_tests() + else: + tester.update_tests(args.FILE) + case "run": + success: bool + if args.all: + success = tester.run_all_tests() + else: + success = tester.run_tests(args.FILE) + if not success: + sys.exit(1) diff --git a/tests/cases/checker/01_simple_types.py b/tests/cases/checker/01_simple_types.py new file mode 100644 index 0000000..3566f9a --- /dev/null +++ b/tests/cases/checker/01_simple_types.py @@ -0,0 +1,14 @@ +# type: ignore +# ruff: disable[F821] +from __future__ import annotations + +df: Frame[ + verified: bool, + birth_year: int, + height: float + ( _ > 0 ) + ( _ < 250 ), + name: str, + date: datetime, + float, + unknown: _, + _ +] diff --git a/tests/cases/checker/01_simple_types.py.ref.json b/tests/cases/checker/01_simple_types.py.ref.json new file mode 100644 index 0000000..c37fb01 --- /dev/null +++ b/tests/cases/checker/01_simple_types.py.ref.json @@ -0,0 +1,3 @@ +{ + "diagnostics": [] +} \ No newline at end of file diff --git a/tests/cases/checker/02_simple_operations.py b/tests/cases/checker/02_simple_operations.py new file mode 100644 index 0000000..9e936c1 --- /dev/null +++ b/tests/cases/checker/02_simple_operations.py @@ -0,0 +1,11 @@ +a: int = 3 +b: int = 4 + +c = a + b + +c = "invalid" + +d = True +e = d + d + +f: float = a diff --git a/tests/cases/checker/02_simple_operations.py.ref.json b/tests/cases/checker/02_simple_operations.py.ref.json new file mode 100644 index 0000000..c390e27 --- /dev/null +++ b/tests/cases/checker/02_simple_operations.py.ref.json @@ -0,0 +1,46 @@ +{ + "diagnostics": [ + { + "type": "Error", + "location": { + "start": [ + 6, + 0 + ], + "end": [ + 6, + 13 + ] + }, + "message": "Cannot assign BaseType(name='str') to c of type BaseType(name='int')" + }, + { + "type": "Error", + "location": { + "start": [ + 9, + 4 + ], + "end": [ + 9, + 9 + ] + }, + "message": "Undefined operation __add__ between BaseType(name='bool') and BaseType(name='bool')" + }, + { + "type": "Error", + "location": { + "start": [ + 11, + 0 + ], + "end": [ + 11, + 12 + ] + }, + "message": "Cannot assign BaseType(name='int') to f of type BaseType(name='float')" + } + ] +} \ No newline at end of file diff --git a/tests/cases/checker/03_functions.py b/tests/cases/checker/03_functions.py new file mode 100644 index 0000000..ddc0a56 --- /dev/null +++ b/tests/cases/checker/03_functions.py @@ -0,0 +1,18 @@ +def foo(a: int, /, b: float, *, c: str): + return True + + +r1 = foo() +r2 = foo(1) +r3 = foo(1, 2.0) +r4 = foo(1, b=2.0) +r5 = foo(1, 2.0, "test") +r6 = foo(1, 2.0, b=3.0) +r7 = foo(a=1) +r8 = foo(g="test") + +r9a = foo(1, 2.0, c="test") +r9b = foo(1, b=2.0, c="test") +r9c = foo(1, c="test", b=2.0) + +r10 = foo("a", 3, c=False) diff --git a/tests/cases/checker/03_functions.py.ref.json b/tests/cases/checker/03_functions.py.ref.json new file mode 100644 index 0000000..40b33b5 --- /dev/null +++ b/tests/cases/checker/03_functions.py.ref.json @@ -0,0 +1,270 @@ +{ + "diagnostics": [ + { + "type": "Error", + "location": { + "start": [ + 5, + 5 + ], + "end": [ + 5, + 10 + ] + }, + "message": "Missing required positional arguments: 'a' and 'b'" + }, + { + "type": "Error", + "location": { + "start": [ + 5, + 5 + ], + "end": [ + 5, + 10 + ] + }, + "message": "Missing required keyword argument: 'c'" + }, + { + "type": "Error", + "location": { + "start": [ + 6, + 5 + ], + "end": [ + 6, + 11 + ] + }, + "message": "Missing required positional argument: 'b'" + }, + { + "type": "Error", + "location": { + "start": [ + 6, + 5 + ], + "end": [ + 6, + 11 + ] + }, + "message": "Missing required keyword argument: 'c'" + }, + { + "type": "Error", + "location": { + "start": [ + 7, + 5 + ], + "end": [ + 7, + 16 + ] + }, + "message": "Missing required keyword argument: 'c'" + }, + { + "type": "Error", + "location": { + "start": [ + 8, + 5 + ], + "end": [ + 8, + 18 + ] + }, + "message": "Missing required keyword argument: 'c'" + }, + { + "type": "Error", + "location": { + "start": [ + 9, + 17 + ], + "end": [ + 9, + 23 + ] + }, + "message": "Too many positional arguments" + }, + { + "type": "Error", + "location": { + "start": [ + 9, + 5 + ], + "end": [ + 9, + 24 + ] + }, + "message": "Missing required keyword argument: 'c'" + }, + { + "type": "Error", + "location": { + "start": [ + 10, + 19 + ], + "end": [ + 10, + 22 + ] + }, + "message": "Multiple values for argument 'b'" + }, + { + "type": "Error", + "location": { + "start": [ + 10, + 5 + ], + "end": [ + 10, + 23 + ] + }, + "message": "Missing required keyword argument: 'c'" + }, + { + "type": "Error", + "location": { + "start": [ + 11, + 11 + ], + "end": [ + 11, + 12 + ] + }, + "message": "Unknown keyword argument 'a'" + }, + { + "type": "Error", + "location": { + "start": [ + 11, + 5 + ], + "end": [ + 11, + 13 + ] + }, + "message": "Missing required positional arguments: 'a' and 'b'" + }, + { + "type": "Error", + "location": { + "start": [ + 11, + 5 + ], + "end": [ + 11, + 13 + ] + }, + "message": "Missing required keyword argument: 'c'" + }, + { + "type": "Error", + "location": { + "start": [ + 12, + 11 + ], + "end": [ + 12, + 17 + ] + }, + "message": "Unknown keyword argument 'g'" + }, + { + "type": "Error", + "location": { + "start": [ + 12, + 5 + ], + "end": [ + 12, + 18 + ] + }, + "message": "Missing required positional arguments: 'a' and 'b'" + }, + { + "type": "Error", + "location": { + "start": [ + 12, + 5 + ], + "end": [ + 12, + 18 + ] + }, + "message": "Missing required keyword argument: 'c'" + }, + { + "type": "Error", + "location": { + "start": [ + 18, + 10 + ], + "end": [ + 18, + 13 + ] + }, + "message": "Wrong type for argument 'a', expected BaseType(name='int'), got BaseType(name='str')" + }, + { + "type": "Error", + "location": { + "start": [ + 18, + 15 + ], + "end": [ + 18, + 16 + ] + }, + "message": "Wrong type for argument 'b', expected BaseType(name='float'), got BaseType(name='int')" + }, + { + "type": "Error", + "location": { + "start": [ + 18, + 20 + ], + "end": [ + 18, + 25 + ] + }, + "message": "Wrong type for argument 'c', expected BaseType(name='str'), got BaseType(name='bool')" + } + ] +} \ No newline at end of file diff --git a/tests/cases/checker/04_custom_types.midas b/tests/cases/checker/04_custom_types.midas new file mode 100644 index 0000000..6a1a6a2 --- /dev/null +++ b/tests/cases/checker/04_custom_types.midas @@ -0,0 +1,14 @@ +type Meter = float +type Second = float +type MeterPerSecond = float + +extend Meter { + op __add__(Meter) -> Meter + op __sub__(Meter) -> Meter + op __truediv__(Second) -> MeterPerSecond +} + +extend Second { + op __add__(Second) -> Second + op __sub__(Second) -> Second +} diff --git a/tests/cases/checker/04_custom_types.py b/tests/cases/checker/04_custom_types.py new file mode 100644 index 0000000..c015a75 --- /dev/null +++ b/tests/cases/checker/04_custom_types.py @@ -0,0 +1,6 @@ +# type: ignore +# ruff: disable [F821] + +distance: Meter = cast(Meter, 123.45) +time: Second = cast(Second, 6.7) +speed = distance / time diff --git a/tests/cases/checker/04_custom_types.py.ref.json b/tests/cases/checker/04_custom_types.py.ref.json new file mode 100644 index 0000000..c37fb01 --- /dev/null +++ b/tests/cases/checker/04_custom_types.py.ref.json @@ -0,0 +1,3 @@ +{ + "diagnostics": [] +} \ No newline at end of file diff --git a/tests/cases/checker/05_control_flow.py b/tests/cases/checker/05_control_flow.py new file mode 100644 index 0000000..486818b --- /dev/null +++ b/tests/cases/checker/05_control_flow.py @@ -0,0 +1,25 @@ +def valid(a: int, b: int) -> int: + return a + b + +def with_if(a: int, b: int) -> int: + if a < b: + return b - a + else: + return a - b + +def unreachable1(): + return + a = 0 + +def unreachable2(a: int) -> int: + if a > 10: + return a - 10 + else: + return a + b = 0 + +def mixed(a: int, b: int): + if a < b: + return b - a + else: + return "oops" diff --git a/tests/cases/checker/05_control_flow.py.ref.json b/tests/cases/checker/05_control_flow.py.ref.json new file mode 100644 index 0000000..a68a7b9 --- /dev/null +++ b/tests/cases/checker/05_control_flow.py.ref.json @@ -0,0 +1,46 @@ +{ + "diagnostics": [ + { + "type": "Warning", + "location": { + "start": [ + 12, + 4 + ], + "end": [ + 12, + 9 + ] + }, + "message": "Unreachable statement" + }, + { + "type": "Warning", + "location": { + "start": [ + 19, + 4 + ], + "end": [ + 19, + 9 + ] + }, + "message": "Unreachable statement" + }, + { + "type": "Error", + "location": { + "start": [ + 21, + 0 + ], + "end": [ + 25, + 21 + ] + }, + "message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]" + } + ] +} \ No newline at end of file diff --git a/tests/cases/parser/01_simple_types.midas b/tests/cases/midas-parser/01_simple_types.midas similarity index 84% rename from tests/cases/parser/01_simple_types.midas rename to tests/cases/midas-parser/01_simple_types.midas index 9432751..6446790 100644 --- a/tests/cases/parser/01_simple_types.midas +++ b/tests/cases/midas-parser/01_simple_types.midas @@ -1,15 +1,15 @@ // Simple custom type derived from float -type Custom(float) +type Custom = float // Simple custom types with constraints -type Latitude(float) where (-90 <= _ <= 90) -type Longitude(float) where (-180 <= _ <= 180) +type Latitude = float where (-90 <= _ <= 90) +type Longitude = float where (-180 <= _ <= 180) // Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float -type Difference[T](T) +type Difference[T] = T // Complex custom type, containing two values accessible through properties -type GeoLocation { +type GeoLocation = { lat: Latitude lon: Longitude } @@ -24,7 +24,7 @@ extend GeoLocation { // For complex generics, you need to specify how the genericity the properties // are handled -type Difference[GeoLocation] { +type Difference[GeoLocation] = { lat: Difference[Latitude] lon: Difference[Longitude] } @@ -44,11 +44,11 @@ predicate StrictlyPositive(v: float) = v > 0 predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10) predicate Arctic(loc: GeoLocation) = (loc.lat >= 66) -type Person { +type Person = { name: str // Property with an inline constraint - age: int? where (0 <= _ < 150) + age: Optional[int where (0 <= _ < 150)] // Property referencing a predicate height: float where StrictlyPositive diff --git a/tests/cases/parser/01_simple_types.midas.ref.json b/tests/cases/midas-parser/01_simple_types.midas.ref.json similarity index 83% rename from tests/cases/parser/01_simple_types.midas.ref.json rename to tests/cases/midas-parser/01_simple_types.midas.ref.json index 9c9aa5b..55b4813 100644 --- a/tests/cases/parser/01_simple_types.midas.ref.json +++ b/tests/cases/midas-parser/01_simple_types.midas.ref.json @@ -31,28 +31,34 @@ "column": 6 }, { - "type": "LEFT_PAREN", - "lexeme": "(", + "type": "WHITESPACE", + "lexeme": " ", "line": 2, "column": 12 }, + { + "type": "EQUAL", + "lexeme": "=", + "line": 2, + "column": 13 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 2, + "column": 14 + }, { "type": "IDENTIFIER", "lexeme": "float", "line": 2, - "column": 13 - }, - { - "type": "RIGHT_PAREN", - "lexeme": ")", - "line": 2, - "column": 18 + "column": 15 }, { "type": "NEWLINE", "lexeme": "\n", "line": 2, - "column": 19 + "column": 20 }, { "type": "NEWLINE", @@ -91,118 +97,124 @@ "column": 6 }, { - "type": "LEFT_PAREN", - "lexeme": "(", + "type": "WHITESPACE", + "lexeme": " ", "line": 5, "column": 14 }, + { + "type": "EQUAL", + "lexeme": "=", + "line": 5, + "column": 15 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 5, + "column": 16 + }, { "type": "IDENTIFIER", "lexeme": "float", "line": 5, - "column": 15 - }, - { - "type": "RIGHT_PAREN", - "lexeme": ")", - "line": 5, - "column": 20 + "column": 17 }, { "type": "WHITESPACE", "lexeme": " ", "line": 5, - "column": 21 + "column": 22 }, { "type": "WHERE", "lexeme": "where", "line": 5, - "column": 22 + "column": 23 }, { "type": "WHITESPACE", "lexeme": " ", "line": 5, - "column": 27 + "column": 28 }, { "type": "LEFT_PAREN", "lexeme": "(", "line": 5, - "column": 28 + "column": 29 }, { "type": "MINUS", "lexeme": "-", "line": 5, - "column": 29 + "column": 30 }, { "type": "NUMBER", "lexeme": "90", "line": 5, - "column": 30 + "column": 31 }, { "type": "WHITESPACE", "lexeme": " ", "line": 5, - "column": 32 + "column": 33 }, { "type": "LESS_EQUAL", "lexeme": "<=", "line": 5, - "column": 33 + "column": 34 }, { "type": "WHITESPACE", "lexeme": " ", "line": 5, - "column": 35 + "column": 36 }, { "type": "UNDERSCORE", "lexeme": "_", "line": 5, - "column": 36 + "column": 37 }, { "type": "WHITESPACE", "lexeme": " ", "line": 5, - "column": 37 + "column": 38 }, { "type": "LESS_EQUAL", "lexeme": "<=", "line": 5, - "column": 38 + "column": 39 }, { "type": "WHITESPACE", "lexeme": " ", "line": 5, - "column": 40 + "column": 41 }, { "type": "NUMBER", "lexeme": "90", "line": 5, - "column": 41 + "column": 42 }, { "type": "RIGHT_PAREN", "lexeme": ")", "line": 5, - "column": 43 + "column": 44 }, { "type": "NEWLINE", "lexeme": "\n", "line": 5, - "column": 44 + "column": 45 }, { "type": "TYPE", @@ -223,118 +235,124 @@ "column": 6 }, { - "type": "LEFT_PAREN", - "lexeme": "(", + "type": "WHITESPACE", + "lexeme": " ", "line": 6, "column": 15 }, + { + "type": "EQUAL", + "lexeme": "=", + "line": 6, + "column": 16 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 6, + "column": 17 + }, { "type": "IDENTIFIER", "lexeme": "float", "line": 6, - "column": 16 - }, - { - "type": "RIGHT_PAREN", - "lexeme": ")", - "line": 6, - "column": 21 + "column": 18 }, { "type": "WHITESPACE", "lexeme": " ", "line": 6, - "column": 22 + "column": 23 }, { "type": "WHERE", "lexeme": "where", "line": 6, - "column": 23 + "column": 24 }, { "type": "WHITESPACE", "lexeme": " ", "line": 6, - "column": 28 + "column": 29 }, { "type": "LEFT_PAREN", "lexeme": "(", "line": 6, - "column": 29 + "column": 30 }, { "type": "MINUS", "lexeme": "-", "line": 6, - "column": 30 + "column": 31 }, { "type": "NUMBER", "lexeme": "180", "line": 6, - "column": 31 + "column": 32 }, { "type": "WHITESPACE", "lexeme": " ", "line": 6, - "column": 34 + "column": 35 }, { "type": "LESS_EQUAL", "lexeme": "<=", "line": 6, - "column": 35 + "column": 36 }, { "type": "WHITESPACE", "lexeme": " ", "line": 6, - "column": 37 + "column": 38 }, { "type": "UNDERSCORE", "lexeme": "_", "line": 6, - "column": 38 + "column": 39 }, { "type": "WHITESPACE", "lexeme": " ", "line": 6, - "column": 39 + "column": 40 }, { "type": "LESS_EQUAL", "lexeme": "<=", "line": 6, - "column": 40 + "column": 41 }, { "type": "WHITESPACE", "lexeme": " ", "line": 6, - "column": 42 + "column": 43 }, { "type": "NUMBER", "lexeme": "180", "line": 6, - "column": 43 + "column": 44 }, { "type": "RIGHT_PAREN", "lexeme": ")", "line": 6, - "column": 46 + "column": 47 }, { "type": "NEWLINE", "lexeme": "\n", "line": 6, - "column": 47 + "column": 48 }, { "type": "NEWLINE", @@ -391,28 +409,34 @@ "column": 18 }, { - "type": "LEFT_PAREN", - "lexeme": "(", + "type": "WHITESPACE", + "lexeme": " ", "line": 9, "column": 19 }, + { + "type": "EQUAL", + "lexeme": "=", + "line": 9, + "column": 20 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 9, + "column": 21 + }, { "type": "IDENTIFIER", "lexeme": "T", "line": 9, - "column": 20 - }, - { - "type": "RIGHT_PAREN", - "lexeme": ")", - "line": 9, - "column": 21 + "column": 22 }, { "type": "NEWLINE", "lexeme": "\n", "line": 9, - "column": 22 + "column": 23 }, { "type": "NEWLINE", @@ -456,17 +480,29 @@ "line": 12, "column": 17 }, + { + "type": "EQUAL", + "lexeme": "=", + "line": 12, + "column": 18 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 12, + "column": 19 + }, { "type": "LEFT_BRACE", "lexeme": "{", "line": 12, - "column": 18 + "column": 20 }, { "type": "NEWLINE", "lexeme": "\n", "line": 12, - "column": 19 + "column": 21 }, { "type": "WHITESPACE", @@ -834,17 +870,29 @@ "line": 27, "column": 29 }, + { + "type": "EQUAL", + "lexeme": "=", + "line": 27, + "column": 30 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 27, + "column": 31 + }, { "type": "LEFT_BRACE", "lexeme": "{", "line": 27, - "column": 30 + "column": 32 }, { "type": "NEWLINE", "lexeme": "\n", "line": 27, - "column": 31 + "column": 33 }, { "type": "WHITESPACE", @@ -1824,17 +1872,29 @@ "line": 47, "column": 12 }, + { + "type": "EQUAL", + "lexeme": "=", + "line": 47, + "column": 13 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 47, + "column": 14 + }, { "type": "LEFT_BRACE", "lexeme": "{", "line": 47, - "column": 13 + "column": 15 }, { "type": "NEWLINE", "lexeme": "\n", "line": 47, - "column": 14 + "column": 16 }, { "type": "WHITESPACE", @@ -1922,70 +1982,34 @@ }, { "type": "IDENTIFIER", - "lexeme": "int", + "lexeme": "Optional", "line": 51, "column": 10 }, { - "type": "QMARK", - "lexeme": "?", + "type": "LEFT_BRACKET", + "lexeme": "[", "line": 51, - "column": 13 + "column": 18 + }, + { + "type": "IDENTIFIER", + "lexeme": "int", + "line": 51, + "column": 19 }, { "type": "WHITESPACE", "lexeme": " ", "line": 51, - "column": 14 + "column": 22 }, { "type": "WHERE", "lexeme": "where", "line": 51, - "column": 15 - }, - { - "type": "WHITESPACE", - "lexeme": " ", - "line": 51, - "column": 20 - }, - { - "type": "LEFT_PAREN", - "lexeme": "(", - "line": 51, - "column": 21 - }, - { - "type": "NUMBER", - "lexeme": "0", - "line": 51, - "column": 22 - }, - { - "type": "WHITESPACE", - "lexeme": " ", - "line": 51, "column": 23 }, - { - "type": "LESS_EQUAL", - "lexeme": "<=", - "line": 51, - "column": 24 - }, - { - "type": "WHITESPACE", - "lexeme": " ", - "line": 51, - "column": 26 - }, - { - "type": "UNDERSCORE", - "lexeme": "_", - "line": 51, - "column": 27 - }, { "type": "WHITESPACE", "lexeme": " ", @@ -1993,34 +2017,82 @@ "column": 28 }, { - "type": "LESS", - "lexeme": "<", + "type": "LEFT_PAREN", + "lexeme": "(", "line": 51, "column": 29 }, + { + "type": "NUMBER", + "lexeme": "0", + "line": 51, + "column": 30 + }, { "type": "WHITESPACE", "lexeme": " ", "line": 51, - "column": 30 + "column": 31 + }, + { + "type": "LESS_EQUAL", + "lexeme": "<=", + "line": 51, + "column": 32 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 51, + "column": 34 + }, + { + "type": "UNDERSCORE", + "lexeme": "_", + "line": 51, + "column": 35 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 51, + "column": 36 + }, + { + "type": "LESS", + "lexeme": "<", + "line": 51, + "column": 37 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 51, + "column": 38 }, { "type": "NUMBER", "lexeme": "150", "line": 51, - "column": 31 + "column": 39 }, { "type": "RIGHT_PAREN", "lexeme": ")", "line": 51, - "column": 34 + "column": 42 + }, + { + "type": "RIGHT_BRACKET", + "lexeme": "]", + "line": 51, + "column": 43 }, { "type": "NEWLINE", "lexeme": "\n", "line": 51, - "column": 35 + "column": 44 }, { "type": "NEWLINE", @@ -2169,259 +2241,235 @@ ], "stmts": [ { - "_type": "SimpleTypeStmt", - "template": null, + "_type": "TypeStmt", "name": "Custom", - "base": { - "_type": "TypeExpr", - "name": "float", - "template": null, - "optional": false - }, - "constraint": null + "params": [], + "type": { + "_type": "NamedType", + "name": "float" + } }, { - "_type": "SimpleTypeStmt", - "template": null, + "_type": "TypeStmt", "name": "Latitude", - "base": { - "_type": "TypeExpr", - "name": "float", - "template": null, - "optional": false - }, - "constraint": { - "_type": "GroupingExpr", - "expr": { - "_type": "BinaryExpr", - "left": { + "params": [], + "type": { + "_type": "ConstraintType", + "type": { + "_type": "NamedType", + "name": "float" + }, + "constraint": { + "_type": "GroupingExpr", + "expr": { "_type": "BinaryExpr", "left": { - "_type": "UnaryExpr", - "operator": "-", + "_type": "BinaryExpr", + "left": { + "_type": "UnaryExpr", + "operator": "-", + "right": { + "_type": "LiteralExpr", + "value": 90.0 + } + }, + "operator": "<=", "right": { - "_type": "LiteralExpr", - "value": 90.0 + "_type": "WildcardExpr" } }, "operator": "<=", "right": { - "_type": "WildcardExpr" + "_type": "LiteralExpr", + "value": 90.0 } - }, - "operator": "<=", - "right": { - "_type": "LiteralExpr", - "value": 90.0 } } } }, { - "_type": "SimpleTypeStmt", - "template": null, + "_type": "TypeStmt", "name": "Longitude", - "base": { - "_type": "TypeExpr", - "name": "float", - "template": null, - "optional": false - }, - "constraint": { - "_type": "GroupingExpr", - "expr": { - "_type": "BinaryExpr", - "left": { + "params": [], + "type": { + "_type": "ConstraintType", + "type": { + "_type": "NamedType", + "name": "float" + }, + "constraint": { + "_type": "GroupingExpr", + "expr": { "_type": "BinaryExpr", "left": { - "_type": "UnaryExpr", - "operator": "-", + "_type": "BinaryExpr", + "left": { + "_type": "UnaryExpr", + "operator": "-", + "right": { + "_type": "LiteralExpr", + "value": 180.0 + } + }, + "operator": "<=", "right": { - "_type": "LiteralExpr", - "value": 180.0 + "_type": "WildcardExpr" } }, "operator": "<=", "right": { - "_type": "WildcardExpr" + "_type": "LiteralExpr", + "value": 180.0 } - }, - "operator": "<=", - "right": { - "_type": "LiteralExpr", - "value": 180.0 } } } }, { - "_type": "SimpleTypeStmt", - "template": { - "_type": "TemplateExpr", - "type": { - "_type": "TypeExpr", + "_type": "TypeStmt", + "name": "Difference", + "params": [ + { "name": "T", - "template": null, - "optional": false + "bound": null } - }, - "name": "Difference", - "base": { - "_type": "TypeExpr", - "name": "T", - "template": null, - "optional": false - }, - "constraint": null + ], + "type": { + "_type": "NamedType", + "name": "T" + } }, { - "_type": "ComplexTypeStmt", + "_type": "TypeStmt", "name": "GeoLocation", - "template": null, - "properties": [ - { - "_type": "PropertyStmt", - "name": "lat", - "type": { - "_type": "TypeExpr", - "name": "Latitude", - "template": null, - "optional": false + "params": [], + "type": { + "_type": "ComplexType", + "properties": [ + { + "_type": "PropertyStmt", + "name": "lat", + "type": { + "_type": "NamedType", + "name": "Latitude" + } }, - "constraint": null - }, - { - "_type": "PropertyStmt", - "name": "lon", - "type": { - "_type": "TypeExpr", - "name": "Longitude", - "template": null, - "optional": false - }, - "constraint": null - } - ] + { + "_type": "PropertyStmt", + "name": "lon", + "type": { + "_type": "NamedType", + "name": "Longitude" + } + } + ] + } }, { "_type": "ExtendStmt", "type": { - "_type": "TypeExpr", - "name": "GeoLocation", - "template": null, - "optional": false + "_type": "NamedType", + "name": "GeoLocation" }, "operations": [ { "_type": "OpStmt", "name": "__sub__", "operand": { - "_type": "TypeExpr", - "name": "GeoLocation", - "template": null, - "optional": false + "_type": "NamedType", + "name": "GeoLocation" }, "result": { - "_type": "TypeExpr", - "name": "Difference", - "template": { - "_type": "TemplateExpr", - "type": { - "_type": "TypeExpr", - "name": "GeoLocation", - "template": null, - "optional": false - } + "_type": "GenericType", + "type": { + "_type": "NamedType", + "name": "Difference" }, - "optional": false + "params": [ + { + "_type": "NamedType", + "name": "GeoLocation" + } + ] } } ] }, { - "_type": "ComplexTypeStmt", + "_type": "TypeStmt", "name": "Difference", - "template": { - "_type": "TemplateExpr", - "type": { - "_type": "TypeExpr", + "params": [ + { "name": "GeoLocation", - "template": null, - "optional": false + "bound": null } - }, - "properties": [ - { - "_type": "PropertyStmt", - "name": "lat", - "type": { - "_type": "TypeExpr", - "name": "Difference", - "template": { - "_type": "TemplateExpr", + ], + "type": { + "_type": "ComplexType", + "properties": [ + { + "_type": "PropertyStmt", + "name": "lat", + "type": { + "_type": "GenericType", "type": { - "_type": "TypeExpr", - "name": "Latitude", - "template": null, - "optional": false - } - }, - "optional": false + "_type": "NamedType", + "name": "Difference" + }, + "params": [ + { + "_type": "NamedType", + "name": "Latitude" + } + ] + } }, - "constraint": null - }, - { - "_type": "PropertyStmt", - "name": "lon", - "type": { - "_type": "TypeExpr", - "name": "Difference", - "template": { - "_type": "TemplateExpr", + { + "_type": "PropertyStmt", + "name": "lon", + "type": { + "_type": "GenericType", "type": { - "_type": "TypeExpr", - "name": "Longitude", - "template": null, - "optional": false - } - }, - "optional": false - }, - "constraint": null - } - ] + "_type": "NamedType", + "name": "Difference" + }, + "params": [ + { + "_type": "NamedType", + "name": "Longitude" + } + ] + } + } + ] + } }, { "_type": "ExtendStmt", "type": { - "_type": "TypeExpr", - "name": "Latitude", - "template": null, - "optional": false + "_type": "NamedType", + "name": "Latitude" }, "operations": [ { "_type": "OpStmt", "name": "__sub__", "operand": { - "_type": "TypeExpr", - "name": "Latitude", - "template": null, - "optional": false + "_type": "NamedType", + "name": "Latitude" }, "result": { - "_type": "TypeExpr", - "name": "Difference", - "template": { - "_type": "TemplateExpr", - "type": { - "_type": "TypeExpr", - "name": "Latitude", - "template": null, - "optional": false - } + "_type": "GenericType", + "type": { + "_type": "NamedType", + "name": "Difference" }, - "optional": false + "params": [ + { + "_type": "NamedType", + "name": "Latitude" + } + ] } } ] @@ -2429,34 +2477,29 @@ { "_type": "ExtendStmt", "type": { - "_type": "TypeExpr", - "name": "Longitude", - "template": null, - "optional": false + "_type": "NamedType", + "name": "Longitude" }, "operations": [ { "_type": "OpStmt", "name": "__sub__", "operand": { - "_type": "TypeExpr", - "name": "Longitude", - "template": null, - "optional": false + "_type": "NamedType", + "name": "Longitude" }, "result": { - "_type": "TypeExpr", - "name": "Difference", - "template": { - "_type": "TemplateExpr", - "type": { - "_type": "TypeExpr", - "name": "Longitude", - "template": null, - "optional": false - } + "_type": "GenericType", + "type": { + "_type": "NamedType", + "name": "Difference" }, - "optional": false + "params": [ + { + "_type": "NamedType", + "name": "Longitude" + } + ] } } ] @@ -2466,10 +2509,8 @@ "name": "Positive", "subject": "v", "type": { - "_type": "TypeExpr", - "name": "float", - "template": null, - "optional": false + "_type": "NamedType", + "name": "float" }, "condition": { "_type": "BinaryExpr", @@ -2489,10 +2530,8 @@ "name": "StrictlyPositive", "subject": "v", "type": { - "_type": "TypeExpr", - "name": "float", - "template": null, - "optional": false + "_type": "NamedType", + "name": "float" }, "condition": { "_type": "BinaryExpr", @@ -2512,10 +2551,8 @@ "name": "Equatorial", "subject": "loc", "type": { - "_type": "TypeExpr", - "name": "GeoLocation", - "template": null, - "optional": false + "_type": "NamedType", + "name": "GeoLocation" }, "condition": { "_type": "GroupingExpr", @@ -2554,10 +2591,8 @@ "name": "Arctic", "subject": "loc", "type": { - "_type": "TypeExpr", - "name": "GeoLocation", - "template": null, - "optional": false + "_type": "NamedType", + "name": "GeoLocation" }, "condition": { "_type": "GroupingExpr", @@ -2580,79 +2615,87 @@ } }, { - "_type": "ComplexTypeStmt", + "_type": "TypeStmt", "name": "Person", - "template": null, - "properties": [ - { - "_type": "PropertyStmt", - "name": "name", - "type": { - "_type": "TypeExpr", - "name": "str", - "template": null, - "optional": false + "params": [], + "type": { + "_type": "ComplexType", + "properties": [ + { + "_type": "PropertyStmt", + "name": "name", + "type": { + "_type": "NamedType", + "name": "str" + } }, - "constraint": null - }, - { - "_type": "PropertyStmt", - "name": "age", - "type": { - "_type": "TypeExpr", - "name": "int", - "template": null, - "optional": true - }, - "constraint": { - "_type": "GroupingExpr", - "expr": { - "_type": "BinaryExpr", - "left": { - "_type": "BinaryExpr", - "left": { - "_type": "LiteralExpr", - "value": 0.0 - }, - "operator": "<=", - "right": { - "_type": "WildcardExpr" - } + { + "_type": "PropertyStmt", + "name": "age", + "type": { + "_type": "GenericType", + "type": { + "_type": "NamedType", + "name": "Optional" }, - "operator": "<", - "right": { - "_type": "LiteralExpr", - "value": 150.0 + "params": [ + { + "_type": "ConstraintType", + "type": { + "_type": "NamedType", + "name": "int" + }, + "constraint": { + "_type": "GroupingExpr", + "expr": { + "_type": "BinaryExpr", + "left": { + "_type": "BinaryExpr", + "left": { + "_type": "LiteralExpr", + "value": 0.0 + }, + "operator": "<=", + "right": { + "_type": "WildcardExpr" + } + }, + "operator": "<", + "right": { + "_type": "LiteralExpr", + "value": 150.0 + } + } + } + } + ] + } + }, + { + "_type": "PropertyStmt", + "name": "height", + "type": { + "_type": "ConstraintType", + "type": { + "_type": "NamedType", + "name": "float" + }, + "constraint": { + "_type": "VariableExpr", + "name": "StrictlyPositive" } } - } - }, - { - "_type": "PropertyStmt", - "name": "height", - "type": { - "_type": "TypeExpr", - "name": "float", - "template": null, - "optional": false }, - "constraint": { - "_type": "VariableExpr", - "name": "StrictlyPositive" + { + "_type": "PropertyStmt", + "name": "home", + "type": { + "_type": "NamedType", + "name": "GeoLocation" + } } - }, - { - "_type": "PropertyStmt", - "name": "home", - "type": { - "_type": "TypeExpr", - "name": "GeoLocation", - "template": null, - "optional": false - }, - "constraint": null - } - ] + ] + } } ], "errors": [] diff --git a/tests/cases/python-parser/01_simple_types.py b/tests/cases/python-parser/01_simple_types.py new file mode 100644 index 0000000..3566f9a --- /dev/null +++ b/tests/cases/python-parser/01_simple_types.py @@ -0,0 +1,14 @@ +# type: ignore +# ruff: disable[F821] +from __future__ import annotations + +df: Frame[ + verified: bool, + birth_year: int, + height: float + ( _ > 0 ) + ( _ < 250 ), + name: str, + date: datetime, + float, + unknown: _, + _ +] diff --git a/tests/cases/python-parser/01_simple_types.py.ref.json b/tests/cases/python-parser/01_simple_types.py.ref.json new file mode 100644 index 0000000..e4fd591 --- /dev/null +++ b/tests/cases/python-parser/01_simple_types.py.ref.json @@ -0,0 +1,85 @@ +{ + "stmts": [ + { + "_type": "TypeAssign", + "name": "df", + "type": { + "_type": "FrameType", + "columns": [ + { + "_type": "FrameColumn", + "name": "verified", + "type": { + "_type": "BaseType", + "base": "bool", + "param": null + } + }, + { + "_type": "FrameColumn", + "name": "birth_year", + "type": { + "_type": "BaseType", + "base": "int", + "param": null + } + }, + { + "_type": "FrameColumn", + "name": "height", + "type": { + "_type": "ConstraintType", + "type": { + "_type": "BaseType", + "base": "float", + "param": null + }, + "constraint": "(_ > 0) + (_ < 250)" + } + }, + { + "_type": "FrameColumn", + "name": "name", + "type": { + "_type": "BaseType", + "base": "str", + "param": null + } + }, + { + "_type": "FrameColumn", + "name": "date", + "type": { + "_type": "BaseType", + "base": "datetime", + "param": null + } + }, + { + "_type": "FrameColumn", + "name": null, + "type": { + "_type": "BaseType", + "base": "float", + "param": null + } + }, + { + "_type": "FrameColumn", + "name": "unknown", + "type": null + }, + { + "_type": "FrameColumn", + "name": null, + "type": { + "_type": "BaseType", + "base": "_", + "param": null + } + } + ] + } + } + ] +} \ No newline at end of file diff --git a/tests/cases/python-parser/02_custom_types.py b/tests/cases/python-parser/02_custom_types.py new file mode 100644 index 0000000..1725bf4 --- /dev/null +++ b/tests/cases/python-parser/02_custom_types.py @@ -0,0 +1,25 @@ +# type: ignore +# ruff: disable[F821] +from __future__ import annotations + +df: Frame[ + location: GeoLocation +] + +lat: Column[GeoLocation] = df["location"].lat +lon: Column[GeoLocation] = df["location"].lon + +lat + lon + +lat1: Latitude = lat[0] +lat2: Latitude = lat[1] +lat_diff: Difference[Latitude] = lat2 - lat1 + +df2: Frame[ + age: int + (_ >= 0), + height: float + (_ >= 0), +] +df2_bis: Frame[ + age: int + Positive, + height: float + Positive, +] diff --git a/tests/cases/python-parser/02_custom_types.py.ref.json b/tests/cases/python-parser/02_custom_types.py.ref.json new file mode 100644 index 0000000..639610d --- /dev/null +++ b/tests/cases/python-parser/02_custom_types.py.ref.json @@ -0,0 +1,141 @@ +{ + "stmts": [ + { + "_type": "TypeAssign", + "name": "df", + "type": { + "_type": "FrameType", + "columns": [ + { + "_type": "FrameColumn", + "name": "location", + "type": { + "_type": "BaseType", + "base": "GeoLocation", + "param": null + } + } + ] + } + }, + { + "_type": "ExpressionStmt", + "expr": { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "lat" + }, + "operator": "+", + "right": { + "_type": "VariableExpr", + "name": "lon" + } + } + }, + { + "_type": "TypeAssign", + "name": "lat_diff", + "type": { + "_type": "BaseType", + "base": "Difference", + "param": { + "_type": "BaseType", + "base": "Latitude", + "param": null + } + } + }, + { + "_type": "AssignStmt", + "targets": [ + { + "_type": "VariableExpr", + "name": "lat_diff" + } + ], + "value": { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "lat2" + }, + "operator": "-", + "right": { + "_type": "VariableExpr", + "name": "lat1" + } + } + }, + { + "_type": "TypeAssign", + "name": "df2", + "type": { + "_type": "FrameType", + "columns": [ + { + "_type": "FrameColumn", + "name": "age", + "type": { + "_type": "ConstraintType", + "type": { + "_type": "BaseType", + "base": "int", + "param": null + }, + "constraint": "_ >= 0" + } + }, + { + "_type": "FrameColumn", + "name": "height", + "type": { + "_type": "ConstraintType", + "type": { + "_type": "BaseType", + "base": "float", + "param": null + }, + "constraint": "_ >= 0" + } + } + ] + } + }, + { + "_type": "TypeAssign", + "name": "df2_bis", + "type": { + "_type": "FrameType", + "columns": [ + { + "_type": "FrameColumn", + "name": "age", + "type": { + "_type": "ConstraintType", + "type": { + "_type": "BaseType", + "base": "int", + "param": null + }, + "constraint": "Positive" + } + }, + { + "_type": "FrameColumn", + "name": "height", + "type": { + "_type": "ConstraintType", + "type": { + "_type": "BaseType", + "base": "float", + "param": null + }, + "constraint": "Positive" + } + } + ] + } + } + ] +} \ No newline at end of file diff --git a/tests/cases/python-parser/03_functions.py b/tests/cases/python-parser/03_functions.py new file mode 100644 index 0000000..3b07899 --- /dev/null +++ b/tests/cases/python-parser/03_functions.py @@ -0,0 +1,15 @@ +# type: ignore +# ruff: disable[F821] +from __future__ import annotations + + +def func( + col1: Column[float + (0 <= _ <= 1)], + col2: Column[float + (0 <= _ <= 1)], +) -> Column[float + (0 <= _ <= 2)]: + result: Column[float + (0 <= _ <= 2)] = col1 + col2 + return result + + +def func2(a: int, /, b: float, *, c: str): + pass diff --git a/tests/cases/python-parser/03_functions.py.ref.json b/tests/cases/python-parser/03_functions.py.ref.json new file mode 100644 index 0000000..529455b --- /dev/null +++ b/tests/cases/python-parser/03_functions.py.ref.json @@ -0,0 +1,149 @@ +{ + "stmts": [ + { + "_type": "Function", + "name": "func", + "posonlyargs": [], + "args": [ + { + "name": "col1", + "type": { + "_type": "BaseType", + "base": "Column", + "param": { + "_type": "ConstraintType", + "type": { + "_type": "BaseType", + "base": "float", + "param": null + }, + "constraint": "0 <= _ <= 1" + } + }, + "default": null + }, + { + "name": "col2", + "type": { + "_type": "BaseType", + "base": "Column", + "param": { + "_type": "ConstraintType", + "type": { + "_type": "BaseType", + "base": "float", + "param": null + }, + "constraint": "0 <= _ <= 1" + } + }, + "default": null + } + ], + "sink": null, + "kwonlyargs": [], + "kw_sink": null, + "returns": { + "_type": "BaseType", + "base": "Column", + "param": { + "_type": "ConstraintType", + "type": { + "_type": "BaseType", + "base": "float", + "param": null + }, + "constraint": "0 <= _ <= 2" + } + }, + "body": [ + { + "_type": "TypeAssign", + "name": "result", + "type": { + "_type": "BaseType", + "base": "Column", + "param": { + "_type": "ConstraintType", + "type": { + "_type": "BaseType", + "base": "float", + "param": null + }, + "constraint": "0 <= _ <= 2" + } + } + }, + { + "_type": "AssignStmt", + "targets": [ + { + "_type": "VariableExpr", + "name": "result" + } + ], + "value": { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "col1" + }, + "operator": "+", + "right": { + "_type": "VariableExpr", + "name": "col2" + } + } + }, + { + "_type": "ReturnStmt", + "value": { + "_type": "VariableExpr", + "name": "result" + } + } + ] + }, + { + "_type": "Function", + "name": "func2", + "posonlyargs": [ + { + "name": "a", + "type": { + "_type": "BaseType", + "base": "int", + "param": null + }, + "default": null + } + ], + "args": [ + { + "name": "b", + "type": { + "_type": "BaseType", + "base": "float", + "param": null + }, + "default": null + } + ], + "sink": null, + "kwonlyargs": [ + { + "name": "c", + "type": { + "_type": "BaseType", + "base": "str", + "param": null + }, + "default": null + } + ], + "kw_sink": null, + "returns": null, + "body": [] + } + ] +} \ No newline at end of file diff --git a/tests/checker.py b/tests/checker.py new file mode 100644 index 0000000..d0a7b3e --- /dev/null +++ b/tests/checker.py @@ -0,0 +1,75 @@ +import ast +import json +from dataclasses import asdict, dataclass, field +from pathlib import Path + +import midas.ast.python as p +from midas.checker.checker import Checker +from midas.checker.diagnostic import Diagnostic +from midas.parser.python import PythonParser +from midas.resolver.resolver import Resolver +from tests.base import Tester + + +@dataclass +class CaseResult: + diagnostics: list[dict] = field(default_factory=list) + + def dumps(self) -> str: + return json.dumps(asdict(self), indent=2) + + +class CheckerTester(Tester): + @property + def namespace(self) -> str: + return "checker" + + def _list_tests(self) -> list[Path]: + return list(self.base_dir.rglob("*.py")) + + def _exec_case(self, path: Path) -> CaseResult: + if not path.exists(): + raise FileNotFoundError(f"Could not find test '{path}'") + if not path.is_file(): + raise TypeError(f"Test '{path}' is not a file") + + types_paths: list[Path] = [] + types_path: Path = path.with_suffix(".midas") + if types_path.exists(): + types_paths.append(types_path) + source: str = path.read_text() + tree: ast.Module = ast.parse(source, filename=path) + parser = PythonParser() + stmts: list[p.Stmt] = parser.parse_module(tree) + resolver = Resolver() + resolver.resolve(*stmts) + result: CaseResult = CaseResult() + checker = Checker( + resolver.locals, + source_path=path, + types_paths=types_paths, + ) + diagnostics: list[Diagnostic] = checker.check(stmts) + for diagnostic in diagnostics: + result.diagnostics.append( + { + "type": str(diagnostic.type), + "location": { + "start": ( + diagnostic.location.lineno, + diagnostic.location.col_offset, + ), + "end": ( + diagnostic.location.end_lineno, + diagnostic.location.end_col_offset, + ), + }, + "message": diagnostic.message, + } + ) + + return result + + +if __name__ == "__main__": + CheckerTester.main() diff --git a/tests/midas.py b/tests/midas.py new file mode 100644 index 0000000..4474976 --- /dev/null +++ b/tests/midas.py @@ -0,0 +1,82 @@ +import json +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Optional + +from midas.ast.midas import Stmt +from midas.lexer.base import MidasSyntaxError +from midas.lexer.midas import MidasLexer +from midas.lexer.token import Token +from midas.parser.midas import MidasParser +from tests.base import Tester +from tests.serializer.midas import MidasAstJsonSerializer + + +@dataclass +class CaseResult: + tokens: Optional[list[dict]] = None + stmts: Optional[list[dict]] = None + errors: list[dict] = field(default_factory=list) + + def dumps(self) -> str: + return json.dumps(asdict(self), indent=2) + + +class MidasTester(Tester): + @property + def namespace(self) -> str: + return "midas-parser" + + def _list_tests(self) -> list[Path]: + return list(self.base_dir.rglob("*.midas")) + + def _exec_case(self, path: Path) -> CaseResult: + if not path.exists(): + raise FileNotFoundError(f"Could not find test '{path}'") + if not path.is_file(): + raise TypeError(f"Test '{path}' is not a file") + + result: CaseResult = CaseResult() + content: str = path.read_text() + lexer: MidasLexer = MidasLexer(content) + tokens: list[Token] = [] + try: + tokens = lexer.process() + result.tokens = [ + { + "type": token.type.name, + "lexeme": token.lexeme, + "line": token.position.line, + "column": token.position.column, + } + for token in tokens + ] + except MidasSyntaxError as e: + result.errors.append( + { + "type": "SyntaxError", + "line": e.pos.line, + "column": e.pos.column, + "message": e.message, + } + ) + return result + + parser: MidasParser = MidasParser(tokens) + stmts: list[Stmt] = parser.parse() + result.stmts = MidasAstJsonSerializer().serialize(stmts) + result.errors.extend( + [ + { + "line": e.token.position.line, + "column": e.token.position.column, + "message": e.message, + } + for e in parser.errors + ] + ) + return result + + +if __name__ == "__main__": + MidasTester.main() diff --git a/tests/python.py b/tests/python.py new file mode 100644 index 0000000..a6cbb7b --- /dev/null +++ b/tests/python.py @@ -0,0 +1,46 @@ +import ast +import json +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Optional + +from midas.ast.python import Stmt +from midas.parser.python import PythonParser +from tests.base import Tester +from tests.serializer.python import PythonAstJsonSerializer + + +@dataclass +class CaseResult: + stmts: Optional[list[dict]] = None + + def dumps(self) -> str: + return json.dumps(asdict(self), indent=2) + + +class PythonTester(Tester): + @property + def namespace(self) -> str: + return "python-parser" + + def _list_tests(self) -> list[Path]: + return list(self.base_dir.rglob("*.py")) + + def _exec_case(self, path: Path) -> CaseResult: + if not path.exists(): + raise FileNotFoundError(f"Could not find test '{path}'") + if not path.is_file(): + raise TypeError(f"Test '{path}' is not a file") + + result: CaseResult = CaseResult() + content: str = path.read_text() + tree: ast.Module = ast.parse(content) + + parser: PythonParser = PythonParser() + stmts: list[Stmt] = parser.parse_module(tree) + result.stmts = PythonAstJsonSerializer().serialize(stmts) + return result + + +if __name__ == "__main__": + PythonTester.main() diff --git a/midas/ast/json_serializer.py b/tests/serializer/midas.py similarity index 68% rename from midas/ast/json_serializer.py rename to tests/serializer/midas.py index d602117..919dc66 100644 --- a/midas/ast/json_serializer.py +++ b/tests/serializer/midas.py @@ -2,56 +2,60 @@ from typing import Optional, Sequence from midas.ast.midas import ( BinaryExpr, - ComplexTypeStmt, + ComplexType, + ConstraintType, Expr, ExtendStmt, + GenericType, GetExpr, GroupingExpr, LiteralExpr, LogicalExpr, + NamedType, OpStmt, PredicateStmt, PropertyStmt, - SimpleTypeExpr, - SimpleTypeStmt, Stmt, - TemplateExpr, - TypeExpr, + Type, + TypeStmt, UnaryExpr, VariableExpr, WildcardExpr, ) -class AstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]): +class MidasAstJsonSerializer( + Stmt.Visitor[dict], Expr.Visitor[dict], Type.Visitor[dict] +): """An AST serializer which produces a JSON-compatible structure""" def serialize(self, stmts: list[Stmt]) -> list[dict]: return [stmt.accept(self) for stmt in stmts] - def _serialize_optional(self, element: Optional[Stmt | Expr]) -> Optional[dict]: + def _serialize_optional( + self, element: Optional[Stmt | Expr | Type] + ) -> Optional[dict]: if element is None: return None return element.accept(self) - def _serialize_list(self, elements: Sequence[Stmt | Expr]) -> list[dict]: + def _serialize_list(self, elements: Sequence[Stmt | Expr | Type]) -> list[dict]: return [element.accept(self) for element in elements] - def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> dict: + def visit_type_stmt(self, stmt: TypeStmt) -> dict: return { - "_type": "SimpleTypeStmt", - "template": self._serialize_optional(stmt.template), + "_type": "TypeStmt", "name": stmt.name.lexeme, - "base": stmt.base.accept(self), - "constraint": self._serialize_optional(stmt.constraint), + "params": [ + self._serialize_type_stmt_template_param(param) for param in stmt.params + ], + "type": stmt.type.accept(self), } - def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> dict: + def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict: return { - "_type": "ComplexTypeStmt", - "name": stmt.name.lexeme, - "template": self._serialize_optional(stmt.template), - "properties": self._serialize_list(stmt.properties), + "name": param.name.lexeme, + "bound": self._serialize_optional(param.bound), } def visit_property_stmt(self, stmt: PropertyStmt) -> dict: @@ -59,7 +63,6 @@ class AstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]): "_type": "PropertyStmt", "name": stmt.name.lexeme, "type": stmt.type.accept(self), - "constraint": self._serialize_optional(stmt.constraint), } def visit_extend_stmt(self, stmt: ExtendStmt) -> dict: @@ -86,13 +89,6 @@ class AstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]): "condition": stmt.condition.accept(self), } - def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> dict: - return { - "_type": "SimpleTypeExpr", - "name": expr.name.lexeme, - "optional": expr.optional, - } - def visit_logical_expr(self, expr: LogicalExpr) -> dict: return { "_type": "LogicalExpr", @@ -144,16 +140,28 @@ class AstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]): def visit_wildcard_expr(self, expr: WildcardExpr) -> dict: return {"_type": "WildcardExpr"} - def visit_template_expr(self, expr: TemplateExpr) -> dict: + def visit_named_type(self, type: NamedType) -> dict: return { - "_type": "TemplateExpr", - "type": expr.type.accept(self), + "_type": "NamedType", + "name": type.name.lexeme, } - def visit_type_expr(self, expr: TypeExpr) -> dict: + def visit_generic_type(self, type: GenericType) -> dict: return { - "_type": "TypeExpr", - "name": expr.name.lexeme, - "template": self._serialize_optional(expr.template), - "optional": expr.optional, + "_type": "GenericType", + "type": type.type.accept(self), + "params": self._serialize_list(type.params), + } + + def visit_constraint_type(self, type: ConstraintType) -> dict: + return { + "_type": "ConstraintType", + "type": type.type.accept(self), + "constraint": type.constraint.accept(self), + } + + def visit_complex_type(self, type: ComplexType) -> dict: + return { + "_type": "ComplexType", + "properties": self._serialize_list(type.properties), } diff --git a/tests/serializer/python.py b/tests/serializer/python.py new file mode 100644 index 0000000..786b15b --- /dev/null +++ b/tests/serializer/python.py @@ -0,0 +1,256 @@ +import ast +from typing import Optional, Sequence, Type + +from midas.ast.python import ( + AssignStmt, + BaseType, + BinaryExpr, + CallExpr, + CastExpr, + CompareExpr, + ConstraintType, + Expr, + ExpressionStmt, + FrameColumn, + FrameType, + Function, + GetExpr, + IfStmt, + LiteralExpr, + LogicalExpr, + MidasType, + ReturnStmt, + SetExpr, + Stmt, + TernaryExpr, + TypeAssign, + UnaryExpr, + VariableExpr, +) + +unary_ops: dict[Type[ast.unaryop], str] = { + ast.Invert: "~", + ast.Not: "not", + ast.UAdd: "+", + ast.USub: "-", +} +binary_ops: dict[Type[ast.operator], str] = { + ast.Add: "+", + ast.Sub: "-", + ast.Mult: "*", + ast.MatMult: "@", + ast.Div: "/", + ast.Mod: "%", + ast.LShift: "<<", + ast.RShift: ">>", + ast.BitOr: "|", + ast.BitXor: "^", + ast.BitAnd: "&", + ast.FloorDiv: "//", + ast.Pow: "**", +} +compare_ops: dict[Type[ast.cmpop], str] = { + ast.Eq: "==", + ast.NotEq: "!=", + ast.Lt: "<", + ast.LtE: "<=", + ast.Gt: ">", + ast.GtE: ">=", + ast.Is: "is", + ast.IsNot: "is not", + ast.In: "in", + ast.NotIn: "not in", +} +boolean_ops: dict[Type[ast.boolop], str] = { + ast.And: "and", + ast.Or: "or", +} + + +class PythonAstJsonSerializer( + Stmt.Visitor[dict], Expr.Visitor[dict], MidasType.Visitor[dict] +): + """An AST serializer which produces a JSON-compatible structure""" + + def serialize(self, stmts: list[Stmt]) -> list[dict]: + return [stmt.accept(self) for stmt in stmts] + + def _serialize_optional( + self, element: Optional[Stmt | Expr | MidasType] + ) -> Optional[dict]: + if element is None: + return None + return element.accept(self) + + def _serialize_list( + self, elements: Sequence[Stmt | Expr | MidasType] + ) -> list[dict]: + return [element.accept(self) for element in elements] + + def visit_base_type(self, node: BaseType) -> dict: + return { + "_type": "BaseType", + "base": node.base, + "param": self._serialize_optional(node.param), + } + + def visit_constraint_type(self, node: ConstraintType) -> dict: + return { + "_type": "ConstraintType", + "type": node.type.accept(self), + "constraint": ast.unparse(node.constraint), + } + + def visit_frame_column(self, node: FrameColumn) -> dict: + return { + "_type": "FrameColumn", + "name": node.name, + "type": self._serialize_optional(node.type), + } + + def visit_frame_type(self, node: FrameType) -> dict: + return { + "_type": "FrameType", + "columns": self._serialize_list(node.columns), + } + + def visit_expression_stmt(self, stmt: ExpressionStmt) -> dict: + return { + "_type": "ExpressionStmt", + "expr": stmt.expr.accept(self), + } + + def _serialize_argument(self, arg: Function.Argument) -> dict: + return { + "name": arg.name, + "type": self._serialize_optional(arg.type), + "default": self._serialize_optional(arg.default), + } + + def visit_function(self, stmt: Function) -> dict: + return { + "_type": "Function", + "name": stmt.name, + "posonlyargs": [self._serialize_argument(arg) for arg in stmt.posonlyargs], + "args": [self._serialize_argument(arg) for arg in stmt.args], + "sink": ( + self._serialize_argument(stmt.sink) if stmt.sink is not None else None + ), + "kwonlyargs": [self._serialize_argument(arg) for arg in stmt.kwonlyargs], + "kw_sink": ( + self._serialize_argument(stmt.kw_sink) + if stmt.kw_sink is not None + else None + ), + "returns": self._serialize_optional(stmt.returns), + "body": self._serialize_list(stmt.body), + } + + def visit_type_assign(self, stmt: TypeAssign) -> dict: + return { + "_type": "TypeAssign", + "name": stmt.name, + "type": stmt.type.accept(self), + } + + def visit_assign_stmt(self, stmt: AssignStmt) -> dict: + return { + "_type": "AssignStmt", + "targets": self._serialize_list(stmt.targets), + "value": stmt.value.accept(self), + } + + def visit_return_stmt(self, stmt: ReturnStmt) -> dict: + return { + "_type": "ReturnStmt", + "value": self._serialize_optional(stmt.value), + } + + def visit_if_stmt(self, stmt: IfStmt) -> dict: + return { + "_type": "IfStmt", + "test": stmt.test.accept(self), + "body": self._serialize_list(stmt.body), + "orelse": self._serialize_list(stmt.orelse), + } + + def visit_binary_expr(self, expr: BinaryExpr) -> dict: + return { + "_type": "BinaryExpr", + "left": expr.left.accept(self), + "operator": binary_ops[expr.operator.__class__], + "right": expr.right.accept(self), + } + + def visit_compare_expr(self, expr: CompareExpr) -> dict: + return { + "_type": "CompareExpr", + "left": expr.left.accept(self), + "operator": compare_ops[expr.operator.__class__], + "right": expr.right.accept(self), + } + + def visit_unary_expr(self, expr: UnaryExpr) -> dict: + return { + "_type": "UnaryExpr", + "operator": unary_ops[expr.operator.__class__], + "right": expr.right.accept(self), + } + + def visit_call_expr(self, expr: CallExpr) -> dict: + return { + "_type": "CallExpr", + "callee": expr.callee.accept(self), + "arguments": self._serialize_list(expr.arguments), + "keywords": {name: arg.accept(self) for name, arg in expr.keywords.items()}, + } + + def visit_get_expr(self, expr: GetExpr) -> dict: + return { + "_type": "GetExpr", + "object": expr.object.accept(self), + "name": expr.name, + } + + def visit_literal_expr(self, expr: LiteralExpr) -> dict: + return { + "_type": "LiteralExpr", + "value": expr.value, + } + + def visit_variable_expr(self, expr: VariableExpr) -> dict: + return { + "_type": "VariableExpr", + "name": expr.name, + } + + def visit_logical_expr(self, expr: LogicalExpr) -> dict: + return { + "_type": "LogicalExpr", + "left": expr.left.accept(self), + "operator": boolean_ops[expr.operator.__class__], + "right": expr.right.accept(self), + } + + def visit_set_expr(self, expr: SetExpr) -> dict: + return { + "_type": "SetExpr", + "object": expr.object.accept(self), + "name": expr.name, + "value": expr.value.accept(self), + } + + def visit_cast_expr(self, expr: CastExpr) -> dict: + return { + "_type": "CastExpr", + "type": expr.type.accept(self), + "expr": expr.expr.accept(self), + } + + def visit_ternary_expr(self, expr: TernaryExpr) -> dict: + return { + "_type": "TernaryExpr", + "test": expr.test.accept(self), + "if_true": expr.if_true.accept(self), + "if_false": expr.if_false.accept(self), + } diff --git a/vscode-ext/syntaxes/midas.tmLanguage.json b/vscode-ext/syntaxes/midas.tmLanguage.json index 44745b0..20d1ded 100644 --- a/vscode-ext/syntaxes/midas.tmLanguage.json +++ b/vscode-ext/syntaxes/midas.tmLanguage.json @@ -31,22 +31,32 @@ ] }, "type-base": { - "begin": "<", - "end": ">", + "begin": "(\\()([a-zA-Z_][a-zA-Z_\\d]*)(\\))", + "end": "$", "beginCaptures": { - "0": { + "1": { "name": "punctuation.definition.base.begin.midas" - } - }, - "endCaptures": { - "0": { + }, + "2": { + "name": "variable.name" + }, + "3": { "name": "punctuation.definition.base.end.midas" } }, "patterns": [ - {"include": "source.python"} + { "include": "#type-cond" } ] }, + "type-cond": { + "begin": "where", + "end": "$", + "beginCaptures": { + "0": { + "name": "keyword.control.where.midas" + } + } + }, "type-body": { "begin": "\\{", "end": "\\}", @@ -61,7 +71,8 @@ } }, "patterns": [ - {"include": "#type-prop"} + {"include": "#type-prop"}, + {"include": "#comment"} ] }, "type-prop": { @@ -78,44 +89,67 @@ } } }, - "op-def": { - "match": "\\b(op)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>\\s+(\\S+)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>\\s+(=)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>", - "captures": { - "1": { - "name": "keyword.control.op.midas" - }, - "2": { - "name" : "variable.name" - }, - "3": { - "name" : "keyword.operator" - }, - "4": { - "name" : "variable.name" - }, - "5": { - "name" : "keyword.operator.assignment" - }, - "6": { - "name" : "variable.name" - } - }, - "patterns": [ - { "include": "#type-base" }, - { "include": "#type-body" } - ] - }, - "constr-def": { - "begin": "(constraint)\\s+([a-zA-Z_][a-zA-Z_\\d]*)\\s*(=)", - "end": "$", + "extend-def": { + "begin": "\\b(extend)\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\s+(\\{)", + "end": "\\}", "beginCaptures": { "1": { - "name": "keyword.control.constr.midas" + "name": "keyword.control.extend.midas" }, "2": { "name": "variable.name" }, "3": { + "name": "punctuation.definition.extend-body.begin.midas" + } + }, + "endCaptures": { + "0": { + "name": "punctuation.definition.extend-body.end.midas" + } + }, + "patterns": [ + {"include": "#op-def"}, + {"include": "#comment"} + ] + }, + "op-def": { + "match": "\\b(op)\\s+(\\S+)\\s*\\(\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\s*\\)\\s*(->)\\s*([a-zA-Z_][a-zA-Z_\\d]*)", + "captures": { + "1": { + "name": "keyword.control.op.midas" + }, + "2": { + "name" : "keyword.operator" + }, + "3": { + "name" : "variable.name" + }, + "4": { + "name" : "keyword.operator.assignment" + }, + "5": { + "name" : "variable.name" + } + } + }, + "pred-def": { + "begin": "(predicate)\\s+([a-zA-Z_][a-zA-Z_\\d]*)\\(([a-zA-Z_][a-zA-Z_\\d]*):\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\)\\s*(=)", + "end": "$", + "beginCaptures": { + "1": { + "name": "keyword.control.pred.midas" + }, + "2": { + "name": "variable.name" + }, + "3": { + "name": "variable.name" + }, + "4": { + "name": "variable.name" + }, + "5": { "name": "keyword.operator.assignment" } }, @@ -127,8 +161,8 @@ "patterns": [ { "include": "#comment" }, { "include": "#type-def" }, - { "include": "#op-def" }, - { "include": "#constr-def" } + { "include": "#extend-def" }, + { "include": "#pred-def" } ] } }