diff --git a/gen/midas.py b/gen/midas.py index 2184f86..16a4dd8 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -121,4 +121,17 @@ class ComplexType: properties: list[PropertyStmt] +class FunctionType: + pos_args: list[Argument] + kw_args: list[Argument] + returns: Type + + @dataclass(frozen=True, kw_only=True) + class Argument: + location: Optional[Location] = None + name: Optional[Token] + type: Type + required: bool + + ###< diff --git a/midas/ast/midas.py b/midas/ast/midas.py index d759079..00e71c8 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -233,6 +233,9 @@ class Type(ABC): @abstractmethod def visit_complex_type(self, type: ComplexType) -> T: ... + @abstractmethod + def visit_function_type(self, type: FunctionType) -> T: ... + @dataclass(frozen=True) class NamedType(Type): @@ -266,3 +269,20 @@ class ComplexType(Type): def accept(self, visitor: Type.Visitor[T]) -> T: return visitor.visit_complex_type(self) + + +@dataclass(frozen=True) +class FunctionType(Type): + pos_args: list[Argument] + kw_args: list[Argument] + returns: Type + + @dataclass(frozen=True, kw_only=True) + class Argument: + location: Optional[Location] = None + name: Optional[Token] + type: Type + required: bool + + def accept(self, visitor: Type.Visitor[T]) -> T: + return visitor.visit_function_type(self) diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 82bd0b4..5d109ef 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -270,6 +270,41 @@ class MidasAstPrinter( self._mark_last() prop.accept(self) + def visit_function_type(self, type: m.FunctionType) -> None: + self._write_line("FunctionType") + with self._child_level(): + self._write_line("pos_args") + with self._child_level(): + for i, arg in enumerate(type.pos_args): + self._idx = i + if i == len(type.pos_args) - 1: + self._mark_last() + self._print_function_arg(arg) + + self._write_line("kw_args") + with self._child_level(): + for i, arg in enumerate(type.kw_args): + self._idx = i + if i == len(type.kw_args) - 1: + self._mark_last() + self._print_function_arg(arg) + + self._write_line("returns", last=True) + with self._child_level(single=True): + type.returns.accept(self) + + def _print_function_arg(self, arg: m.FunctionType.Argument) -> None: + self._write_line("Argument") + with self._child_level(): + name: str = "None" + if arg.name is not None: + name = f'"{arg.name.lexeme}"' + self._write_line(f"name: {name}") + self._write_line("type") + with self._child_level(single=True): + arg.type.accept(self) + self._write_line(f"required: {arg.required}", last=True) + class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]): def __init__(self, indent: int = 4): @@ -383,6 +418,29 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] res += self.indented("}") return res + def visit_function_type(self, type: m.FunctionType) -> str: + pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args] + kw_args: list[str] = [self._print_arg(arg) for arg in type.pos_args] + args: list[str] = pos_args + + if len(pos_args) != 0: + args.append("/") + if len(kw_args) != 0: + args.append("*") + args += kw_args + + return f"({', '.join(args)}) -> {type.returns.accept(self)}" + + def _print_arg(self, arg: m.FunctionType.Argument) -> str: + res: str = "" + if arg.name is not None: + res += arg.name.lexeme + res += ": " + res += arg.type.accept(self) + if not arg.required: + res += "?" + return res + class PythonAstPrinter( AstPrinter, diff --git a/midas/checker/midas.py b/midas/checker/midas.py index a6d86a9..decb40c 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -8,6 +8,7 @@ from midas.checker.reporter import FileReporter, Reporter from midas.checker.types import ( AliasType, ComplexType, + Function, GenericType, Type, TypeVar, @@ -131,3 +132,40 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type prop.name.lexeme: prop.type.accept(self) for prop in type.properties } ) + + def visit_function_type(self, type: m.FunctionType) -> Type: + return Function( + name="", + pos_args=[ + Function.Argument( + pos=i, + name=arg.name.lexeme if arg.name is not None else str(i), + type=arg.type.accept(self), + required=arg.required, + ) + for i, arg in enumerate(type.pos_args) + ], + args=[], + kw_args=[ + Function.Argument( + pos=i, + name=arg.name.lexeme if arg.name is not None else str(i), + type=arg.type.accept(self), + required=arg.required, + ) + for i, arg in enumerate(type.kw_args, start=len(type.pos_args)) + ], + returns=type.returns.accept(self), + ) + + def _resolve_type_params(self, params: list[m.TypeParam]): + vars: list[TypeVar] = [] + for param in params: + name: str = param.name.lexeme + bound: Optional[Type] = None + if param.bound is not None: + bound = param.bound.accept(self) + var = TypeVar(name=name, bound=bound) + self._local_variables[name] = var + vars.append(var) + return vars diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py index af0fb4d..16fdf94 100644 --- a/midas/cli/highlighter.py +++ b/midas/cli/highlighter.py @@ -301,6 +301,12 @@ class MidasHighlighter( for prop in type.properties: prop.accept(self) + def visit_function_type(self, type: m.FunctionType) -> None: + self.wrap(type, "function") + for arg in type.pos_args + type.kw_args: + arg.type.accept(self) + type.returns.accept(self) + class DiagnosticsHighlighter(Highlighter): EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css" diff --git a/midas/lexer/midas.py b/midas/lexer/midas.py index 124ea09..c3246fc 100644 --- a/midas/lexer/midas.py +++ b/midas/lexer/midas.py @@ -50,12 +50,14 @@ class MidasLexer(Lexer): # self.add_token(TokenType.PLUS) case "-": self.add_token(TokenType.MINUS) - # case "*": - # self.add_token(TokenType.STAR) + case "*": + self.add_token(TokenType.STAR) case "/" if self.match("/"): self.scan_comment() case "/" if self.match("*"): self.scan_comment_multiline() + case "/": + self.add_token(TokenType.SLASH) case "\n": self.add_token(TokenType.NEWLINE) case " " | "\r" | "\t": diff --git a/midas/lexer/token.py b/midas/lexer/token.py index f08964a..74bf7b0 100644 --- a/midas/lexer/token.py +++ b/midas/lexer/token.py @@ -27,8 +27,8 @@ class TokenType(Enum): # Operators # PLUS = auto() MINUS = auto() - # STAR = auto() - # SLASH = auto() + STAR = auto() + SLASH = auto() GREATER = auto() GREATER_EQUAL = auto() LESS = auto() diff --git a/midas/parser/midas.py b/midas/parser/midas.py index 35c7a97..ce5d3f9 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -7,6 +7,7 @@ from midas.ast.midas import ( ConstraintType, Expr, ExtendStmt, + FunctionType, GenericType, GetExpr, GroupingExpr, @@ -24,7 +25,7 @@ from midas.ast.midas import ( VariableExpr, WildcardExpr, ) -from midas.lexer.token import Token, TokenType +from midas.lexer.token import KEYWORDS, Token, TokenType from midas.parser.base import Parser from midas.parser.errors import ParsingError @@ -108,7 +109,7 @@ class MidasParser(Parser): TypeStmt: the parsed type declaration statement """ keyword: Token = self.previous() - name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") + name: Token = self.consume_identifier("Expected type name") params: list[TypeParam] = self.type_params() self.consume(TokenType.EQUAL, "Expected '=' before type definition") @@ -136,7 +137,7 @@ class MidasParser(Parser): params: list[TypeParam] = [] while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET): - name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable") + name: Token = self.consume_identifier("Expected type variable") bound: Optional[Type] = None if self.match(TokenType.LESS): self.consume(TokenType.COLON, "Expected ':' after '<'") @@ -208,7 +209,7 @@ class MidasParser(Parser): return args def named_type(self) -> Type: - name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") + name: Token = self.consume_identifier("Expected type name") return NamedType( location=name.get_location(), name=name, @@ -324,9 +325,7 @@ class MidasParser(Parser): """ expr: Expr = self.primary() while self.match(TokenType.DOT): - name: Token = self.consume( - TokenType.IDENTIFIER, "Expected property name after '.'" - ) + name: Token = self.consume_identifier("Expected property name after '.'") location: Location = Location.span(expr.location, name.get_location()) expr = GetExpr(location=location, expr=expr, name=name) return expr @@ -350,7 +349,7 @@ class MidasParser(Parser): if self.match(TokenType.NUMBER): return LiteralExpr(location=token.get_location(), value=token.value) - if self.match(TokenType.IDENTIFIER): + if self.match_identifier(): return VariableExpr(location=token.get_location(), name=token) if self.match(TokenType.UNDERSCORE): @@ -363,6 +362,20 @@ class MidasParser(Parser): raise self.error(self.peek(), "Expected expression") + def consume_identifier(self, message: str = "Expected identifier") -> Token: + if not self.match_identifier(): + raise self.error(self.peek(), message) + return self.previous() + + def match_identifier(self) -> bool: + return self.match(TokenType.IDENTIFIER, *KEYWORDS.values()) + + def check_identifier(self) -> bool: + for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]: + if self.check(tt): + return True + return False + def property_stmt(self) -> PropertyStmt: """Parse a property statement @@ -371,7 +384,7 @@ class MidasParser(Parser): Returns: PropertyStmt: the parsed property statement """ - name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name") + name: Token = self.consume_identifier("Expected property name") self.consume(TokenType.COLON, "Expected ':' after property name") type: Type = self.type_expr() return PropertyStmt( @@ -439,9 +452,9 @@ class MidasParser(Parser): PredicateStmt: the parsed predicate declaration statement """ keyword: Token = self.previous() - name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name") + name: Token = self.consume_identifier("Expected predicate name") self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject") - subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name") + subject: Token = self.consume_identifier("Expected subject name") self.consume(TokenType.COLON, "Expected ':' after subject name") type: Type = self.type_expr() self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject") @@ -454,3 +467,48 @@ class MidasParser(Parser): type=type, condition=condition, ) + + def function(self) -> FunctionType: + l_paren: Token = self.consume( + TokenType.LEFT_PAREN, "Expected '(' before function parameters" + ) + pos_args: list[FunctionType.Argument] = [] + kw_args: list[FunctionType.Argument] = [] + + positional: bool = True + while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN): + if positional and ( + self.match(TokenType.STAR) or self.match(TokenType.SLASH) + ): + positional = False + else: + name: Optional[Token] = None + if self.check_identifier() and self.check_next(TokenType.COLON): + name = self.advance() + self.advance() + type: Type = self.type_expr() + required: bool = self.match(TokenType.QMARK) + arg = FunctionType.Argument( + location=None, + name=name, + type=type, + required=required, + ) + if positional: + pos_args.append(arg) + else: + kw_args.append(arg) + + if not self.match(TokenType.COMMA): + break + self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters") + + self.consume(TokenType.ARROW, "Expected '->' before result type") + result: Type = self.type_expr() + + return FunctionType( + location=l_paren.location_to(self.previous()), + pos_args=pos_args, + kw_args=kw_args, + returns=result, + ) diff --git a/tests/serializer/midas.py b/tests/serializer/midas.py index 947641e..2a5daf5 100644 --- a/tests/serializer/midas.py +++ b/tests/serializer/midas.py @@ -6,6 +6,7 @@ from midas.ast.midas import ( ConstraintType, Expr, ExtendStmt, + FunctionType, GenericType, GetExpr, GroupingExpr, @@ -164,3 +165,18 @@ class MidasAstJsonSerializer( "_type": "ComplexType", "properties": self._serialize_list(type.properties), } + + def visit_function_type(self, type: FunctionType) -> dict: + return { + "_type": "FunctionType", + "pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args], + "kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args], + "returns": type.returns.accept(self), + } + + def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict: + return { + "name": arg.name, + "type": arg.type.accept(self), + "required": arg.required, + }