feat: add function type to midas syntax
This commit is contained in:
13
gen/midas.py
13
gen/midas.py
@@ -121,4 +121,17 @@ class ComplexType:
|
||||
properties: list[PropertyStmt]
|
||||
|
||||
|
||||
class FunctionType:
|
||||
pos_args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
location: Optional[Location] = None
|
||||
name: Optional[Token]
|
||||
type: Type
|
||||
required: bool
|
||||
|
||||
|
||||
###<
|
||||
|
||||
@@ -233,6 +233,9 @@ class Type(ABC):
|
||||
@abstractmethod
|
||||
def visit_complex_type(self, type: ComplexType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_function_type(self, type: FunctionType) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NamedType(Type):
|
||||
@@ -266,3 +269,20 @@ class ComplexType(Type):
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_complex_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionType(Type):
|
||||
pos_args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
location: Optional[Location] = None
|
||||
name: Optional[Token]
|
||||
type: Type
|
||||
required: bool
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_function_type(self)
|
||||
|
||||
@@ -270,6 +270,41 @@ class MidasAstPrinter(
|
||||
self._mark_last()
|
||||
prop.accept(self)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||
self._write_line("FunctionType")
|
||||
with self._child_level():
|
||||
self._write_line("pos_args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.pos_args):
|
||||
self._idx = i
|
||||
if i == len(type.pos_args) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("kw_args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.kw_args):
|
||||
self._idx = i
|
||||
if i == len(type.kw_args) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("returns", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.returns.accept(self)
|
||||
|
||||
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
|
||||
self._write_line("Argument")
|
||||
with self._child_level():
|
||||
name: str = "None"
|
||||
if arg.name is not None:
|
||||
name = f'"{arg.name.lexeme}"'
|
||||
self._write_line(f"name: {name}")
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
arg.type.accept(self)
|
||||
self._write_line(f"required: {arg.required}", last=True)
|
||||
|
||||
|
||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
||||
def __init__(self, indent: int = 4):
|
||||
@@ -383,6 +418,29 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
|
||||
kw_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
|
||||
args: list[str] = pos_args
|
||||
|
||||
if len(pos_args) != 0:
|
||||
args.append("/")
|
||||
if len(kw_args) != 0:
|
||||
args.append("*")
|
||||
args += kw_args
|
||||
|
||||
return f"({', '.join(args)}) -> {type.returns.accept(self)}"
|
||||
|
||||
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
|
||||
res: str = ""
|
||||
if arg.name is not None:
|
||||
res += arg.name.lexeme
|
||||
res += ": "
|
||||
res += arg.type.accept(self)
|
||||
if not arg.required:
|
||||
res += "?"
|
||||
return res
|
||||
|
||||
|
||||
class PythonAstPrinter(
|
||||
AstPrinter,
|
||||
|
||||
@@ -8,6 +8,7 @@ from midas.checker.reporter import FileReporter, Reporter
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
ComplexType,
|
||||
Function,
|
||||
GenericType,
|
||||
Type,
|
||||
TypeVar,
|
||||
@@ -131,3 +132,40 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
prop.name.lexeme: prop.type.accept(self) for prop in type.properties
|
||||
}
|
||||
)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> Type:
|
||||
return Function(
|
||||
name="<anonymous>",
|
||||
pos_args=[
|
||||
Function.Argument(
|
||||
pos=i,
|
||||
name=arg.name.lexeme if arg.name is not None else str(i),
|
||||
type=arg.type.accept(self),
|
||||
required=arg.required,
|
||||
)
|
||||
for i, arg in enumerate(type.pos_args)
|
||||
],
|
||||
args=[],
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=i,
|
||||
name=arg.name.lexeme if arg.name is not None else str(i),
|
||||
type=arg.type.accept(self),
|
||||
required=arg.required,
|
||||
)
|
||||
for i, arg in enumerate(type.kw_args, start=len(type.pos_args))
|
||||
],
|
||||
returns=type.returns.accept(self),
|
||||
)
|
||||
|
||||
def _resolve_type_params(self, params: list[m.TypeParam]):
|
||||
vars: list[TypeVar] = []
|
||||
for param in params:
|
||||
name: str = param.name.lexeme
|
||||
bound: Optional[Type] = None
|
||||
if param.bound is not None:
|
||||
bound = param.bound.accept(self)
|
||||
var = TypeVar(name=name, bound=bound)
|
||||
self._local_variables[name] = var
|
||||
vars.append(var)
|
||||
return vars
|
||||
|
||||
@@ -301,6 +301,12 @@ class MidasHighlighter(
|
||||
for prop in type.properties:
|
||||
prop.accept(self)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||
self.wrap(type, "function")
|
||||
for arg in type.pos_args + type.kw_args:
|
||||
arg.type.accept(self)
|
||||
type.returns.accept(self)
|
||||
|
||||
|
||||
class DiagnosticsHighlighter(Highlighter):
|
||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
|
||||
|
||||
@@ -50,12 +50,14 @@ class MidasLexer(Lexer):
|
||||
# self.add_token(TokenType.PLUS)
|
||||
case "-":
|
||||
self.add_token(TokenType.MINUS)
|
||||
# case "*":
|
||||
# self.add_token(TokenType.STAR)
|
||||
case "*":
|
||||
self.add_token(TokenType.STAR)
|
||||
case "/" if self.match("/"):
|
||||
self.scan_comment()
|
||||
case "/" if self.match("*"):
|
||||
self.scan_comment_multiline()
|
||||
case "/":
|
||||
self.add_token(TokenType.SLASH)
|
||||
case "\n":
|
||||
self.add_token(TokenType.NEWLINE)
|
||||
case " " | "\r" | "\t":
|
||||
|
||||
@@ -27,8 +27,8 @@ class TokenType(Enum):
|
||||
# Operators
|
||||
# PLUS = auto()
|
||||
MINUS = auto()
|
||||
# STAR = auto()
|
||||
# SLASH = auto()
|
||||
STAR = auto()
|
||||
SLASH = auto()
|
||||
GREATER = auto()
|
||||
GREATER_EQUAL = auto()
|
||||
LESS = auto()
|
||||
|
||||
@@ -7,6 +7,7 @@ from midas.ast.midas import (
|
||||
ConstraintType,
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
FunctionType,
|
||||
GenericType,
|
||||
GetExpr,
|
||||
GroupingExpr,
|
||||
@@ -24,7 +25,7 @@ from midas.ast.midas import (
|
||||
VariableExpr,
|
||||
WildcardExpr,
|
||||
)
|
||||
from midas.lexer.token import Token, TokenType
|
||||
from midas.lexer.token import KEYWORDS, Token, TokenType
|
||||
from midas.parser.base import Parser
|
||||
from midas.parser.errors import ParsingError
|
||||
|
||||
@@ -108,7 +109,7 @@ class MidasParser(Parser):
|
||||
TypeStmt: the parsed type declaration statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
name: Token = self.consume_identifier("Expected type name")
|
||||
params: list[TypeParam] = self.type_params()
|
||||
|
||||
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
|
||||
@@ -136,7 +137,7 @@ class MidasParser(Parser):
|
||||
|
||||
params: list[TypeParam] = []
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
|
||||
name: Token = self.consume_identifier("Expected type variable")
|
||||
bound: Optional[Type] = None
|
||||
if self.match(TokenType.LESS):
|
||||
self.consume(TokenType.COLON, "Expected ':' after '<'")
|
||||
@@ -208,7 +209,7 @@ class MidasParser(Parser):
|
||||
return args
|
||||
|
||||
def named_type(self) -> Type:
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
name: Token = self.consume_identifier("Expected type name")
|
||||
return NamedType(
|
||||
location=name.get_location(),
|
||||
name=name,
|
||||
@@ -324,9 +325,7 @@ class MidasParser(Parser):
|
||||
"""
|
||||
expr: Expr = self.primary()
|
||||
while self.match(TokenType.DOT):
|
||||
name: Token = self.consume(
|
||||
TokenType.IDENTIFIER, "Expected property name after '.'"
|
||||
)
|
||||
name: Token = self.consume_identifier("Expected property name after '.'")
|
||||
location: Location = Location.span(expr.location, name.get_location())
|
||||
expr = GetExpr(location=location, expr=expr, name=name)
|
||||
return expr
|
||||
@@ -350,7 +349,7 @@ class MidasParser(Parser):
|
||||
if self.match(TokenType.NUMBER):
|
||||
return LiteralExpr(location=token.get_location(), value=token.value)
|
||||
|
||||
if self.match(TokenType.IDENTIFIER):
|
||||
if self.match_identifier():
|
||||
return VariableExpr(location=token.get_location(), name=token)
|
||||
|
||||
if self.match(TokenType.UNDERSCORE):
|
||||
@@ -363,6 +362,20 @@ class MidasParser(Parser):
|
||||
|
||||
raise self.error(self.peek(), "Expected expression")
|
||||
|
||||
def consume_identifier(self, message: str = "Expected identifier") -> Token:
|
||||
if not self.match_identifier():
|
||||
raise self.error(self.peek(), message)
|
||||
return self.previous()
|
||||
|
||||
def match_identifier(self) -> bool:
|
||||
return self.match(TokenType.IDENTIFIER, *KEYWORDS.values())
|
||||
|
||||
def check_identifier(self) -> bool:
|
||||
for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]:
|
||||
if self.check(tt):
|
||||
return True
|
||||
return False
|
||||
|
||||
def property_stmt(self) -> PropertyStmt:
|
||||
"""Parse a property statement
|
||||
|
||||
@@ -371,7 +384,7 @@ class MidasParser(Parser):
|
||||
Returns:
|
||||
PropertyStmt: the parsed property statement
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
|
||||
name: Token = self.consume_identifier("Expected property name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after property name")
|
||||
type: Type = self.type_expr()
|
||||
return PropertyStmt(
|
||||
@@ -439,9 +452,9 @@ class MidasParser(Parser):
|
||||
PredicateStmt: the parsed predicate declaration statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name")
|
||||
name: Token = self.consume_identifier("Expected predicate name")
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
||||
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
|
||||
subject: Token = self.consume_identifier("Expected subject name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
||||
@@ -454,3 +467,48 @@ class MidasParser(Parser):
|
||||
type=type,
|
||||
condition=condition,
|
||||
)
|
||||
|
||||
def function(self) -> FunctionType:
|
||||
l_paren: Token = self.consume(
|
||||
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
||||
)
|
||||
pos_args: list[FunctionType.Argument] = []
|
||||
kw_args: list[FunctionType.Argument] = []
|
||||
|
||||
positional: bool = True
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
|
||||
if positional and (
|
||||
self.match(TokenType.STAR) or self.match(TokenType.SLASH)
|
||||
):
|
||||
positional = False
|
||||
else:
|
||||
name: Optional[Token] = None
|
||||
if self.check_identifier() and self.check_next(TokenType.COLON):
|
||||
name = self.advance()
|
||||
self.advance()
|
||||
type: Type = self.type_expr()
|
||||
required: bool = self.match(TokenType.QMARK)
|
||||
arg = FunctionType.Argument(
|
||||
location=None,
|
||||
name=name,
|
||||
type=type,
|
||||
required=required,
|
||||
)
|
||||
if positional:
|
||||
pos_args.append(arg)
|
||||
else:
|
||||
kw_args.append(arg)
|
||||
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
result: Type = self.type_expr()
|
||||
|
||||
return FunctionType(
|
||||
location=l_paren.location_to(self.previous()),
|
||||
pos_args=pos_args,
|
||||
kw_args=kw_args,
|
||||
returns=result,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from midas.ast.midas import (
|
||||
ConstraintType,
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
FunctionType,
|
||||
GenericType,
|
||||
GetExpr,
|
||||
GroupingExpr,
|
||||
@@ -164,3 +165,18 @@ class MidasAstJsonSerializer(
|
||||
"_type": "ComplexType",
|
||||
"properties": self._serialize_list(type.properties),
|
||||
}
|
||||
|
||||
def visit_function_type(self, type: FunctionType) -> dict:
|
||||
return {
|
||||
"_type": "FunctionType",
|
||||
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
|
||||
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
|
||||
"returns": type.returns.accept(self),
|
||||
}
|
||||
|
||||
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
|
||||
return {
|
||||
"name": arg.name,
|
||||
"type": arg.type.accept(self),
|
||||
"required": arg.required,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user