feat: add function type to midas syntax
This commit is contained in:
13
gen/midas.py
13
gen/midas.py
@@ -121,4 +121,17 @@ class ComplexType:
|
|||||||
properties: list[PropertyStmt]
|
properties: list[PropertyStmt]
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionType:
|
||||||
|
pos_args: list[Argument]
|
||||||
|
kw_args: list[Argument]
|
||||||
|
returns: Type
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Argument:
|
||||||
|
location: Optional[Location] = None
|
||||||
|
name: Optional[Token]
|
||||||
|
type: Type
|
||||||
|
required: bool
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
|
|||||||
@@ -233,6 +233,9 @@ class Type(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_complex_type(self, type: ComplexType) -> T: ...
|
def visit_complex_type(self, type: ComplexType) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_function_type(self, type: FunctionType) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class NamedType(Type):
|
class NamedType(Type):
|
||||||
@@ -266,3 +269,20 @@ class ComplexType(Type):
|
|||||||
|
|
||||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||||
return visitor.visit_complex_type(self)
|
return visitor.visit_complex_type(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FunctionType(Type):
|
||||||
|
pos_args: list[Argument]
|
||||||
|
kw_args: list[Argument]
|
||||||
|
returns: Type
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Argument:
|
||||||
|
location: Optional[Location] = None
|
||||||
|
name: Optional[Token]
|
||||||
|
type: Type
|
||||||
|
required: bool
|
||||||
|
|
||||||
|
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_function_type(self)
|
||||||
|
|||||||
@@ -270,6 +270,41 @@ class MidasAstPrinter(
|
|||||||
self._mark_last()
|
self._mark_last()
|
||||||
prop.accept(self)
|
prop.accept(self)
|
||||||
|
|
||||||
|
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||||
|
self._write_line("FunctionType")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("pos_args")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(type.pos_args):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(type.pos_args) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_function_arg(arg)
|
||||||
|
|
||||||
|
self._write_line("kw_args")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(type.kw_args):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(type.kw_args) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_function_arg(arg)
|
||||||
|
|
||||||
|
self._write_line("returns", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
type.returns.accept(self)
|
||||||
|
|
||||||
|
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
|
||||||
|
self._write_line("Argument")
|
||||||
|
with self._child_level():
|
||||||
|
name: str = "None"
|
||||||
|
if arg.name is not None:
|
||||||
|
name = f'"{arg.name.lexeme}"'
|
||||||
|
self._write_line(f"name: {name}")
|
||||||
|
self._write_line("type")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
arg.type.accept(self)
|
||||||
|
self._write_line(f"required: {arg.required}", last=True)
|
||||||
|
|
||||||
|
|
||||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
||||||
def __init__(self, indent: int = 4):
|
def __init__(self, indent: int = 4):
|
||||||
@@ -383,6 +418,29 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
res += self.indented("}")
|
res += self.indented("}")
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||||
|
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
|
||||||
|
kw_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
|
||||||
|
args: list[str] = pos_args
|
||||||
|
|
||||||
|
if len(pos_args) != 0:
|
||||||
|
args.append("/")
|
||||||
|
if len(kw_args) != 0:
|
||||||
|
args.append("*")
|
||||||
|
args += kw_args
|
||||||
|
|
||||||
|
return f"({', '.join(args)}) -> {type.returns.accept(self)}"
|
||||||
|
|
||||||
|
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
|
||||||
|
res: str = ""
|
||||||
|
if arg.name is not None:
|
||||||
|
res += arg.name.lexeme
|
||||||
|
res += ": "
|
||||||
|
res += arg.type.accept(self)
|
||||||
|
if not arg.required:
|
||||||
|
res += "?"
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class PythonAstPrinter(
|
class PythonAstPrinter(
|
||||||
AstPrinter,
|
AstPrinter,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from midas.checker.reporter import FileReporter, Reporter
|
|||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
AliasType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
@@ -131,3 +132,40 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
prop.name.lexeme: prop.type.accept(self) for prop in type.properties
|
prop.name.lexeme: prop.type.accept(self) for prop in type.properties
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def visit_function_type(self, type: m.FunctionType) -> Type:
|
||||||
|
return Function(
|
||||||
|
name="<anonymous>",
|
||||||
|
pos_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=i,
|
||||||
|
name=arg.name.lexeme if arg.name is not None else str(i),
|
||||||
|
type=arg.type.accept(self),
|
||||||
|
required=arg.required,
|
||||||
|
)
|
||||||
|
for i, arg in enumerate(type.pos_args)
|
||||||
|
],
|
||||||
|
args=[],
|
||||||
|
kw_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=i,
|
||||||
|
name=arg.name.lexeme if arg.name is not None else str(i),
|
||||||
|
type=arg.type.accept(self),
|
||||||
|
required=arg.required,
|
||||||
|
)
|
||||||
|
for i, arg in enumerate(type.kw_args, start=len(type.pos_args))
|
||||||
|
],
|
||||||
|
returns=type.returns.accept(self),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _resolve_type_params(self, params: list[m.TypeParam]):
|
||||||
|
vars: list[TypeVar] = []
|
||||||
|
for param in params:
|
||||||
|
name: str = param.name.lexeme
|
||||||
|
bound: Optional[Type] = None
|
||||||
|
if param.bound is not None:
|
||||||
|
bound = param.bound.accept(self)
|
||||||
|
var = TypeVar(name=name, bound=bound)
|
||||||
|
self._local_variables[name] = var
|
||||||
|
vars.append(var)
|
||||||
|
return vars
|
||||||
|
|||||||
@@ -301,6 +301,12 @@ class MidasHighlighter(
|
|||||||
for prop in type.properties:
|
for prop in type.properties:
|
||||||
prop.accept(self)
|
prop.accept(self)
|
||||||
|
|
||||||
|
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||||
|
self.wrap(type, "function")
|
||||||
|
for arg in type.pos_args + type.kw_args:
|
||||||
|
arg.type.accept(self)
|
||||||
|
type.returns.accept(self)
|
||||||
|
|
||||||
|
|
||||||
class DiagnosticsHighlighter(Highlighter):
|
class DiagnosticsHighlighter(Highlighter):
|
||||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
|
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
|
||||||
|
|||||||
@@ -50,12 +50,14 @@ class MidasLexer(Lexer):
|
|||||||
# self.add_token(TokenType.PLUS)
|
# self.add_token(TokenType.PLUS)
|
||||||
case "-":
|
case "-":
|
||||||
self.add_token(TokenType.MINUS)
|
self.add_token(TokenType.MINUS)
|
||||||
# case "*":
|
case "*":
|
||||||
# self.add_token(TokenType.STAR)
|
self.add_token(TokenType.STAR)
|
||||||
case "/" if self.match("/"):
|
case "/" if self.match("/"):
|
||||||
self.scan_comment()
|
self.scan_comment()
|
||||||
case "/" if self.match("*"):
|
case "/" if self.match("*"):
|
||||||
self.scan_comment_multiline()
|
self.scan_comment_multiline()
|
||||||
|
case "/":
|
||||||
|
self.add_token(TokenType.SLASH)
|
||||||
case "\n":
|
case "\n":
|
||||||
self.add_token(TokenType.NEWLINE)
|
self.add_token(TokenType.NEWLINE)
|
||||||
case " " | "\r" | "\t":
|
case " " | "\r" | "\t":
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ class TokenType(Enum):
|
|||||||
# Operators
|
# Operators
|
||||||
# PLUS = auto()
|
# PLUS = auto()
|
||||||
MINUS = auto()
|
MINUS = auto()
|
||||||
# STAR = auto()
|
STAR = auto()
|
||||||
# SLASH = auto()
|
SLASH = auto()
|
||||||
GREATER = auto()
|
GREATER = auto()
|
||||||
GREATER_EQUAL = auto()
|
GREATER_EQUAL = auto()
|
||||||
LESS = auto()
|
LESS = auto()
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from midas.ast.midas import (
|
|||||||
ConstraintType,
|
ConstraintType,
|
||||||
Expr,
|
Expr,
|
||||||
ExtendStmt,
|
ExtendStmt,
|
||||||
|
FunctionType,
|
||||||
GenericType,
|
GenericType,
|
||||||
GetExpr,
|
GetExpr,
|
||||||
GroupingExpr,
|
GroupingExpr,
|
||||||
@@ -24,7 +25,7 @@ from midas.ast.midas import (
|
|||||||
VariableExpr,
|
VariableExpr,
|
||||||
WildcardExpr,
|
WildcardExpr,
|
||||||
)
|
)
|
||||||
from midas.lexer.token import Token, TokenType
|
from midas.lexer.token import KEYWORDS, Token, TokenType
|
||||||
from midas.parser.base import Parser
|
from midas.parser.base import Parser
|
||||||
from midas.parser.errors import ParsingError
|
from midas.parser.errors import ParsingError
|
||||||
|
|
||||||
@@ -108,7 +109,7 @@ class MidasParser(Parser):
|
|||||||
TypeStmt: the parsed type declaration statement
|
TypeStmt: the parsed type declaration statement
|
||||||
"""
|
"""
|
||||||
keyword: Token = self.previous()
|
keyword: Token = self.previous()
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
name: Token = self.consume_identifier("Expected type name")
|
||||||
params: list[TypeParam] = self.type_params()
|
params: list[TypeParam] = self.type_params()
|
||||||
|
|
||||||
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
|
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
|
||||||
@@ -136,7 +137,7 @@ class MidasParser(Parser):
|
|||||||
|
|
||||||
params: list[TypeParam] = []
|
params: list[TypeParam] = []
|
||||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
|
name: Token = self.consume_identifier("Expected type variable")
|
||||||
bound: Optional[Type] = None
|
bound: Optional[Type] = None
|
||||||
if self.match(TokenType.LESS):
|
if self.match(TokenType.LESS):
|
||||||
self.consume(TokenType.COLON, "Expected ':' after '<'")
|
self.consume(TokenType.COLON, "Expected ':' after '<'")
|
||||||
@@ -208,7 +209,7 @@ class MidasParser(Parser):
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
def named_type(self) -> Type:
|
def named_type(self) -> Type:
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
name: Token = self.consume_identifier("Expected type name")
|
||||||
return NamedType(
|
return NamedType(
|
||||||
location=name.get_location(),
|
location=name.get_location(),
|
||||||
name=name,
|
name=name,
|
||||||
@@ -324,9 +325,7 @@ class MidasParser(Parser):
|
|||||||
"""
|
"""
|
||||||
expr: Expr = self.primary()
|
expr: Expr = self.primary()
|
||||||
while self.match(TokenType.DOT):
|
while self.match(TokenType.DOT):
|
||||||
name: Token = self.consume(
|
name: Token = self.consume_identifier("Expected property name after '.'")
|
||||||
TokenType.IDENTIFIER, "Expected property name after '.'"
|
|
||||||
)
|
|
||||||
location: Location = Location.span(expr.location, name.get_location())
|
location: Location = Location.span(expr.location, name.get_location())
|
||||||
expr = GetExpr(location=location, expr=expr, name=name)
|
expr = GetExpr(location=location, expr=expr, name=name)
|
||||||
return expr
|
return expr
|
||||||
@@ -350,7 +349,7 @@ class MidasParser(Parser):
|
|||||||
if self.match(TokenType.NUMBER):
|
if self.match(TokenType.NUMBER):
|
||||||
return LiteralExpr(location=token.get_location(), value=token.value)
|
return LiteralExpr(location=token.get_location(), value=token.value)
|
||||||
|
|
||||||
if self.match(TokenType.IDENTIFIER):
|
if self.match_identifier():
|
||||||
return VariableExpr(location=token.get_location(), name=token)
|
return VariableExpr(location=token.get_location(), name=token)
|
||||||
|
|
||||||
if self.match(TokenType.UNDERSCORE):
|
if self.match(TokenType.UNDERSCORE):
|
||||||
@@ -363,6 +362,20 @@ class MidasParser(Parser):
|
|||||||
|
|
||||||
raise self.error(self.peek(), "Expected expression")
|
raise self.error(self.peek(), "Expected expression")
|
||||||
|
|
||||||
|
def consume_identifier(self, message: str = "Expected identifier") -> Token:
|
||||||
|
if not self.match_identifier():
|
||||||
|
raise self.error(self.peek(), message)
|
||||||
|
return self.previous()
|
||||||
|
|
||||||
|
def match_identifier(self) -> bool:
|
||||||
|
return self.match(TokenType.IDENTIFIER, *KEYWORDS.values())
|
||||||
|
|
||||||
|
def check_identifier(self) -> bool:
|
||||||
|
for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]:
|
||||||
|
if self.check(tt):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def property_stmt(self) -> PropertyStmt:
|
def property_stmt(self) -> PropertyStmt:
|
||||||
"""Parse a property statement
|
"""Parse a property statement
|
||||||
|
|
||||||
@@ -371,7 +384,7 @@ class MidasParser(Parser):
|
|||||||
Returns:
|
Returns:
|
||||||
PropertyStmt: the parsed property statement
|
PropertyStmt: the parsed property statement
|
||||||
"""
|
"""
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
|
name: Token = self.consume_identifier("Expected property name")
|
||||||
self.consume(TokenType.COLON, "Expected ':' after property name")
|
self.consume(TokenType.COLON, "Expected ':' after property name")
|
||||||
type: Type = self.type_expr()
|
type: Type = self.type_expr()
|
||||||
return PropertyStmt(
|
return PropertyStmt(
|
||||||
@@ -439,9 +452,9 @@ class MidasParser(Parser):
|
|||||||
PredicateStmt: the parsed predicate declaration statement
|
PredicateStmt: the parsed predicate declaration statement
|
||||||
"""
|
"""
|
||||||
keyword: Token = self.previous()
|
keyword: Token = self.previous()
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name")
|
name: Token = self.consume_identifier("Expected predicate name")
|
||||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
||||||
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
|
subject: Token = self.consume_identifier("Expected subject name")
|
||||||
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
||||||
type: Type = self.type_expr()
|
type: Type = self.type_expr()
|
||||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
||||||
@@ -454,3 +467,48 @@ class MidasParser(Parser):
|
|||||||
type=type,
|
type=type,
|
||||||
condition=condition,
|
condition=condition,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def function(self) -> FunctionType:
|
||||||
|
l_paren: Token = self.consume(
|
||||||
|
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
||||||
|
)
|
||||||
|
pos_args: list[FunctionType.Argument] = []
|
||||||
|
kw_args: list[FunctionType.Argument] = []
|
||||||
|
|
||||||
|
positional: bool = True
|
||||||
|
while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
|
||||||
|
if positional and (
|
||||||
|
self.match(TokenType.STAR) or self.match(TokenType.SLASH)
|
||||||
|
):
|
||||||
|
positional = False
|
||||||
|
else:
|
||||||
|
name: Optional[Token] = None
|
||||||
|
if self.check_identifier() and self.check_next(TokenType.COLON):
|
||||||
|
name = self.advance()
|
||||||
|
self.advance()
|
||||||
|
type: Type = self.type_expr()
|
||||||
|
required: bool = self.match(TokenType.QMARK)
|
||||||
|
arg = FunctionType.Argument(
|
||||||
|
location=None,
|
||||||
|
name=name,
|
||||||
|
type=type,
|
||||||
|
required=required,
|
||||||
|
)
|
||||||
|
if positional:
|
||||||
|
pos_args.append(arg)
|
||||||
|
else:
|
||||||
|
kw_args.append(arg)
|
||||||
|
|
||||||
|
if not self.match(TokenType.COMMA):
|
||||||
|
break
|
||||||
|
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
||||||
|
|
||||||
|
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||||
|
result: Type = self.type_expr()
|
||||||
|
|
||||||
|
return FunctionType(
|
||||||
|
location=l_paren.location_to(self.previous()),
|
||||||
|
pos_args=pos_args,
|
||||||
|
kw_args=kw_args,
|
||||||
|
returns=result,
|
||||||
|
)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from midas.ast.midas import (
|
|||||||
ConstraintType,
|
ConstraintType,
|
||||||
Expr,
|
Expr,
|
||||||
ExtendStmt,
|
ExtendStmt,
|
||||||
|
FunctionType,
|
||||||
GenericType,
|
GenericType,
|
||||||
GetExpr,
|
GetExpr,
|
||||||
GroupingExpr,
|
GroupingExpr,
|
||||||
@@ -164,3 +165,18 @@ class MidasAstJsonSerializer(
|
|||||||
"_type": "ComplexType",
|
"_type": "ComplexType",
|
||||||
"properties": self._serialize_list(type.properties),
|
"properties": self._serialize_list(type.properties),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def visit_function_type(self, type: FunctionType) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "FunctionType",
|
||||||
|
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
|
||||||
|
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
|
||||||
|
"returns": type.returns.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
|
||||||
|
return {
|
||||||
|
"name": arg.name,
|
||||||
|
"type": arg.type.accept(self),
|
||||||
|
"required": arg.required,
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user