refactor: add param spec for FunctionType

This commit is contained in:
2026-06-18 11:06:02 +02:00
parent a4a2ed5d64
commit 8381f4f31d
6 changed files with 106 additions and 58 deletions

View File

@@ -26,6 +26,14 @@ class MemberKind(Enum):
METHOD = auto() METHOD = auto()
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
l_paren: Token
pos: list[FunctionType.Argument]
mixed: list[FunctionType.Argument]
kw: list[FunctionType.Argument]
###< ###<
@@ -128,9 +136,7 @@ class ExtensionType:
class FunctionType: class FunctionType:
pos_args: list[Argument] params: ParamSpec
args: list[Argument]
kw_args: list[Argument]
returns: Type returns: Type
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)

View File

@@ -27,6 +27,14 @@ class MemberKind(Enum):
METHOD = auto() METHOD = auto()
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
l_paren: Token
pos: list[FunctionType.Argument]
mixed: list[FunctionType.Argument]
kw: list[FunctionType.Argument]
############## ##############
# Statements # # Statements #
############## ##############
@@ -279,9 +287,7 @@ class ExtensionType(Type):
@dataclass(frozen=True) @dataclass(frozen=True)
class FunctionType(Type): class FunctionType(Type):
pos_args: list[Argument] params: ParamSpec
args: list[Argument]
kw_args: list[Argument]
returns: Type returns: Type
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)

View File

@@ -276,34 +276,41 @@ class MidasAstPrinter(
def visit_function_type(self, type: m.FunctionType) -> None: def visit_function_type(self, type: m.FunctionType) -> None:
self._write_line("FunctionType") self._write_line("FunctionType")
with self._child_level(): with self._child_level():
self._write_line("pos_args") self._write_line("params")
with self._child_level(): with self._child_level(single=True):
for i, arg in enumerate(type.pos_args): self._visit_param_spec(type.params)
self._idx = i
if i == len(type.pos_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("args")
with self._child_level():
for i, arg in enumerate(type.args):
self._idx = i
if i == len(type.args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw_args")
with self._child_level():
for i, arg in enumerate(type.kw_args):
self._idx = i
if i == len(type.kw_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("returns", last=True) self._write_line("returns", last=True)
with self._child_level(single=True): with self._child_level(single=True):
type.returns.accept(self) type.returns.accept(self)
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
self._write_line("ParamSpec")
with self._child_level():
self._write_line("pos")
with self._child_level():
for i, arg in enumerate(spec.pos):
self._idx = i
if i == len(spec.pos) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("mixed")
with self._child_level():
for i, arg in enumerate(spec.mixed):
self._idx = i
if i == len(spec.mixed) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw", last=True)
with self._child_level():
for i, arg in enumerate(spec.kw):
self._idx = i
if i == len(spec.kw) - 1:
self._mark_last()
self._print_function_arg(arg)
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None: def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
self._write_line("Argument") self._write_line("Argument")
with self._child_level(): with self._child_level():
@@ -436,9 +443,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
return f"{type.base.accept(self)} & {type.extension.accept(self)}" return f"{type.base.accept(self)} & {type.extension.accept(self)}"
def visit_function_type(self, type: m.FunctionType) -> str: def visit_function_type(self, type: m.FunctionType) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args] spec: str = self._visit_param_spec(type.params)
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args] return f"fn {spec} -> {type.returns.accept(self)}"
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in spec.pos]
mixed_args: list[str] = [self._print_arg(arg) for arg in spec.mixed]
kw_args: list[str] = [self._print_arg(arg) for arg in spec.kw]
args: list[str] = pos_args args: list[str] = pos_args
if len(pos_args) != 0: if len(pos_args) != 0:
@@ -447,8 +458,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
if len(kw_args) != 0: if len(kw_args) != 0:
args.append("*") args.append("*")
args += kw_args args += kw_args
return f"({', '.join(args)})"
return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}"
def _print_arg(self, arg: m.FunctionType.Argument) -> str: def _print_arg(self, arg: m.FunctionType.Argument) -> str:
res: str = "" res: str = ""

View File

@@ -1,4 +1,5 @@
import logging import logging
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@@ -21,6 +22,13 @@ from midas.lexer.token import Token
from midas.parser.midas import MidasParser from midas.parser.midas import MidasParser
@dataclass(frozen=True, kw_only=True)
class TypedParamSpec:
pos: list[Function.Argument]
mixed: list[Function.Argument]
kw: list[Function.Argument]
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]): class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry""" """A resolver which evaluates Midas type definitions and build a registry"""
@@ -172,8 +180,17 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
) )
def visit_function_type(self, type: m.FunctionType) -> Type: def visit_function_type(self, type: m.FunctionType) -> Type:
n_pos_args: int = len(type.pos_args) params: TypedParamSpec = self._visit_param_spec(type.params)
n_args: int = len(type.args) return Function(
pos_args=params.pos,
args=params.mixed,
kw_args=params.kw,
returns=type.returns.accept(self),
)
def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec:
n_pos: int = len(spec.pos)
n_mixed: int = len(spec.mixed)
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument: def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
return Function.Argument( return Function.Argument(
@@ -183,14 +200,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
required=arg.required, required=arg.required,
) )
return Function( return TypedParamSpec(
pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)], pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)],
args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)], mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)],
kw_args=[ kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
process_arg(arg, i + n_pos_args + n_args)
for i, arg in enumerate(type.kw_args)
],
returns=type.returns.accept(self),
) )
def _resolve_type_params(self, params: list[m.TypeParam]): def _resolve_type_params(self, params: list[m.TypeParam]):

View File

@@ -17,6 +17,7 @@ from midas.ast.midas import (
MemberKind, MemberKind,
MemberStmt, MemberStmt,
NamedType, NamedType,
ParamSpec,
PredicateStmt, PredicateStmt,
Stmt, Stmt,
Type, Type,
@@ -265,6 +266,9 @@ class MidasParser(Parser):
Returns: Returns:
Expr: the parsed constraint expression Expr: the parsed constraint expression
""" """
return self.expression()
def expression(self) -> Expr:
return self.and_() return self.and_()
def and_(self) -> Expr: def and_(self) -> Expr:
@@ -470,6 +474,18 @@ class MidasParser(Parser):
) )
def function(self) -> FunctionType: def function(self) -> FunctionType:
params: ParamSpec = self.function_args()
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: Type = self.type_expr()
return FunctionType(
location=params.l_paren.location_to(self.previous()),
params=params,
returns=result,
)
def function_args(self) -> ParamSpec:
l_paren: Token = self.consume( l_paren: Token = self.consume(
TokenType.LEFT_PAREN, "Expected '(' before function parameters" TokenType.LEFT_PAREN, "Expected '(' before function parameters"
) )
@@ -526,14 +542,4 @@ class MidasParser(Parser):
self.error(token, "Unnamed mixed argument") self.error(token, "Unnamed mixed argument")
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters") self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
return ParamSpec(l_paren=l_paren, pos=pos_args, mixed=args, kw=kw_args)
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: Type = self.type_expr()
return FunctionType(
location=l_paren.location_to(self.previous()),
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=result,
)

View File

@@ -15,6 +15,7 @@ from midas.ast.midas import (
LogicalExpr, LogicalExpr,
MemberStmt, MemberStmt,
NamedType, NamedType,
ParamSpec,
PredicateStmt, PredicateStmt,
Stmt, Stmt,
Type, Type,
@@ -163,12 +164,18 @@ class MidasAstJsonSerializer(
def visit_function_type(self, type: FunctionType) -> dict: def visit_function_type(self, type: FunctionType) -> dict:
return { return {
"_type": "FunctionType", "_type": "FunctionType",
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args], "params": self._serialize_param_spec(type.params),
"args": [self._serialize_func_arg(arg) for arg in type.args],
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
"returns": type.returns.accept(self), "returns": type.returns.accept(self),
} }
def _serialize_param_spec(self, spec: ParamSpec) -> dict:
return {
"_type": "ParamSpec",
"pos": [self._serialize_func_arg(arg) for arg in spec.pos],
"mixed": [self._serialize_func_arg(arg) for arg in spec.mixed],
"kw": [self._serialize_func_arg(arg) for arg in spec.kw],
}
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict: def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
return { return {
"name": arg.name, "name": arg.name,