feat(midas): generalize param spec of predicate and parse

This commit is contained in:
2026-06-18 12:38:24 +02:00
parent 94d84ab170
commit ad86446a2d
8 changed files with 62 additions and 29 deletions

View File

@@ -58,9 +58,8 @@ class ExtendStmt:
class PredicateStmt: class PredicateStmt:
name: Token name: Token
subject: Token params: list[ParamSpec]
type: Type body: Expr
condition: Expr
###< ###<

View File

@@ -94,9 +94,8 @@ class ExtendStmt(Stmt):
@dataclass(frozen=True) @dataclass(frozen=True)
class PredicateStmt(Stmt): class PredicateStmt(Stmt):
name: Token name: Token
subject: Token params: list[ParamSpec]
type: Type body: Expr
condition: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_predicate_stmt(self) return visitor.visit_predicate_stmt(self)

View File

@@ -150,13 +150,17 @@ class MidasAstPrinter(
self._write_line("PredicateStmt") self._write_line("PredicateStmt")
with self._child_level(): with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"') self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line(f'subject: "{stmt.subject.lexeme}"') self._write_line("params")
self._write_line("type") with self._child_level():
for i, spec in enumerate(stmt.params):
self._idx = i
if i == len(stmt.params) - 1:
self._mark_last()
self._visit_param_spec(spec)
self._write_line("body", last=True)
with self._child_level(single=True): with self._child_level(single=True):
stmt.type.accept(self) stmt.body.accept(self)
self._write_line("condition", last=True)
with self._child_level(single=True):
stmt.condition.accept(self)
# Expressions # Expressions
@@ -397,10 +401,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
def visit_predicate_stmt(self, stmt: m.PredicateStmt): def visit_predicate_stmt(self, stmt: m.PredicateStmt):
name: str = stmt.name.lexeme name: str = stmt.name.lexeme
subject: str = stmt.subject.lexeme sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
type: str = stmt.type.accept(self) body: str = stmt.body.accept(self)
condition: str = stmt.condition.accept(self) return self.indented(f"predicate {name}{sig} = {body}")
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
def visit_logical_expr(self, expr: m.LogicalExpr): def visit_logical_expr(self, expr: m.LogicalExpr):
left: str = expr.left.accept(self) left: str = expr.left.accept(self)

View File

@@ -13,6 +13,7 @@ from midas.checker.types import (
ExtensionType, ExtensionType,
Function, Function,
GenericType, GenericType,
Predicate,
Type, Type,
TypeVar, TypeVar,
UnknownType, UnknownType,
@@ -45,6 +46,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve() builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
self.process(builtins_path.read_text(), str(builtins_path)) self.process(builtins_path.read_text(), str(builtins_path))
self._bool: Type = self.get_type("bool")
def process(self, source: str, path: Optional[str]): def process(self, source: str, path: Optional[str]):
self.reporter = self.reporter.for_file(path) self.reporter = self.reporter.for_file(path)
lexer: MidasLexer = MidasLexer(source) lexer: MidasLexer = MidasLexer(source)
@@ -114,7 +117,24 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
) )
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.reporter.warning(stmt.location, "PredicateStmt not yet supported") params: list[TypedParamSpec] = [
self._visit_param_spec(spec) for spec in stmt.params
]
type: Type = self._bool
for spec in reversed(params):
type = Function(
pos_args=spec.pos,
args=spec.mixed,
kw_args=spec.kw,
returns=type,
)
self.types.define_predicate(
stmt.name.lexeme,
Predicate(
type=type,
body=stmt.body,
),
)
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.reporter.warning(expr.location, "LogicalExpr not yet supported") self.reporter.warning(expr.location, "LogicalExpr not yet supported")

View File

@@ -11,6 +11,7 @@ from midas.checker.types import (
Function, Function,
GenericType, GenericType,
OverloadedFunction, OverloadedFunction,
Predicate,
TopType, TopType,
Type, Type,
TypeVar, TypeVar,
@@ -24,6 +25,7 @@ class TypesRegistry:
self.logger: logging.Logger = logging.getLogger("TypesRegistry") self.logger: logging.Logger = logging.getLogger("TypesRegistry")
self._types: dict[str, Type] = {} self._types: dict[str, Type] = {}
self._members: dict[str, dict[str, Type]] = {} self._members: dict[str, dict[str, Type]] = {}
self._predicates: dict[str, Predicate] = {}
def get_type(self, name: str) -> Type: def get_type(self, name: str) -> Type:
"""Get a type from its name """Get a type from its name
@@ -81,6 +83,11 @@ class TypesRegistry:
else: else:
members[member_name] = member_type members[member_name] = member_type
def define_predicate(self, name: str, predicate: Predicate):
if name in self._predicates:
raise ValueError(f"Predicate {name} already defined")
self._predicates[name] = predicate
def is_subtype(self, type1: Type, type2: Type) -> bool: def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2` """Check whether `type1` is a subtype of `type2`

View File

@@ -215,6 +215,12 @@ def unfold_type(type: Type) -> Type:
return type return type
@dataclass(frozen=True, kw_only=True)
class Predicate:
type: Type
body: m.Expr
Type = ( Type = (
TopType TopType
| BaseType | BaseType

View File

@@ -506,20 +506,20 @@ class MidasParser(Parser):
PredicateStmt: the parsed predicate declaration statement PredicateStmt: the parsed predicate declaration statement
""" """
keyword: Token = self.previous() keyword: Token = self.previous()
name: Token = self.consume_identifier("Expected predicate name") name: Token = self.consume_identifier("Expected predicate name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
subject: Token = self.consume_identifier("Expected subject name") params: list[ParamSpec] = []
self.consume(TokenType.COLON, "Expected ':' after subject name") while self.check(TokenType.LEFT_PAREN):
type: Type = self.type_expr() params.append(self.function_args())
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject") self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
condition: Expr = self.constraint() body: Expr = self.constraint()
return PredicateStmt( return PredicateStmt(
location=keyword.location_to(self.previous()), location=keyword.location_to(self.previous()),
name=name, name=name,
subject=subject, params=params,
type=type, body=body,
condition=condition,
) )
def function(self) -> FunctionType: def function(self) -> FunctionType:

View File

@@ -80,9 +80,8 @@ class MidasAstJsonSerializer(
return { return {
"_type": "PredicateStmt", "_type": "PredicateStmt",
"name": stmt.name.lexeme, "name": stmt.name.lexeme,
"subject": stmt.subject.lexeme, "params": [self._serialize_param_spec(spec) for spec in stmt.params],
"type": stmt.type.accept(self), "body": stmt.body.accept(self),
"condition": stmt.condition.accept(self),
} }
def visit_logical_expr(self, expr: LogicalExpr) -> dict: def visit_logical_expr(self, expr: LogicalExpr) -> dict: