feat(parser): add extension type and rename properties

This commit is contained in:
2026-06-11 13:42:19 +02:00
parent c6ead886ec
commit 31158df2a9
3 changed files with 47 additions and 17 deletions

View File

@@ -30,7 +30,7 @@ class TypeStmt:
type: Type
class PropertyStmt:
class MemberStmt:
name: Token
type: Type
@@ -118,7 +118,12 @@ class ConstraintType:
class ComplexType:
properties: list[PropertyStmt]
members: list[MemberStmt]
class ExtensionType:
base: Type
extension: ComplexType
class FunctionType:

View File

@@ -38,7 +38,7 @@ class Stmt(ABC):
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
@abstractmethod
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
def visit_member_stmt(self, stmt: MemberStmt) -> T: ...
@abstractmethod
def visit_extend_stmt(self, stmt: ExtendStmt) -> T: ...
@@ -61,12 +61,12 @@ class TypeStmt(Stmt):
@dataclass(frozen=True)
class PropertyStmt(Stmt):
class MemberStmt(Stmt):
name: Token
type: Type
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_property_stmt(self)
return visitor.visit_member_stmt(self)
@dataclass(frozen=True)
@@ -233,6 +233,9 @@ class Type(ABC):
@abstractmethod
def visit_complex_type(self, type: ComplexType) -> T: ...
@abstractmethod
def visit_extension_type(self, type: ExtensionType) -> T: ...
@abstractmethod
def visit_function_type(self, type: FunctionType) -> T: ...
@@ -265,12 +268,21 @@ class ConstraintType(Type):
@dataclass(frozen=True)
class ComplexType(Type):
properties: list[PropertyStmt]
members: list[MemberStmt]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_complex_type(self)
@dataclass(frozen=True)
class ExtensionType(Type):
base: Type
extension: ComplexType
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_extension_type(self)
@dataclass(frozen=True)
class FunctionType(Type):
pos_args: list[Argument]

View File

@@ -111,8 +111,8 @@ class MidasAstPrinter(
self._write_line(f'name: "{param.name.lexeme}"')
self._write_optional_child("bound", param.bound, last=True)
def visit_property_stmt(self, stmt: m.PropertyStmt):
self._write_line("PropertyStmt")
def visit_member_stmt(self, stmt: m.MemberStmt):
self._write_line("MemberStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type", last=True)
@@ -262,13 +262,23 @@ class MidasAstPrinter(
def visit_complex_type(self, type: m.ComplexType) -> None:
self._write_line("ComplexType")
with self._child_level():
self._write_line("properties", last=True)
self._write_line("members", last=True)
with self._child_level():
for i, prop in enumerate(type.properties):
for i, member in enumerate(type.members):
self._idx = i
if i == len(type.properties) - 1:
if i == len(type.members) - 1:
self._mark_last()
prop.accept(self)
member.accept(self)
def visit_extension_type(self, type: m.ExtensionType) -> None:
self._write_line("ExtensionType")
with self._child_level():
self._write_line("base")
with self._child_level(single=True):
type.base.accept(self)
self._write_line("extension", last=True)
with self._child_level(single=True):
type.extension.accept(self)
def visit_function_type(self, type: m.FunctionType) -> None:
self._write_line("FunctionType")
@@ -332,7 +342,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
res += "<:" + param.bound.accept(self)
return res
def visit_property_stmt(self, stmt: m.PropertyStmt):
def visit_member_stmt(self, stmt: m.MemberStmt):
res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
return self.indented(res)
@@ -411,16 +421,19 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
def visit_complex_type(self, type: m.ComplexType) -> str:
res: str = "{\n"
self.level += 1
for prop in type.properties:
res += prop.accept(self)
for member in type.members:
res += member.accept(self)
res += "\n"
self.level -= 1
res += self.indented("}")
return res
def visit_extension_type(self, type: m.ExtensionType) -> str:
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
def visit_function_type(self, type: m.FunctionType) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
kw_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
args: list[str] = pos_args
if len(pos_args) != 0:
@@ -429,7 +442,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
args.append("*")
args += kw_args
return f"({', '.join(args)}) -> {type.returns.accept(self)}"
return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}"
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
res: str = ""