feat(midas): generalize param spec of predicate and parse

This commit is contained in:
2026-06-18 12:38:24 +02:00
parent 94d84ab170
commit ad86446a2d
8 changed files with 62 additions and 29 deletions

View File

@@ -58,9 +58,8 @@ class ExtendStmt:
class PredicateStmt:
name: Token
subject: Token
type: Type
condition: Expr
params: list[ParamSpec]
body: Expr
###<

View File

@@ -94,9 +94,8 @@ class ExtendStmt(Stmt):
@dataclass(frozen=True)
class PredicateStmt(Stmt):
name: Token
subject: Token
type: Type
condition: Expr
params: list[ParamSpec]
body: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_predicate_stmt(self)

View File

@@ -150,13 +150,17 @@ class MidasAstPrinter(
self._write_line("PredicateStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line(f'subject: "{stmt.subject.lexeme}"')
self._write_line("type")
self._write_line("params")
with self._child_level():
for i, spec in enumerate(stmt.params):
self._idx = i
if i == len(stmt.params) - 1:
self._mark_last()
self._visit_param_spec(spec)
self._write_line("body", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
self._write_line("condition", last=True)
with self._child_level(single=True):
stmt.condition.accept(self)
stmt.body.accept(self)
# Expressions
@@ -397,10 +401,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
name: str = stmt.name.lexeme
subject: str = stmt.subject.lexeme
type: str = stmt.type.accept(self)
condition: str = stmt.condition.accept(self)
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
body: str = stmt.body.accept(self)
return self.indented(f"predicate {name}{sig} = {body}")
def visit_logical_expr(self, expr: m.LogicalExpr):
left: str = expr.left.accept(self)

View File

@@ -13,6 +13,7 @@ from midas.checker.types import (
ExtensionType,
Function,
GenericType,
Predicate,
Type,
TypeVar,
UnknownType,
@@ -45,6 +46,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
self.process(builtins_path.read_text(), str(builtins_path))
self._bool: Type = self.get_type("bool")
def process(self, source: str, path: Optional[str]):
self.reporter = self.reporter.for_file(path)
lexer: MidasLexer = MidasLexer(source)
@@ -114,7 +117,24 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.reporter.warning(stmt.location, "PredicateStmt not yet supported")
params: list[TypedParamSpec] = [
self._visit_param_spec(spec) for spec in stmt.params
]
type: Type = self._bool
for spec in reversed(params):
type = Function(
pos_args=spec.pos,
args=spec.mixed,
kw_args=spec.kw,
returns=type,
)
self.types.define_predicate(
stmt.name.lexeme,
Predicate(
type=type,
body=stmt.body,
),
)
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.reporter.warning(expr.location, "LogicalExpr not yet supported")

View File

@@ -11,6 +11,7 @@ from midas.checker.types import (
Function,
GenericType,
OverloadedFunction,
Predicate,
TopType,
Type,
TypeVar,
@@ -24,6 +25,7 @@ class TypesRegistry:
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
self._types: dict[str, Type] = {}
self._members: dict[str, dict[str, Type]] = {}
self._predicates: dict[str, Predicate] = {}
def get_type(self, name: str) -> Type:
"""Get a type from its name
@@ -81,6 +83,11 @@ class TypesRegistry:
else:
members[member_name] = member_type
def define_predicate(self, name: str, predicate: Predicate):
if name in self._predicates:
raise ValueError(f"Predicate {name} already defined")
self._predicates[name] = predicate
def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2`

View File

@@ -215,6 +215,12 @@ def unfold_type(type: Type) -> Type:
return type
@dataclass(frozen=True, kw_only=True)
class Predicate:
type: Type
body: m.Expr
Type = (
TopType
| BaseType

View File

@@ -506,20 +506,20 @@ class MidasParser(Parser):
PredicateStmt: the parsed predicate declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume_identifier("Expected predicate name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
subject: Token = self.consume_identifier("Expected subject name")
self.consume(TokenType.COLON, "Expected ':' after subject name")
type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
params: list[ParamSpec] = []
while self.check(TokenType.LEFT_PAREN):
params.append(self.function_args())
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
condition: Expr = self.constraint()
body: Expr = self.constraint()
return PredicateStmt(
location=keyword.location_to(self.previous()),
name=name,
subject=subject,
type=type,
condition=condition,
params=params,
body=body,
)
def function(self) -> FunctionType:

View File

@@ -80,9 +80,8 @@ class MidasAstJsonSerializer(
return {
"_type": "PredicateStmt",
"name": stmt.name.lexeme,
"subject": stmt.subject.lexeme,
"type": stmt.type.accept(self),
"condition": stmt.condition.accept(self),
"params": [self._serialize_param_spec(spec) for spec in stmt.params],
"body": stmt.body.accept(self),
}
def visit_logical_expr(self, expr: LogicalExpr) -> dict: