feat(parser): accept props and methods in extend

This commit is contained in:
2026-06-12 16:41:33 +02:00
parent 01d6e41893
commit 0461a4184c
5 changed files with 64 additions and 21 deletions

View File

@@ -4,6 +4,7 @@
###> Imports ###> Imports
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Generic, Optional, TypeVar from typing import Any, Generic, Optional, TypeVar
from midas.ast.location import Location from midas.ast.location import Location
@@ -20,6 +21,11 @@ class TypeParam:
bound: Optional[Type] bound: Optional[Type]
class MemberKind(Enum):
PROPERTY = auto()
METHOD = auto()
###< ###<
@@ -33,12 +39,13 @@ class TypeStmt:
class MemberStmt: class MemberStmt:
name: Token name: Token
type: Type type: Type
kind: MemberKind
class ExtendStmt: class ExtendStmt:
name: Token
params: list[TypeParam] params: list[TypeParam]
type: Type members: list[MemberStmt]
operations: list[OpStmt]
class OpStmt: class OpStmt:

View File

@@ -7,6 +7,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Generic, Optional, TypeVar from typing import Any, Generic, Optional, TypeVar
from midas.ast.location import Location from midas.ast.location import Location
@@ -21,6 +22,11 @@ class TypeParam:
bound: Optional[Type] bound: Optional[Type]
class MemberKind(Enum):
PROPERTY = auto()
METHOD = auto()
############## ##############
# Statements # # Statements #
############## ##############
@@ -64,6 +70,7 @@ class TypeStmt(Stmt):
class MemberStmt(Stmt): class MemberStmt(Stmt):
name: Token name: Token
type: Type type: Type
kind: MemberKind
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_member_stmt(self) return visitor.visit_member_stmt(self)
@@ -71,9 +78,9 @@ class MemberStmt(Stmt):
@dataclass(frozen=True) @dataclass(frozen=True)
class ExtendStmt(Stmt): class ExtendStmt(Stmt):
name: Token
params: list[TypeParam] params: list[TypeParam]
type: Type members: list[MemberStmt]
operations: list[OpStmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_extend_stmt(self) return visitor.visit_extend_stmt(self)

View File

@@ -114,6 +114,7 @@ class MidasAstPrinter(
def visit_member_stmt(self, stmt: m.MemberStmt): def visit_member_stmt(self, stmt: m.MemberStmt):
self._write_line("MemberStmt") self._write_line("MemberStmt")
with self._child_level(): with self._child_level():
self._write_line(f"kind: {stmt.kind.name}")
self._write_line(f'name: "{stmt.name.lexeme}"') self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type", last=True) self._write_line("type", last=True)
with self._child_level(single=True): with self._child_level(single=True):
@@ -129,16 +130,21 @@ class MidasAstPrinter(
if i == len(stmt.params) - 1: if i == len(stmt.params) - 1:
self._mark_last() self._mark_last()
self._print_type_param(param) self._print_type_param(param)
self._write_line("type") self._write_line(f'name: "{stmt.name.lexeme}"')
with self._child_level(single=True): self._write_line("params")
stmt.type.accept(self)
self._write_line("operations", last=True)
with self._child_level(): with self._child_level():
for i, op in enumerate(stmt.operations): for i, param in enumerate(stmt.params):
self._idx = i self._idx = i
if i == len(stmt.operations) - 1: if i == len(stmt.params) - 1:
self._mark_last() self._mark_last()
op.accept(self) self._print_type_param(param)
self._write_line("members", last=True)
with self._child_level():
for i, member in enumerate(stmt.members):
self._idx = i
if i == len(stmt.members) - 1:
self._mark_last()
member.accept(self)
def visit_op_stmt(self, stmt: m.OpStmt) -> None: def visit_op_stmt(self, stmt: m.OpStmt) -> None:
self._write_line("OpStmt") self._write_line("OpStmt")
@@ -343,15 +349,23 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
return res return res
def visit_member_stmt(self, stmt: m.MemberStmt): def visit_member_stmt(self, stmt: m.MemberStmt):
res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}" keyword: str = {
m.MemberKind.PROPERTY: "prop",
m.MemberKind.METHOD: "def",
}.get(stmt.kind, "")
res: str = f"{keyword} {stmt.name.lexeme}: {stmt.type.accept(self)}"
return self.indented(res) return self.indented(res)
def visit_extend_stmt(self, stmt: m.ExtendStmt): def visit_extend_stmt(self, stmt: m.ExtendStmt):
res: str = self.indented(f"extend {stmt.type.accept(self)}") template: str = ""
if len(stmt.params) != 0:
params: list[str] = [self._print_type_param(param) for param in stmt.params]
template = f"[{', '.join(params)}]"
res: str = self.indented(f"extend {stmt.name.lexeme}{template}")
res += " {\n" res += " {\n"
self.level += 1 self.level += 1
for op in stmt.operations: for member in stmt.members:
res += op.accept(self) res += member.accept(self) + "\n"
self.level -= 1 self.level -= 1
res += self.indented("}") res += self.indented("}")
return res return res

View File

@@ -50,6 +50,8 @@ class TokenType(Enum):
PREDICATE = auto() PREDICATE = auto()
EXTEND = auto() EXTEND = auto()
WHERE = auto() WHERE = auto()
PROP = auto()
DEF = auto()
FUNC = auto() FUNC = auto()
# Misc # Misc
@@ -68,6 +70,8 @@ KEYWORDS: dict[str, TokenType] = {
"true": TokenType.TRUE, "true": TokenType.TRUE,
"false": TokenType.FALSE, "false": TokenType.FALSE,
"none": TokenType.NONE, "none": TokenType.NONE,
"prop": TokenType.PROP,
"def": TokenType.DEF,
"fn": TokenType.FUNC, "fn": TokenType.FUNC,
} }

View File

@@ -14,6 +14,7 @@ from midas.ast.midas import (
GroupingExpr, GroupingExpr,
LiteralExpr, LiteralExpr,
LogicalExpr, LogicalExpr,
MemberKind,
MemberStmt, MemberStmt,
NamedType, NamedType,
OpStmt, OpStmt,
@@ -394,18 +395,28 @@ class MidasParser(Parser):
def member_stmt(self) -> MemberStmt: def member_stmt(self) -> MemberStmt:
"""Parse a member statement """Parse a member statement
A type member statement is written `name: Type` A type member statement is written `prop name: Type` or `def name: Type`
Returns: Returns:
MemberStmt: the parsed member statement MemberStmt: the parsed member statement
""" """
kind: MemberKind
if self.match(TokenType.PROP):
kind = MemberKind.PROPERTY
elif self.match(TokenType.DEF):
kind = MemberKind.METHOD
else:
raise self.error(self.peek(), "Expected 'prop' or 'def'")
name: Token = self.consume_identifier("Expected member name") name: Token = self.consume_identifier("Expected member name")
self.consume(TokenType.COLON, "Expected ':' after member name") self.consume(TokenType.COLON, "Expected ':' after member name")
type: Type = self.type_expr() type: Type = self.type_expr()
return MemberStmt( return MemberStmt(
location=name.location_to(self.previous()), location=name.location_to(self.previous()),
name=name, name=name,
type=type, type=type,
kind=kind,
) )
def extend_declaration(self) -> ExtendStmt: def extend_declaration(self) -> ExtendStmt:
@@ -417,20 +428,20 @@ class MidasParser(Parser):
ExtendStmt: the parsed extension statement ExtendStmt: the parsed extension statement
""" """
keyword: Token = self.previous() keyword: Token = self.previous()
name: Token = self.consume_identifier("Expected type name")
params: list[TypeParam] = self.type_params() params: list[TypeParam] = self.type_params()
type: Type = self.type_expr()
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body") self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
operations: list[OpStmt] = [] members: list[MemberStmt] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE): while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
operations.append(self.op_declaration()) members.append(self.member_stmt())
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body") self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
location: Location = keyword.location_to(self.previous()) location: Location = keyword.location_to(self.previous())
return ExtendStmt( return ExtendStmt(
location=location, location=location,
name=name,
params=params, params=params,
type=type, members=members,
operations=operations,
) )
def op_declaration(self) -> OpStmt: def op_declaration(self) -> OpStmt: