refactor: add param spec for FunctionType

This commit is contained in:
2026-06-18 11:06:02 +02:00
parent a4a2ed5d64
commit 8381f4f31d
6 changed files with 106 additions and 58 deletions

View File

@@ -26,6 +26,14 @@ class MemberKind(Enum):
METHOD = auto()
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
l_paren: Token
pos: list[FunctionType.Argument]
mixed: list[FunctionType.Argument]
kw: list[FunctionType.Argument]
###<
@@ -128,9 +136,7 @@ class ExtensionType:
class FunctionType:
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
params: ParamSpec
returns: Type
@dataclass(frozen=True, kw_only=True)

View File

@@ -27,6 +27,14 @@ class MemberKind(Enum):
METHOD = auto()
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
l_paren: Token
pos: list[FunctionType.Argument]
mixed: list[FunctionType.Argument]
kw: list[FunctionType.Argument]
##############
# Statements #
##############
@@ -279,9 +287,7 @@ class ExtensionType(Type):
@dataclass(frozen=True)
class FunctionType(Type):
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
params: ParamSpec
returns: Type
@dataclass(frozen=True, kw_only=True)

View File

@@ -276,34 +276,41 @@ class MidasAstPrinter(
def visit_function_type(self, type: m.FunctionType) -> None:
self._write_line("FunctionType")
with self._child_level():
self._write_line("pos_args")
with self._child_level():
for i, arg in enumerate(type.pos_args):
self._idx = i
if i == len(type.pos_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("args")
with self._child_level():
for i, arg in enumerate(type.args):
self._idx = i
if i == len(type.args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw_args")
with self._child_level():
for i, arg in enumerate(type.kw_args):
self._idx = i
if i == len(type.kw_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("params")
with self._child_level(single=True):
self._visit_param_spec(type.params)
self._write_line("returns", last=True)
with self._child_level(single=True):
type.returns.accept(self)
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
self._write_line("ParamSpec")
with self._child_level():
self._write_line("pos")
with self._child_level():
for i, arg in enumerate(spec.pos):
self._idx = i
if i == len(spec.pos) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("mixed")
with self._child_level():
for i, arg in enumerate(spec.mixed):
self._idx = i
if i == len(spec.mixed) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw", last=True)
with self._child_level():
for i, arg in enumerate(spec.kw):
self._idx = i
if i == len(spec.kw) - 1:
self._mark_last()
self._print_function_arg(arg)
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
self._write_line("Argument")
with self._child_level():
@@ -436,9 +443,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
def visit_function_type(self, type: m.FunctionType) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
spec: str = self._visit_param_spec(type.params)
return f"fn {spec} -> {type.returns.accept(self)}"
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in spec.pos]
mixed_args: list[str] = [self._print_arg(arg) for arg in spec.mixed]
kw_args: list[str] = [self._print_arg(arg) for arg in spec.kw]
args: list[str] = pos_args
if len(pos_args) != 0:
@@ -447,8 +458,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
if len(kw_args) != 0:
args.append("*")
args += kw_args
return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}"
return f"({', '.join(args)})"
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
res: str = ""

View File

@@ -1,4 +1,5 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
@@ -21,6 +22,13 @@ from midas.lexer.token import Token
from midas.parser.midas import MidasParser
@dataclass(frozen=True, kw_only=True)
class TypedParamSpec:
pos: list[Function.Argument]
mixed: list[Function.Argument]
kw: list[Function.Argument]
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry"""
@@ -172,8 +180,17 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
)
def visit_function_type(self, type: m.FunctionType) -> Type:
n_pos_args: int = len(type.pos_args)
n_args: int = len(type.args)
params: TypedParamSpec = self._visit_param_spec(type.params)
return Function(
pos_args=params.pos,
args=params.mixed,
kw_args=params.kw,
returns=type.returns.accept(self),
)
def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec:
n_pos: int = len(spec.pos)
n_mixed: int = len(spec.mixed)
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
return Function.Argument(
@@ -183,14 +200,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
required=arg.required,
)
return Function(
pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
kw_args=[
process_arg(arg, i + n_pos_args + n_args)
for i, arg in enumerate(type.kw_args)
],
returns=type.returns.accept(self),
return TypedParamSpec(
pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)],
mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)],
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
)
def _resolve_type_params(self, params: list[m.TypeParam]):

View File

@@ -17,6 +17,7 @@ from midas.ast.midas import (
MemberKind,
MemberStmt,
NamedType,
ParamSpec,
PredicateStmt,
Stmt,
Type,
@@ -265,6 +266,9 @@ class MidasParser(Parser):
Returns:
Expr: the parsed constraint expression
"""
return self.expression()
def expression(self) -> Expr:
return self.and_()
def and_(self) -> Expr:
@@ -470,6 +474,18 @@ class MidasParser(Parser):
)
def function(self) -> FunctionType:
params: ParamSpec = self.function_args()
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: Type = self.type_expr()
return FunctionType(
location=params.l_paren.location_to(self.previous()),
params=params,
returns=result,
)
def function_args(self) -> ParamSpec:
l_paren: Token = self.consume(
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
)
@@ -526,14 +542,4 @@ class MidasParser(Parser):
self.error(token, "Unnamed mixed argument")
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: Type = self.type_expr()
return FunctionType(
location=l_paren.location_to(self.previous()),
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=result,
)
return ParamSpec(l_paren=l_paren, pos=pos_args, mixed=args, kw=kw_args)

View File

@@ -15,6 +15,7 @@ from midas.ast.midas import (
LogicalExpr,
MemberStmt,
NamedType,
ParamSpec,
PredicateStmt,
Stmt,
Type,
@@ -163,12 +164,18 @@ class MidasAstJsonSerializer(
def visit_function_type(self, type: FunctionType) -> dict:
return {
"_type": "FunctionType",
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
"args": [self._serialize_func_arg(arg) for arg in type.args],
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
"params": self._serialize_param_spec(type.params),
"returns": type.returns.accept(self),
}
def _serialize_param_spec(self, spec: ParamSpec) -> dict:
return {
"_type": "ParamSpec",
"pos": [self._serialize_func_arg(arg) for arg in spec.pos],
"mixed": [self._serialize_func_arg(arg) for arg in spec.mixed],
"kw": [self._serialize_func_arg(arg) for arg in spec.kw],
}
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
return {
"name": arg.name,