feat: add function type to midas syntax

This commit is contained in:
2026-06-09 23:48:06 +02:00
parent 9de03bf2b5
commit c6ead886ec
9 changed files with 226 additions and 15 deletions

View File

@@ -121,4 +121,17 @@ class ComplexType:
properties: list[PropertyStmt]
class FunctionType:
pos_args: list[Argument]
kw_args: list[Argument]
returns: Type
@dataclass(frozen=True, kw_only=True)
class Argument:
location: Optional[Location] = None
name: Optional[Token]
type: Type
required: bool
###<

View File

@@ -233,6 +233,9 @@ class Type(ABC):
@abstractmethod
def visit_complex_type(self, type: ComplexType) -> T: ...
@abstractmethod
def visit_function_type(self, type: FunctionType) -> T: ...
@dataclass(frozen=True)
class NamedType(Type):
@@ -266,3 +269,20 @@ class ComplexType(Type):
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_complex_type(self)
@dataclass(frozen=True)
class FunctionType(Type):
pos_args: list[Argument]
kw_args: list[Argument]
returns: Type
@dataclass(frozen=True, kw_only=True)
class Argument:
location: Optional[Location] = None
name: Optional[Token]
type: Type
required: bool
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_function_type(self)

View File

@@ -270,6 +270,41 @@ class MidasAstPrinter(
self._mark_last()
prop.accept(self)
def visit_function_type(self, type: m.FunctionType) -> None:
self._write_line("FunctionType")
with self._child_level():
self._write_line("pos_args")
with self._child_level():
for i, arg in enumerate(type.pos_args):
self._idx = i
if i == len(type.pos_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw_args")
with self._child_level():
for i, arg in enumerate(type.kw_args):
self._idx = i
if i == len(type.kw_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("returns", last=True)
with self._child_level(single=True):
type.returns.accept(self)
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
self._write_line("Argument")
with self._child_level():
name: str = "None"
if arg.name is not None:
name = f'"{arg.name.lexeme}"'
self._write_line(f"name: {name}")
self._write_line("type")
with self._child_level(single=True):
arg.type.accept(self)
self._write_line(f"required: {arg.required}", last=True)
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
def __init__(self, indent: int = 4):
@@ -383,6 +418,29 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
res += self.indented("}")
return res
def visit_function_type(self, type: m.FunctionType) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
kw_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
args: list[str] = pos_args
if len(pos_args) != 0:
args.append("/")
if len(kw_args) != 0:
args.append("*")
args += kw_args
return f"({', '.join(args)}) -> {type.returns.accept(self)}"
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
res: str = ""
if arg.name is not None:
res += arg.name.lexeme
res += ": "
res += arg.type.accept(self)
if not arg.required:
res += "?"
return res
class PythonAstPrinter(
AstPrinter,

View File

@@ -8,6 +8,7 @@ from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import (
AliasType,
ComplexType,
Function,
GenericType,
Type,
TypeVar,
@@ -131,3 +132,40 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
prop.name.lexeme: prop.type.accept(self) for prop in type.properties
}
)
def visit_function_type(self, type: m.FunctionType) -> Type:
return Function(
name="<anonymous>",
pos_args=[
Function.Argument(
pos=i,
name=arg.name.lexeme if arg.name is not None else str(i),
type=arg.type.accept(self),
required=arg.required,
)
for i, arg in enumerate(type.pos_args)
],
args=[],
kw_args=[
Function.Argument(
pos=i,
name=arg.name.lexeme if arg.name is not None else str(i),
type=arg.type.accept(self),
required=arg.required,
)
for i, arg in enumerate(type.kw_args, start=len(type.pos_args))
],
returns=type.returns.accept(self),
)
def _resolve_type_params(self, params: list[m.TypeParam]):
vars: list[TypeVar] = []
for param in params:
name: str = param.name.lexeme
bound: Optional[Type] = None
if param.bound is not None:
bound = param.bound.accept(self)
var = TypeVar(name=name, bound=bound)
self._local_variables[name] = var
vars.append(var)
return vars

View File

@@ -301,6 +301,12 @@ class MidasHighlighter(
for prop in type.properties:
prop.accept(self)
def visit_function_type(self, type: m.FunctionType) -> None:
self.wrap(type, "function")
for arg in type.pos_args + type.kw_args:
arg.type.accept(self)
type.returns.accept(self)
class DiagnosticsHighlighter(Highlighter):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"

View File

@@ -50,12 +50,14 @@ class MidasLexer(Lexer):
# self.add_token(TokenType.PLUS)
case "-":
self.add_token(TokenType.MINUS)
# case "*":
# self.add_token(TokenType.STAR)
case "*":
self.add_token(TokenType.STAR)
case "/" if self.match("/"):
self.scan_comment()
case "/" if self.match("*"):
self.scan_comment_multiline()
case "/":
self.add_token(TokenType.SLASH)
case "\n":
self.add_token(TokenType.NEWLINE)
case " " | "\r" | "\t":

View File

@@ -27,8 +27,8 @@ class TokenType(Enum):
# Operators
# PLUS = auto()
MINUS = auto()
# STAR = auto()
# SLASH = auto()
STAR = auto()
SLASH = auto()
GREATER = auto()
GREATER_EQUAL = auto()
LESS = auto()

View File

@@ -7,6 +7,7 @@ from midas.ast.midas import (
ConstraintType,
Expr,
ExtendStmt,
FunctionType,
GenericType,
GetExpr,
GroupingExpr,
@@ -24,7 +25,7 @@ from midas.ast.midas import (
VariableExpr,
WildcardExpr,
)
from midas.lexer.token import Token, TokenType
from midas.lexer.token import KEYWORDS, Token, TokenType
from midas.parser.base import Parser
from midas.parser.errors import ParsingError
@@ -108,7 +109,7 @@ class MidasParser(Parser):
TypeStmt: the parsed type declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
name: Token = self.consume_identifier("Expected type name")
params: list[TypeParam] = self.type_params()
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
@@ -136,7 +137,7 @@ class MidasParser(Parser):
params: list[TypeParam] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
name: Token = self.consume_identifier("Expected type variable")
bound: Optional[Type] = None
if self.match(TokenType.LESS):
self.consume(TokenType.COLON, "Expected ':' after '<'")
@@ -208,7 +209,7 @@ class MidasParser(Parser):
return args
def named_type(self) -> Type:
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
name: Token = self.consume_identifier("Expected type name")
return NamedType(
location=name.get_location(),
name=name,
@@ -324,9 +325,7 @@ class MidasParser(Parser):
"""
expr: Expr = self.primary()
while self.match(TokenType.DOT):
name: Token = self.consume(
TokenType.IDENTIFIER, "Expected property name after '.'"
)
name: Token = self.consume_identifier("Expected property name after '.'")
location: Location = Location.span(expr.location, name.get_location())
expr = GetExpr(location=location, expr=expr, name=name)
return expr
@@ -350,7 +349,7 @@ class MidasParser(Parser):
if self.match(TokenType.NUMBER):
return LiteralExpr(location=token.get_location(), value=token.value)
if self.match(TokenType.IDENTIFIER):
if self.match_identifier():
return VariableExpr(location=token.get_location(), name=token)
if self.match(TokenType.UNDERSCORE):
@@ -363,6 +362,20 @@ class MidasParser(Parser):
raise self.error(self.peek(), "Expected expression")
def consume_identifier(self, message: str = "Expected identifier") -> Token:
if not self.match_identifier():
raise self.error(self.peek(), message)
return self.previous()
def match_identifier(self) -> bool:
return self.match(TokenType.IDENTIFIER, *KEYWORDS.values())
def check_identifier(self) -> bool:
for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]:
if self.check(tt):
return True
return False
def property_stmt(self) -> PropertyStmt:
"""Parse a property statement
@@ -371,7 +384,7 @@ class MidasParser(Parser):
Returns:
PropertyStmt: the parsed property statement
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
name: Token = self.consume_identifier("Expected property name")
self.consume(TokenType.COLON, "Expected ':' after property name")
type: Type = self.type_expr()
return PropertyStmt(
@@ -439,9 +452,9 @@ class MidasParser(Parser):
PredicateStmt: the parsed predicate declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name")
name: Token = self.consume_identifier("Expected predicate name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
subject: Token = self.consume_identifier("Expected subject name")
self.consume(TokenType.COLON, "Expected ':' after subject name")
type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
@@ -454,3 +467,48 @@ class MidasParser(Parser):
type=type,
condition=condition,
)
def function(self) -> FunctionType:
l_paren: Token = self.consume(
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
)
pos_args: list[FunctionType.Argument] = []
kw_args: list[FunctionType.Argument] = []
positional: bool = True
while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
if positional and (
self.match(TokenType.STAR) or self.match(TokenType.SLASH)
):
positional = False
else:
name: Optional[Token] = None
if self.check_identifier() and self.check_next(TokenType.COLON):
name = self.advance()
self.advance()
type: Type = self.type_expr()
required: bool = self.match(TokenType.QMARK)
arg = FunctionType.Argument(
location=None,
name=name,
type=type,
required=required,
)
if positional:
pos_args.append(arg)
else:
kw_args.append(arg)
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: Type = self.type_expr()
return FunctionType(
location=l_paren.location_to(self.previous()),
pos_args=pos_args,
kw_args=kw_args,
returns=result,
)

View File

@@ -6,6 +6,7 @@ from midas.ast.midas import (
ConstraintType,
Expr,
ExtendStmt,
FunctionType,
GenericType,
GetExpr,
GroupingExpr,
@@ -164,3 +165,18 @@ class MidasAstJsonSerializer(
"_type": "ComplexType",
"properties": self._serialize_list(type.properties),
}
def visit_function_type(self, type: FunctionType) -> dict:
return {
"_type": "FunctionType",
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
"returns": type.returns.accept(self),
}
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
return {
"name": arg.name,
"type": arg.type.accept(self),
"required": arg.required,
}