diff --git a/midas/lexer/midas.py b/midas/lexer/midas.py index acc97d6..eec44d5 100644 --- a/midas/lexer/midas.py +++ b/midas/lexer/midas.py @@ -18,6 +18,8 @@ class MidasLexer(Lexer): self.add_token(TokenType.LEFT_BRACE) case "}": self.add_token(TokenType.RIGHT_BRACE) + case "|": + self.add_token(TokenType.PIPE) case "<": self.add_token( TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS @@ -40,8 +42,8 @@ class MidasLexer(Lexer): self.add_token(TokenType.AND) case "?": self.add_token(TokenType.QMARK) - # case ",": - # self.add_token(TokenType.COMMA) + case ",": + self.add_token(TokenType.COMMA) case "_" if not self.is_identifier_char(self.peek_next(), start=False): self.add_token(TokenType.UNDERSCORE) case "-" if self.match(">"): diff --git a/midas/lexer/token.py b/midas/lexer/token.py index a518a8b..9b30940 100644 --- a/midas/lexer/token.py +++ b/midas/lexer/token.py @@ -17,12 +17,13 @@ class TokenType(Enum): LEFT_BRACE = auto() RIGHT_BRACE = auto() COLON = auto() - # COMMA = auto() + COMMA = auto() UNDERSCORE = auto() ARROW = auto() AND = auto() QMARK = auto() DOT = auto() + PIPE = auto() # Operators # PLUS = auto() diff --git a/midas/parser/midas.py b/midas/parser/midas.py index 5a9d649..9af46da 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -3,22 +3,24 @@ from typing import Optional from midas.ast.location import Location from midas.ast.midas import ( BinaryExpr, - ComplexTypeStmt, + ComplexType, + ConstraintType, Expr, ExtendStmt, + GenericType, GetExpr, GroupingExpr, LiteralExpr, LogicalExpr, + NamedType, OpStmt, PredicateStmt, PropertyStmt, - SimpleTypeExpr, - SimpleTypeStmt, Stmt, - TemplateExpr, - TypeExpr, + Type, + TypeStmt, UnaryExpr, + UnionType, VariableExpr, WildcardExpr, ) @@ -81,7 +83,7 @@ class MidasParser(Parser): self.synchronize() return None - def type_declaration(self) -> SimpleTypeStmt | ComplexTypeStmt: + def type_declaration(self) -> TypeStmt: """Parse a type declaration A type declaration can either be a simple type alias or a new complex type. @@ -107,33 +109,22 @@ class MidasParser(Parser): """ keyword: Token = self.previous() name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") - template: Optional[TemplateExpr] = None + params: list[TypeStmt.Param] = [] if self.check(TokenType.LEFT_BRACKET): - template = self.template_expr() + params = self.type_stmt_params() - if self.match(TokenType.LEFT_PAREN): - base: TypeExpr = self.type_expr() - self.consume(TokenType.RIGHT_PAREN, "Unclosed base type parenthesis") - constraint: Optional[Expr] = None - if self.match(TokenType.WHERE): - constraint = self.constraint() - return SimpleTypeStmt( - location=keyword.location_to(self.previous()), - name=name, - template=template, - base=base, - constraint=constraint, - ) - else: - properties: list[PropertyStmt] = self.type_properties() - return ComplexTypeStmt( - location=keyword.location_to(self.previous()), - name=name, - template=template, - properties=properties, - ) + self.consume(TokenType.EQUAL, "Expected '=' before type definition") - def template_expr(self) -> TemplateExpr: + type: Type = self.type_expr() + + return TypeStmt( + location=keyword.location_to(self.previous()), + name=name, + params=params, + type=type, + ) + + def type_stmt_params(self) -> list[TypeStmt.Param]: """Parse a generic template expression A template is written `[TypeExpr]` @@ -141,16 +132,27 @@ class MidasParser(Parser): Returns: TemplateExpr: the parsed template expression """ - left: Token = self.consume( - TokenType.LEFT_BRACKET, "Missing '[' before template expression" - ) - type: TypeExpr = self.type_expr() - right: Token = self.consume( - TokenType.RIGHT_BRACKET, "Missing ']' after template expression" - ) - return TemplateExpr(location=left.location_to(right), type=type) + self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression") + params: list[TypeStmt.Param] = [] + while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET): + name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable") + bound: Optional[Type] = None + if self.match(TokenType.LESS): + self.consume(TokenType.COLON, "Expected ':' after '<'") + bound = self.type_expr() + params.append( + TypeStmt.Param( + location=name.location_to(self.previous()), + name=name, + bound=bound, + ) + ) + if not self.match(TokenType.COMMA): + break + self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression") + return params - def type_expr(self) -> TypeExpr: + def type_expr(self) -> Type: """Parse a type expression A type is an identifier, optionally followed by a template expression. @@ -159,30 +161,93 @@ class MidasParser(Parser): Returns: TypeExpr: the parsed type expression """ - name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") - template: Optional[TemplateExpr] = None - if self.check(TokenType.LEFT_BRACKET): - template = self.template_expr() - optional: bool = self.match(TokenType.QMARK) - return TypeExpr( - location=name.location_to(self.previous()), - name=name, - template=template, - optional=optional, + return self.union_type() + + def union_type(self) -> Type: + types: list[Type] = [self.constraint_type()] + while self.match(TokenType.PIPE): + types.append(self.constraint_type()) + if len(types) == 1: + return types[0] + return UnionType( + location=Location.span(types[0].location, types[-1].location), + types=types, ) - def simple_type_expr(self) -> SimpleTypeExpr: - """Parse a simple type expression + def constraint_type(self) -> Type: + type: Type = self.base_type() + if self.match(TokenType.WHERE): + constraint: Expr = self.constraint() + return ConstraintType( + location=Location.span(type.location, constraint.location), + type=type, + constraint=constraint, + ) + return type - A simple type is just an identifier optionally followed by a '?' + def base_type(self) -> Type: + if self.match(TokenType.LEFT_PAREN): + type: Type = self.type_expr() + self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis") + return type + + if self.check(TokenType.LEFT_BRACE): + return self.complex_type() + + return self.generic_type() + + def generic_type(self) -> Type: + type: Type = self.named_type() + if self.check(TokenType.LEFT_BRACKET): + params: list[Type] = self.type_params() + return GenericType( + location=Location.span(type.location, self.previous().get_location()), + type=type, + params=params, + ) + return type + + def type_params(self) -> list[Type]: + params: list[Type] = [] + self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters") + while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET): + params.append(self.type_expr()) + if not self.match(TokenType.COMMA): + break + self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters") + return params + + def named_type(self) -> Type: + name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") + return NamedType( + location=name.get_location(), + name=name, + ) + + def complex_type(self) -> Type: + """Parse a type definition body + + A type definition body is a set of whitespace-separated + property statements enclosed in curly braces Returns: - SimpleTypeExpr: the parsed simple type expression + list[PropertyStmt]: the parsed type properties """ - name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") - optional: bool = self.match(TokenType.QMARK) - return SimpleTypeExpr( - location=name.location_to(self.previous()), name=name, optional=optional + left: Token = self.consume( + TokenType.LEFT_BRACE, "Expected '{' to start type body" + ) + properties: list[PropertyStmt] = [] + names: set[str] = set() + while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end(): + prop: PropertyStmt = self.property_stmt() + if prop.name.lexeme in names: + raise self.error(prop.name, "Duplicate property") + names.add(prop.name.lexeme) + properties.append(prop) + right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body") + return ComplexType( + location=left.location_to(right), + properties=properties, ) def constraint(self) -> Expr: @@ -308,27 +373,6 @@ class MidasParser(Parser): raise self.error(self.peek(), "Expected expression") - def type_properties(self) -> list[PropertyStmt]: - """Parse a type definition body - - A type definition body is a set of whitespace-separated - property statements enclosed in curly braces - - Returns: - list[PropertyStmt]: the parsed type properties - """ - self.consume(TokenType.LEFT_BRACE, "Expected '{' to start type body") - properties: list[PropertyStmt] = [] - names: set[str] = set() - while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end(): - prop: PropertyStmt = self.property_stmt() - if prop.name.lexeme in names: - raise self.error(prop.name, "Duplicate property") - names.add(prop.name.lexeme) - properties.append(prop) - self.consume(TokenType.RIGHT_BRACE, "Unclosed type body") - return properties - def property_stmt(self) -> PropertyStmt: """Parse a property statement @@ -339,15 +383,11 @@ class MidasParser(Parser): """ name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name") self.consume(TokenType.COLON, "Expected ':' after property name") - type: TypeExpr = self.type_expr() - constraint: Optional[Expr] = None - if self.match(TokenType.WHERE): - constraint = self.constraint() + type: Type = self.type_expr() return PropertyStmt( location=name.location_to(self.previous()), name=name, type=type, - constraint=constraint, ) def extend_declaration(self) -> ExtendStmt: @@ -359,7 +399,7 @@ class MidasParser(Parser): ExtendStmt: the parsed extension statement """ keyword: Token = self.previous() - type: TypeExpr = self.type_expr() + type: Type = self.type_expr() self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body") operations: list[OpStmt] = [] while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE): @@ -380,11 +420,11 @@ class MidasParser(Parser): name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name") self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type") - operand: TypeExpr = self.type_expr() + operand: Type = self.type_expr() self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type") self.consume(TokenType.ARROW, "Expected '->' before result type") - result: TypeExpr = self.type_expr() + result: Type = self.type_expr() return OpStmt( location=keyword.location_to(self.previous()), @@ -406,7 +446,7 @@ class MidasParser(Parser): self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject") subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name") self.consume(TokenType.COLON, "Expected ':' after subject name") - type: TypeExpr = self.type_expr() + type: Type = self.type_expr() self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject") self.consume(TokenType.EQUAL, "Expected '=' after predicate subject") condition: Expr = self.constraint()