diff --git a/gen/midas.py b/gen/midas.py index 7187554..dc0efd0 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -13,40 +13,38 @@ from midas.lexer.token import Token ###> Stmt | Statements -class SimpleTypeStmt: +class TypeStmt: name: Token - template: Optional[TemplateExpr] - base: TypeExpr - constraint: Optional[Expr] + params: list[Param] + type: Type - -class ComplexTypeStmt: - name: Token - template: Optional[TemplateExpr] - properties: list[PropertyStmt] + @dataclass(frozen=True, kw_only=True) + class Param: + location: Location + name: Token + bound: Optional[Type] class PropertyStmt: name: Token - type: TypeExpr - constraint: Optional[Expr] + type: Type class ExtendStmt: - type: TypeExpr + type: Type operations: list[OpStmt] class OpStmt: name: Token - operand: TypeExpr - result: TypeExpr + operand: Type + result: Type class PredicateStmt: name: Token subject: Token - type: TypeExpr + type: Type condition: Expr @@ -54,9 +52,6 @@ class PredicateStmt: ###> Expr | Expressions -class SimpleTypeExpr: - name: Token - optional: bool class LogicalExpr: @@ -97,14 +92,31 @@ class WildcardExpr: token: Token -class TemplateExpr: - type: TypeExpr +###< + +###> Type | Types -class TypeExpr: +class NamedType: name: Token - template: Optional[TemplateExpr] - optional: bool + + +class GenericType: + type: Type + params: list[Type] + + +class ConstraintType: + type: Type + constraint: Expr + + +class UnionType: + types: list[Type] + + +class ComplexType: + properties: list[PropertyStmt] ###< diff --git a/midas/ast/midas.py b/midas/ast/midas.py index c307e85..a18b460 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -28,10 +28,7 @@ class Stmt(ABC): class Visitor(ABC, Generic[T]): @abstractmethod - def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> T: ... - - @abstractmethod - def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> T: ... + def visit_type_stmt(self, stmt: TypeStmt) -> T: ... @abstractmethod def visit_property_stmt(self, stmt: PropertyStmt) -> T: ... @@ -47,31 +44,25 @@ class Stmt(ABC): @dataclass(frozen=True) -class SimpleTypeStmt(Stmt): +class TypeStmt(Stmt): name: Token - template: Optional[TemplateExpr] - base: TypeExpr - constraint: Optional[Expr] + params: list[Param] + type: Type + + @dataclass(frozen=True, kw_only=True) + class Param: + location: Location + name: Token + bound: Optional[Type] def accept(self, visitor: Stmt.Visitor[T]) -> T: - return visitor.visit_simple_type_stmt(self) - - -@dataclass(frozen=True) -class ComplexTypeStmt(Stmt): - name: Token - template: Optional[TemplateExpr] - properties: list[PropertyStmt] - - def accept(self, visitor: Stmt.Visitor[T]) -> T: - return visitor.visit_complex_type_stmt(self) + return visitor.visit_type_stmt(self) @dataclass(frozen=True) class PropertyStmt(Stmt): name: Token - type: TypeExpr - constraint: Optional[Expr] + type: Type def accept(self, visitor: Stmt.Visitor[T]) -> T: return visitor.visit_property_stmt(self) @@ -79,7 +70,7 @@ class PropertyStmt(Stmt): @dataclass(frozen=True) class ExtendStmt(Stmt): - type: TypeExpr + type: Type operations: list[OpStmt] def accept(self, visitor: Stmt.Visitor[T]) -> T: @@ -89,8 +80,8 @@ class ExtendStmt(Stmt): @dataclass(frozen=True) class OpStmt(Stmt): name: Token - operand: TypeExpr - result: TypeExpr + operand: Type + result: Type def accept(self, visitor: Stmt.Visitor[T]) -> T: return visitor.visit_op_stmt(self) @@ -100,7 +91,7 @@ class OpStmt(Stmt): class PredicateStmt(Stmt): name: Token subject: Token - type: TypeExpr + type: Type condition: Expr def accept(self, visitor: Stmt.Visitor[T]) -> T: @@ -120,9 +111,6 @@ class Expr(ABC): def accept(self, visitor: Visitor[T]) -> T: ... class Visitor(ABC, Generic[T]): - @abstractmethod - def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> T: ... - @abstractmethod def visit_logical_expr(self, expr: LogicalExpr) -> T: ... @@ -147,21 +135,6 @@ class Expr(ABC): @abstractmethod def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ... - @abstractmethod - def visit_template_expr(self, expr: TemplateExpr) -> T: ... - - @abstractmethod - def visit_type_expr(self, expr: TypeExpr) -> T: ... - - -@dataclass(frozen=True) -class SimpleTypeExpr(Expr): - name: Token - optional: bool - - def accept(self, visitor: Expr.Visitor[T]) -> T: - return visitor.visit_simple_type_expr(self) - @dataclass(frozen=True) class LogicalExpr(Expr): @@ -233,19 +206,72 @@ class WildcardExpr(Expr): return visitor.visit_wildcard_expr(self) -@dataclass(frozen=True) -class TemplateExpr(Expr): - type: TypeExpr +######### +# Types # +######### - def accept(self, visitor: Expr.Visitor[T]) -> T: - return visitor.visit_template_expr(self) + +@dataclass(frozen=True, kw_only=True) +class Type(ABC): + location: Location + + @abstractmethod + def accept(self, visitor: Visitor[T]) -> T: ... + + class Visitor(ABC, Generic[T]): + @abstractmethod + def visit_named_type(self, type: NamedType) -> T: ... + + @abstractmethod + def visit_generic_type(self, type: GenericType) -> T: ... + + @abstractmethod + def visit_constraint_type(self, type: ConstraintType) -> T: ... + + @abstractmethod + def visit_union_type(self, type: UnionType) -> T: ... + + @abstractmethod + def visit_complex_type(self, type: ComplexType) -> T: ... @dataclass(frozen=True) -class TypeExpr(Expr): +class NamedType(Type): name: Token - template: Optional[TemplateExpr] - optional: bool - def accept(self, visitor: Expr.Visitor[T]) -> T: - return visitor.visit_type_expr(self) + def accept(self, visitor: Type.Visitor[T]) -> T: + return visitor.visit_named_type(self) + + +@dataclass(frozen=True) +class GenericType(Type): + type: Type + params: list[Type] + + def accept(self, visitor: Type.Visitor[T]) -> T: + return visitor.visit_generic_type(self) + + +@dataclass(frozen=True) +class ConstraintType(Type): + type: Type + constraint: Expr + + def accept(self, visitor: Type.Visitor[T]) -> T: + return visitor.visit_constraint_type(self) + + +@dataclass(frozen=True) +class UnionType(Type): + types: list[Type] + + def accept(self, visitor: Type.Visitor[T]) -> T: + return visitor.visit_union_type(self) + + +@dataclass(frozen=True) +class ComplexType(Type): + properties: list[PropertyStmt] + + def accept(self, visitor: Type.Visitor[T]) -> T: + return visitor.visit_complex_type(self) diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 45e4a64..ed9e069 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -85,40 +85,39 @@ class AstPrinter(Generic[T]): child.accept(self) -class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]): +class MidasAstPrinter( + AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None], m.Type.Visitor[None] +): # Statements - def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt): - self._write_line("SimpleTypeStmt") + def visit_type_stmt(self, stmt: m.TypeStmt) -> None: + self._write_line("TypeStmt") with self._child_level(): self._write_line(f'name: "{stmt.name.lexeme}"') - self._write_optional_child("template", stmt.template) - self._write_line("base") - with self._child_level(single=True): - stmt.base.accept(self) - self._write_optional_child("constraint", stmt.constraint, last=True) - - def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt): - self._write_line("ComplexTypeStmt") - with self._child_level(): - self._write_line(f'name: "{stmt.name.lexeme}"') - self._write_optional_child("template", stmt.template) - self._write_line("properties", last=True) + self._write_line("params") with self._child_level(): - for i, prop in enumerate(stmt.properties): + for i, param in enumerate(stmt.params): self._idx = i - if i == len(stmt.properties) - 1: + if i == len(stmt.params) - 1: self._mark_last() - prop.accept(self) + self._print_type_stmt_param(param) + self._write_line("type", last=True) + with self._child_level(single=True): + stmt.type.accept(self) + + def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None: + self._write_line("Param") + with self._child_level(): + self._write_line(f'name: "{param.name.lexeme}"') + self._write_optional_child("bound", param.bound, last=True) def visit_property_stmt(self, stmt: m.PropertyStmt): self._write_line("PropertyStmt") with self._child_level(): self._write_line(f'name: "{stmt.name.lexeme}"') - self._write_line("type") + self._write_line("type", last=True) with self._child_level(single=True): stmt.type.accept(self) - self._write_optional_child("constraint", stmt.constraint, last=True) def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: self._write_line("ExtendStmt") @@ -161,12 +160,6 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]): # Expressions - def visit_simple_type_expr(self, expr: m.SimpleTypeExpr): - self._write_line("SimpleTypeExpr") - with self._child_level(): - self._write_line(f'name: "{expr.name.lexeme}"') - self._write_line(f"optional: {expr.optional}", last=True) - def visit_logical_expr(self, expr: m.LogicalExpr): self._write_line("LogicalExpr") with self._child_level(): @@ -230,22 +223,59 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]): def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: self._write_line("WildcardExpr") - def visit_template_expr(self, expr: m.TemplateExpr) -> None: - self._write_line("TemplateExpr") - with self._child_level(single=True): + def visit_named_type(self, type: m.NamedType) -> None: + self._write_line("NamedType") + with self._child_level(): + self._write_line(f'name: "{type.name.lexeme}"', last=True) + + def visit_generic_type(self, type: m.GenericType) -> None: + self._write_line("GenericType") + with self._child_level(): + self._write_line("type") + with self._child_level(): + type.type.accept(self) + self._write_line("params", last=True) + with self._child_level(): + for i, param in enumerate(type.params): + self._idx = i + if i == len(type.params) - 1: + self._mark_last() + param.accept(self) + + def visit_constraint_type(self, type: m.ConstraintType) -> None: + self._write_line("ConstraintType") + with self._child_level(): self._write_line("type") with self._child_level(single=True): - expr.type.accept(self) + type.type.accept(self) + self._write_line("constraint", last=True) + with self._child_level(single=True): + type.constraint.accept(self) - def visit_type_expr(self, expr: m.TypeExpr): - self._write_line("TypeExpr") + def visit_union_type(self, type: m.UnionType) -> None: + self._write_line("UnionType") with self._child_level(): - self._write_line(f'name: "{expr.name.lexeme}"') - self._write_optional_child("template", expr.template) - self._write_line(f"optional: {expr.optional}", last=True) + self._write_line("types", last=True) + with self._child_level(): + for i, type_ in enumerate(type.types): + self._idx = i + if i == len(type.types) - 1: + self._mark_last() + type_.accept(self) + + def visit_complex_type(self, type: m.ComplexType) -> None: + self._write_line("ComplexType") + with self._child_level(): + self._write_line("properties", last=True) + with self._child_level(): + for i, prop in enumerate(type.properties): + self._idx = i + if i == len(type.properties) - 1: + self._mark_last() + prop.accept(self) -class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): +class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]): def __init__(self, indent: int = 4): self.indent: int = indent self.level: int = 0 @@ -257,29 +287,24 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): self.level = 0 return expr.accept(self) - def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt): - template: str = stmt.template.accept(self) if stmt.template is not None else "" - res: str = f"type {stmt.name.lexeme}{template}({stmt.base.accept(self)})" - if stmt.constraint is not None: - res += " where " + stmt.constraint.accept(self) + def visit_type_stmt(self, stmt: m.TypeStmt) -> str: + template: str = "" + if len(stmt.params) != 0: + params: list[str] = [ + self._print_type_template_param(param) for param in stmt.params + ] + template = f"[{', '.join(params)}]" + res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}" return self.indented(res) - def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt): - template: str = stmt.template.accept(self) if stmt.template is not None else "" - res: str = self.indented(f"type {stmt.name.lexeme}{template}") - res += " {\n" - self.level += 1 - for prop in stmt.properties: - res += prop.accept(self) - res += "\n" - self.level -= 1 - res += self.indented("}") + def _print_type_template_param(self, param: m.TypeStmt.Param) -> str: + res: str = param.name.lexeme + if param.bound is not None: + res += "<:" + param.bound.accept(self) return res def visit_property_stmt(self, stmt: m.PropertyStmt): res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}" - if stmt.constraint is not None: - res += " where " + stmt.constraint.accept(self) return self.indented(res) def visit_extend_stmt(self, stmt: m.ExtendStmt): @@ -304,9 +329,6 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): condition: str = stmt.condition.accept(self) return self.indented(f"predicate {name}({subject}: {type}) = {condition}") - def visit_simple_type_expr(self, expr: m.SimpleTypeExpr): - return f"{expr.name.lexeme}{'?' if expr.optional else ''}" - def visit_logical_expr(self, expr: m.LogicalExpr): left: str = expr.left.accept(self) operator: str = expr.operator.lexeme @@ -342,12 +364,34 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): def visit_wildcard_expr(self, expr: m.WildcardExpr): return "_" - def visit_template_expr(self, expr: m.TemplateExpr): - return f"[{expr.type.accept(self)}]" + def visit_named_type(self, type: m.NamedType) -> str: + return type.name.lexeme - def visit_type_expr(self, expr: m.TypeExpr): - template: str = expr.template.accept(self) if expr.template is not None else "" - return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}" + def visit_generic_type(self, type: m.GenericType) -> str: + res: str = type.type.accept(self) + if len(type.params) != 0: + params: list[str] = [param.accept(self) for param in type.params] + res += f"[{', '.join(params)}]" + return res + + def visit_constraint_type(self, type: m.ConstraintType) -> str: + res: str = type.type.accept(self) + res += " where " + type.constraint.accept(self) + return res + + def visit_union_type(self, type: m.UnionType) -> str: + types: list[str] = [type_.accept(self) for type_ in type.types] + return " | ".join(types) + + def visit_complex_type(self, type: m.ComplexType) -> str: + res: str = "{\n" + self.level += 1 + for prop in type.properties: + res += prop.accept(self) + res += "\n" + self.level -= 1 + res += self.indented("}") + return res class PythonAstPrinter( @@ -600,11 +644,11 @@ class PythonAstPrinter( self._write_line("test") with self._child_level(single=True): expr.test.accept(self) - + self._write_line("if_true") with self._child_level(single=True): expr.if_true.accept(self) - + self._write_line("if_false", last=True) with self._child_level(single=True): expr.if_false.accept(self)