From 31158df2a9c82efe1651fe331f0b4fae052ba14a Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Thu, 11 Jun 2026 13:42:19 +0200 Subject: [PATCH] feat(parser): add extension type and rename properties --- gen/midas.py | 9 +++++++-- midas/ast/midas.py | 20 ++++++++++++++++---- midas/ast/printer.py | 35 ++++++++++++++++++++++++----------- 3 files changed, 47 insertions(+), 17 deletions(-) diff --git a/gen/midas.py b/gen/midas.py index 16a4dd8..4141217 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -30,7 +30,7 @@ class TypeStmt: type: Type -class PropertyStmt: +class MemberStmt: name: Token type: Type @@ -118,7 +118,12 @@ class ConstraintType: class ComplexType: - properties: list[PropertyStmt] + members: list[MemberStmt] + + +class ExtensionType: + base: Type + extension: ComplexType class FunctionType: diff --git a/midas/ast/midas.py b/midas/ast/midas.py index 00e71c8..36d959b 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -38,7 +38,7 @@ class Stmt(ABC): def visit_type_stmt(self, stmt: TypeStmt) -> T: ... @abstractmethod - def visit_property_stmt(self, stmt: PropertyStmt) -> T: ... + def visit_member_stmt(self, stmt: MemberStmt) -> T: ... @abstractmethod def visit_extend_stmt(self, stmt: ExtendStmt) -> T: ... @@ -61,12 +61,12 @@ class TypeStmt(Stmt): @dataclass(frozen=True) -class PropertyStmt(Stmt): +class MemberStmt(Stmt): name: Token type: Type def accept(self, visitor: Stmt.Visitor[T]) -> T: - return visitor.visit_property_stmt(self) + return visitor.visit_member_stmt(self) @dataclass(frozen=True) @@ -233,6 +233,9 @@ class Type(ABC): @abstractmethod def visit_complex_type(self, type: ComplexType) -> T: ... + @abstractmethod + def visit_extension_type(self, type: ExtensionType) -> T: ... + @abstractmethod def visit_function_type(self, type: FunctionType) -> T: ... @@ -265,12 +268,21 @@ class ConstraintType(Type): @dataclass(frozen=True) class ComplexType(Type): - properties: list[PropertyStmt] + members: list[MemberStmt] def accept(self, visitor: Type.Visitor[T]) -> T: return visitor.visit_complex_type(self) +@dataclass(frozen=True) +class ExtensionType(Type): + base: Type + extension: ComplexType + + def accept(self, visitor: Type.Visitor[T]) -> T: + return visitor.visit_extension_type(self) + + @dataclass(frozen=True) class FunctionType(Type): pos_args: list[Argument] diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 5d109ef..2a5eec3 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -111,8 +111,8 @@ class MidasAstPrinter( self._write_line(f'name: "{param.name.lexeme}"') self._write_optional_child("bound", param.bound, last=True) - def visit_property_stmt(self, stmt: m.PropertyStmt): - self._write_line("PropertyStmt") + def visit_member_stmt(self, stmt: m.MemberStmt): + self._write_line("MemberStmt") with self._child_level(): self._write_line(f'name: "{stmt.name.lexeme}"') self._write_line("type", last=True) @@ -262,13 +262,23 @@ class MidasAstPrinter( def visit_complex_type(self, type: m.ComplexType) -> None: self._write_line("ComplexType") with self._child_level(): - self._write_line("properties", last=True) + self._write_line("members", last=True) with self._child_level(): - for i, prop in enumerate(type.properties): + for i, member in enumerate(type.members): self._idx = i - if i == len(type.properties) - 1: + if i == len(type.members) - 1: self._mark_last() - prop.accept(self) + member.accept(self) + + def visit_extension_type(self, type: m.ExtensionType) -> None: + self._write_line("ExtensionType") + with self._child_level(): + self._write_line("base") + with self._child_level(single=True): + type.base.accept(self) + self._write_line("extension", last=True) + with self._child_level(single=True): + type.extension.accept(self) def visit_function_type(self, type: m.FunctionType) -> None: self._write_line("FunctionType") @@ -332,7 +342,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] res += "<:" + param.bound.accept(self) return res - def visit_property_stmt(self, stmt: m.PropertyStmt): + def visit_member_stmt(self, stmt: m.MemberStmt): res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}" return self.indented(res) @@ -411,16 +421,19 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] def visit_complex_type(self, type: m.ComplexType) -> str: res: str = "{\n" self.level += 1 - for prop in type.properties: - res += prop.accept(self) + for member in type.members: + res += member.accept(self) res += "\n" self.level -= 1 res += self.indented("}") return res + def visit_extension_type(self, type: m.ExtensionType) -> str: + return f"{type.base.accept(self)} & {type.extension.accept(self)}" + def visit_function_type(self, type: m.FunctionType) -> str: pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args] - kw_args: list[str] = [self._print_arg(arg) for arg in type.pos_args] + kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args] args: list[str] = pos_args if len(pos_args) != 0: @@ -429,7 +442,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] args.append("*") args += kw_args - return f"({', '.join(args)}) -> {type.returns.accept(self)}" + return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}" def _print_arg(self, arg: m.FunctionType.Argument) -> str: res: str = ""