diff --git a/gen/midas.py b/gen/midas.py index 4141217..72813d4 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -4,6 +4,7 @@ ###> Imports from abc import ABC, abstractmethod from dataclasses import dataclass +from enum import Enum, auto from typing import Any, Generic, Optional, TypeVar from midas.ast.location import Location @@ -20,6 +21,11 @@ class TypeParam: bound: Optional[Type] +class MemberKind(Enum): + PROPERTY = auto() + METHOD = auto() + + ###< @@ -33,12 +39,13 @@ class TypeStmt: class MemberStmt: name: Token type: Type + kind: MemberKind class ExtendStmt: + name: Token params: list[TypeParam] - type: Type - operations: list[OpStmt] + members: list[MemberStmt] class OpStmt: diff --git a/midas/ast/midas.py b/midas/ast/midas.py index 36d959b..affd768 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -7,6 +7,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass +from enum import Enum, auto from typing import Any, Generic, Optional, TypeVar from midas.ast.location import Location @@ -21,6 +22,11 @@ class TypeParam: bound: Optional[Type] +class MemberKind(Enum): + PROPERTY = auto() + METHOD = auto() + + ############## # Statements # ############## @@ -64,6 +70,7 @@ class TypeStmt(Stmt): class MemberStmt(Stmt): name: Token type: Type + kind: MemberKind def accept(self, visitor: Stmt.Visitor[T]) -> T: return visitor.visit_member_stmt(self) @@ -71,9 +78,9 @@ class MemberStmt(Stmt): @dataclass(frozen=True) class ExtendStmt(Stmt): + name: Token params: list[TypeParam] - type: Type - operations: list[OpStmt] + members: list[MemberStmt] def accept(self, visitor: Stmt.Visitor[T]) -> T: return visitor.visit_extend_stmt(self) diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 2a5eec3..c9a9d33 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -114,6 +114,7 @@ class MidasAstPrinter( def visit_member_stmt(self, stmt: m.MemberStmt): self._write_line("MemberStmt") with self._child_level(): + self._write_line(f"kind: {stmt.kind.name}") self._write_line(f'name: "{stmt.name.lexeme}"') self._write_line("type", last=True) with self._child_level(single=True): @@ -129,16 +130,21 @@ class MidasAstPrinter( if i == len(stmt.params) - 1: self._mark_last() self._print_type_param(param) - self._write_line("type") - with self._child_level(single=True): - stmt.type.accept(self) - self._write_line("operations", last=True) + self._write_line(f'name: "{stmt.name.lexeme}"') + self._write_line("params") with self._child_level(): - for i, op in enumerate(stmt.operations): + for i, param in enumerate(stmt.params): self._idx = i - if i == len(stmt.operations) - 1: + if i == len(stmt.params) - 1: self._mark_last() - op.accept(self) + self._print_type_param(param) + self._write_line("members", last=True) + with self._child_level(): + for i, member in enumerate(stmt.members): + self._idx = i + if i == len(stmt.members) - 1: + self._mark_last() + member.accept(self) def visit_op_stmt(self, stmt: m.OpStmt) -> None: self._write_line("OpStmt") @@ -343,15 +349,23 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] return res def visit_member_stmt(self, stmt: m.MemberStmt): - res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}" + keyword: str = { + m.MemberKind.PROPERTY: "prop", + m.MemberKind.METHOD: "def", + }.get(stmt.kind, "") + res: str = f"{keyword} {stmt.name.lexeme}: {stmt.type.accept(self)}" return self.indented(res) def visit_extend_stmt(self, stmt: m.ExtendStmt): - res: str = self.indented(f"extend {stmt.type.accept(self)}") + template: str = "" + if len(stmt.params) != 0: + params: list[str] = [self._print_type_param(param) for param in stmt.params] + template = f"[{', '.join(params)}]" + res: str = self.indented(f"extend {stmt.name.lexeme}{template}") res += " {\n" self.level += 1 - for op in stmt.operations: - res += op.accept(self) + for member in stmt.members: + res += member.accept(self) + "\n" self.level -= 1 res += self.indented("}") return res diff --git a/midas/lexer/token.py b/midas/lexer/token.py index 60b6e47..95e0c18 100644 --- a/midas/lexer/token.py +++ b/midas/lexer/token.py @@ -50,6 +50,8 @@ class TokenType(Enum): PREDICATE = auto() EXTEND = auto() WHERE = auto() + PROP = auto() + DEF = auto() FUNC = auto() # Misc @@ -68,6 +70,8 @@ KEYWORDS: dict[str, TokenType] = { "true": TokenType.TRUE, "false": TokenType.FALSE, "none": TokenType.NONE, + "prop": TokenType.PROP, + "def": TokenType.DEF, "fn": TokenType.FUNC, } diff --git a/midas/parser/midas.py b/midas/parser/midas.py index 0d0cbde..ce94b2d 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -14,6 +14,7 @@ from midas.ast.midas import ( GroupingExpr, LiteralExpr, LogicalExpr, + MemberKind, MemberStmt, NamedType, OpStmt, @@ -394,18 +395,28 @@ class MidasParser(Parser): def member_stmt(self) -> MemberStmt: """Parse a member statement - A type member statement is written `name: Type` + A type member statement is written `prop name: Type` or `def name: Type` Returns: MemberStmt: the parsed member statement """ + kind: MemberKind + if self.match(TokenType.PROP): + kind = MemberKind.PROPERTY + elif self.match(TokenType.DEF): + kind = MemberKind.METHOD + else: + raise self.error(self.peek(), "Expected 'prop' or 'def'") + name: Token = self.consume_identifier("Expected member name") self.consume(TokenType.COLON, "Expected ':' after member name") + type: Type = self.type_expr() return MemberStmt( location=name.location_to(self.previous()), name=name, type=type, + kind=kind, ) def extend_declaration(self) -> ExtendStmt: @@ -417,20 +428,20 @@ class MidasParser(Parser): ExtendStmt: the parsed extension statement """ keyword: Token = self.previous() + name: Token = self.consume_identifier("Expected type name") params: list[TypeParam] = self.type_params() - type: Type = self.type_expr() self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body") - operations: list[OpStmt] = [] + members: list[MemberStmt] = [] while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE): - operations.append(self.op_declaration()) + members.append(self.member_stmt()) self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body") location: Location = keyword.location_to(self.previous()) return ExtendStmt( location=location, + name=name, params=params, - type=type, - operations=operations, + members=members, ) def op_declaration(self) -> OpStmt: