diff --git a/gen/midas.py b/gen/midas.py index fad64c2..287fcc3 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -58,9 +58,8 @@ class ExtendStmt: class PredicateStmt: name: Token - subject: Token - type: Type - condition: Expr + params: list[ParamSpec] + body: Expr ###< diff --git a/midas/ast/midas.py b/midas/ast/midas.py index 2381f65..1ece261 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -94,9 +94,8 @@ class ExtendStmt(Stmt): @dataclass(frozen=True) class PredicateStmt(Stmt): name: Token - subject: Token - type: Type - condition: Expr + params: list[ParamSpec] + body: Expr def accept(self, visitor: Stmt.Visitor[T]) -> T: return visitor.visit_predicate_stmt(self) diff --git a/midas/ast/printer.py b/midas/ast/printer.py index cdbd3a3..1c75a44 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -150,13 +150,17 @@ class MidasAstPrinter( self._write_line("PredicateStmt") with self._child_level(): self._write_line(f'name: "{stmt.name.lexeme}"') - self._write_line(f'subject: "{stmt.subject.lexeme}"') - self._write_line("type") + self._write_line("params") + with self._child_level(): + for i, spec in enumerate(stmt.params): + self._idx = i + if i == len(stmt.params) - 1: + self._mark_last() + self._visit_param_spec(spec) + + self._write_line("body", last=True) with self._child_level(single=True): - stmt.type.accept(self) - self._write_line("condition", last=True) - with self._child_level(single=True): - stmt.condition.accept(self) + stmt.body.accept(self) # Expressions @@ -397,10 +401,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] def visit_predicate_stmt(self, stmt: m.PredicateStmt): name: str = stmt.name.lexeme - subject: str = stmt.subject.lexeme - type: str = stmt.type.accept(self) - condition: str = stmt.condition.accept(self) - return self.indented(f"predicate {name}({subject}: {type}) = {condition}") + sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params) + body: str = stmt.body.accept(self) + return self.indented(f"predicate {name}{sig} = {body}") def visit_logical_expr(self, expr: m.LogicalExpr): left: str = expr.left.accept(self) diff --git a/midas/checker/midas.py b/midas/checker/midas.py index ab43292..fabdd59 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -13,6 +13,7 @@ from midas.checker.types import ( ExtensionType, Function, GenericType, + Predicate, Type, TypeVar, UnknownType, @@ -45,6 +46,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve() self.process(builtins_path.read_text(), str(builtins_path)) + self._bool: Type = self.get_type("bool") + def process(self, source: str, path: Optional[str]): self.reporter = self.reporter.for_file(path) lexer: MidasLexer = MidasLexer(source) @@ -114,7 +117,24 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type ) def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: - self.reporter.warning(stmt.location, "PredicateStmt not yet supported") + params: list[TypedParamSpec] = [ + self._visit_param_spec(spec) for spec in stmt.params + ] + type: Type = self._bool + for spec in reversed(params): + type = Function( + pos_args=spec.pos, + args=spec.mixed, + kw_args=spec.kw, + returns=type, + ) + self.types.define_predicate( + stmt.name.lexeme, + Predicate( + type=type, + body=stmt.body, + ), + ) def visit_logical_expr(self, expr: m.LogicalExpr) -> None: self.reporter.warning(expr.location, "LogicalExpr not yet supported") diff --git a/midas/checker/registry.py b/midas/checker/registry.py index 6591548..f97d04f 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -11,6 +11,7 @@ from midas.checker.types import ( Function, GenericType, OverloadedFunction, + Predicate, TopType, Type, TypeVar, @@ -24,6 +25,7 @@ class TypesRegistry: self.logger: logging.Logger = logging.getLogger("TypesRegistry") self._types: dict[str, Type] = {} self._members: dict[str, dict[str, Type]] = {} + self._predicates: dict[str, Predicate] = {} def get_type(self, name: str) -> Type: """Get a type from its name @@ -81,6 +83,11 @@ class TypesRegistry: else: members[member_name] = member_type + def define_predicate(self, name: str, predicate: Predicate): + if name in self._predicates: + raise ValueError(f"Predicate {name} already defined") + self._predicates[name] = predicate + def is_subtype(self, type1: Type, type2: Type) -> bool: """Check whether `type1` is a subtype of `type2` diff --git a/midas/checker/types.py b/midas/checker/types.py index 708d68b..5cb1beb 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -215,6 +215,12 @@ def unfold_type(type: Type) -> Type: return type +@dataclass(frozen=True, kw_only=True) +class Predicate: + type: Type + body: m.Expr + + Type = ( TopType | BaseType diff --git a/midas/parser/midas.py b/midas/parser/midas.py index 67f3859..1b60dc1 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -506,20 +506,20 @@ class MidasParser(Parser): PredicateStmt: the parsed predicate declaration statement """ keyword: Token = self.previous() + name: Token = self.consume_identifier("Expected predicate name") - self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject") - subject: Token = self.consume_identifier("Expected subject name") - self.consume(TokenType.COLON, "Expected ':' after subject name") - type: Type = self.type_expr() - self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject") + + params: list[ParamSpec] = [] + while self.check(TokenType.LEFT_PAREN): + params.append(self.function_args()) + self.consume(TokenType.EQUAL, "Expected '=' after predicate subject") - condition: Expr = self.constraint() + body: Expr = self.constraint() return PredicateStmt( location=keyword.location_to(self.previous()), name=name, - subject=subject, - type=type, - condition=condition, + params=params, + body=body, ) def function(self) -> FunctionType: diff --git a/tests/serializer/midas.py b/tests/serializer/midas.py index c3df47c..c3d09cf 100644 --- a/tests/serializer/midas.py +++ b/tests/serializer/midas.py @@ -80,9 +80,8 @@ class MidasAstJsonSerializer( return { "_type": "PredicateStmt", "name": stmt.name.lexeme, - "subject": stmt.subject.lexeme, - "type": stmt.type.accept(self), - "condition": stmt.condition.accept(self), + "params": [self._serialize_param_spec(spec) for spec in stmt.params], + "body": stmt.body.accept(self), } def visit_logical_expr(self, expr: LogicalExpr) -> dict: