feat(midas): add CallExpr

This commit is contained in:
2026-06-18 12:34:29 +02:00
parent 8381f4f31d
commit 94d84ab170
6 changed files with 110 additions and 1 deletions

View File

@@ -86,6 +86,12 @@ class UnaryExpr:
right: Expr right: Expr
class CallExpr:
callee: Expr
arguments: list[Expr]
keywords: dict[str, Expr]
class GetExpr: class GetExpr:
expr: Expr expr: Expr
name: Token name: Token

View File

@@ -124,6 +124,9 @@ class Expr(ABC):
@abstractmethod @abstractmethod
def visit_unary_expr(self, expr: UnaryExpr) -> T: ... def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
@abstractmethod
def visit_call_expr(self, expr: CallExpr) -> T: ...
@abstractmethod @abstractmethod
def visit_get_expr(self, expr: GetExpr) -> T: ... def visit_get_expr(self, expr: GetExpr) -> T: ...
@@ -169,6 +172,16 @@ class UnaryExpr(Expr):
return visitor.visit_unary_expr(self) return visitor.visit_unary_expr(self)
@dataclass(frozen=True)
class CallExpr(Expr):
callee: Expr
arguments: list[Expr]
keywords: dict[str, Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_call_expr(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class GetExpr(Expr): class GetExpr(Expr):
expr: Expr expr: Expr

View File

@@ -195,6 +195,29 @@ class MidasAstPrinter(
with self._child_level(single=True): with self._child_level(single=True):
expr.right.accept(self) expr.right.accept(self)
def visit_call_expr(self, expr: m.CallExpr) -> None:
self._write_line("CallExpr")
with self._child_level():
self._write_line("callee")
with self._child_level(single=True):
expr.callee.accept(self)
self._write_line("arguments")
with self._child_level():
for i, arg in enumerate(expr.arguments):
self._idx = i
if i == len(expr.arguments) - 1:
self._mark_last()
arg.accept(self)
self._write_line("keywords", last=True)
with self._child_level():
for i, (name, arg) in enumerate(expr.keywords.items()):
self._idx = i
if i == len(expr.keywords) - 1:
self._mark_last()
self._write_line(name)
with self._child_level(single=True):
arg.accept(self)
def visit_get_expr(self, expr: m.GetExpr): def visit_get_expr(self, expr: m.GetExpr):
self._write_line("GetExpr") self._write_line("GetExpr")
with self._child_level(): with self._child_level():
@@ -396,6 +419,12 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
right: str = expr.right.accept(self) right: str = expr.right.accept(self)
return f"{operator}{right}" return f"{operator}{right}"
def visit_call_expr(self, expr: m.CallExpr) -> str:
args: list[str] = [arg.accept(self) for arg in expr.arguments] + [
f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items()
]
return f"{expr.callee.accept(self)}({', '.join(args)})"
def visit_get_expr(self, expr: m.GetExpr): def visit_get_expr(self, expr: m.GetExpr):
expr_: str = expr.expr.accept(self) expr_: str = expr.expr.accept(self)
name: str = expr.name.lexeme name: str = expr.name.lexeme

View File

@@ -125,6 +125,9 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
def visit_unary_expr(self, expr: m.UnaryExpr) -> None: def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
self.reporter.warning(expr.location, "UnaryExpr not yet supported") self.reporter.warning(expr.location, "UnaryExpr not yet supported")
def visit_call_expr(self, expr: m.CallExpr) -> None:
self.reporter.warning(expr.location, "CallExpr not yet supported")
def visit_get_expr(self, expr: m.GetExpr) -> None: def visit_get_expr(self, expr: m.GetExpr) -> None:
self.reporter.warning(expr.location, "GetExpr not yet supported") self.reporter.warning(expr.location, "GetExpr not yet supported")

View File

@@ -3,6 +3,7 @@ from typing import Optional
from midas.ast.location import Location from midas.ast.location import Location
from midas.ast.midas import ( from midas.ast.midas import (
BinaryExpr, BinaryExpr,
CallExpr,
ComplexType, ComplexType,
ConstraintType, ConstraintType,
Expr, Expr,
@@ -335,7 +336,55 @@ class MidasParser(Parser):
right: Expr = self.unary() right: Expr = self.unary()
location: Location = Location.span(operator.get_location(), right.location) location: Location = Location.span(operator.get_location(), right.location)
return UnaryExpr(location=location, operator=operator, right=right) return UnaryExpr(location=location, operator=operator, right=right)
return self.reference() return self.call()
def call(self) -> Expr:
expr: Expr = self.reference()
if self.match(TokenType.LEFT_PAREN):
expr = self.finish_call(expr)
return expr
def finish_call(self, callee: Expr) -> Expr:
l_paren: Token = self.previous()
pos_args: list[Expr] = []
kw_args: dict[str, Expr] = {}
keywords: bool = False
while not self.match(TokenType.RIGHT_PAREN):
if self.check_identifier() and self.check_next(TokenType.EQUAL):
keywords = True
keyword: Token = self.advance()
value: Expr = self.expression()
name: str = keyword.lexeme
if name in kw_args:
self.error(
self.peek(),
f"Multiple values passed for '{name}', only the last occurrence will be used",
)
kw_args[name] = value
else:
value = self.expression()
if self.check(TokenType.EQUAL):
if keywords:
raise self.error(self.peek(), "Invalid keyword argument name")
else:
raise self.error(
self.peek(),
"Cannot pass positional arguments after a keyword argument",
)
pos_args.append(value)
if not self.match(TokenType.COMMA):
break
r_paren: Token = self.consume(
TokenType.RIGHT_PAREN, "Expected ')' after arguments."
)
return CallExpr(
location=l_paren.location_to(r_paren),
callee=callee,
arguments=pos_args,
keywords=kw_args,
)
def reference(self) -> Expr: def reference(self) -> Expr:
"""Parse an attribute access expression or a simpler expression """Parse an attribute access expression or a simpler expression

View File

@@ -2,6 +2,7 @@ from typing import Optional, Sequence
from midas.ast.midas import ( from midas.ast.midas import (
BinaryExpr, BinaryExpr,
CallExpr,
ComplexType, ComplexType,
ConstraintType, ConstraintType,
Expr, Expr,
@@ -107,6 +108,14 @@ class MidasAstJsonSerializer(
"right": expr.right.accept(self), "right": expr.right.accept(self),
} }
def visit_call_expr(self, expr: CallExpr) -> dict:
return {
"_type": "CallExpr",
"callee": expr.callee.accept(self),
"arguments": self._serialize_list(expr.arguments),
"keywords": {name: arg.accept(self) for name, arg in expr.keywords.items()},
}
def visit_get_expr(self, expr: GetExpr) -> dict: def visit_get_expr(self, expr: GetExpr) -> dict:
return { return {
"_type": "GetExpr", "_type": "GetExpr",