feat(midas): generalize param spec of predicate and parse
This commit is contained in:
@@ -58,9 +58,8 @@ class ExtendStmt:
|
|||||||
|
|
||||||
class PredicateStmt:
|
class PredicateStmt:
|
||||||
name: Token
|
name: Token
|
||||||
subject: Token
|
params: list[ParamSpec]
|
||||||
type: Type
|
body: Expr
|
||||||
condition: Expr
|
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
|
|||||||
@@ -94,9 +94,8 @@ class ExtendStmt(Stmt):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class PredicateStmt(Stmt):
|
class PredicateStmt(Stmt):
|
||||||
name: Token
|
name: Token
|
||||||
subject: Token
|
params: list[ParamSpec]
|
||||||
type: Type
|
body: Expr
|
||||||
condition: Expr
|
|
||||||
|
|
||||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
return visitor.visit_predicate_stmt(self)
|
return visitor.visit_predicate_stmt(self)
|
||||||
|
|||||||
@@ -150,13 +150,17 @@ class MidasAstPrinter(
|
|||||||
self._write_line("PredicateStmt")
|
self._write_line("PredicateStmt")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||||
self._write_line(f'subject: "{stmt.subject.lexeme}"')
|
self._write_line("params")
|
||||||
self._write_line("type")
|
with self._child_level():
|
||||||
|
for i, spec in enumerate(stmt.params):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.params) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._visit_param_spec(spec)
|
||||||
|
|
||||||
|
self._write_line("body", last=True)
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
stmt.type.accept(self)
|
stmt.body.accept(self)
|
||||||
self._write_line("condition", last=True)
|
|
||||||
with self._child_level(single=True):
|
|
||||||
stmt.condition.accept(self)
|
|
||||||
|
|
||||||
# Expressions
|
# Expressions
|
||||||
|
|
||||||
@@ -397,10 +401,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
|
|
||||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||||
name: str = stmt.name.lexeme
|
name: str = stmt.name.lexeme
|
||||||
subject: str = stmt.subject.lexeme
|
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
|
||||||
type: str = stmt.type.accept(self)
|
body: str = stmt.body.accept(self)
|
||||||
condition: str = stmt.condition.accept(self)
|
return self.indented(f"predicate {name}{sig} = {body}")
|
||||||
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
|
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||||
left: str = expr.left.accept(self)
|
left: str = expr.left.accept(self)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from midas.checker.types import (
|
|||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
|
Predicate,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnknownType,
|
UnknownType,
|
||||||
@@ -45,6 +46,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
|
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
|
||||||
self.process(builtins_path.read_text(), str(builtins_path))
|
self.process(builtins_path.read_text(), str(builtins_path))
|
||||||
|
|
||||||
|
self._bool: Type = self.get_type("bool")
|
||||||
|
|
||||||
def process(self, source: str, path: Optional[str]):
|
def process(self, source: str, path: Optional[str]):
|
||||||
self.reporter = self.reporter.for_file(path)
|
self.reporter = self.reporter.for_file(path)
|
||||||
lexer: MidasLexer = MidasLexer(source)
|
lexer: MidasLexer = MidasLexer(source)
|
||||||
@@ -114,7 +117,24 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
)
|
)
|
||||||
|
|
||||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||||
self.reporter.warning(stmt.location, "PredicateStmt not yet supported")
|
params: list[TypedParamSpec] = [
|
||||||
|
self._visit_param_spec(spec) for spec in stmt.params
|
||||||
|
]
|
||||||
|
type: Type = self._bool
|
||||||
|
for spec in reversed(params):
|
||||||
|
type = Function(
|
||||||
|
pos_args=spec.pos,
|
||||||
|
args=spec.mixed,
|
||||||
|
kw_args=spec.kw,
|
||||||
|
returns=type,
|
||||||
|
)
|
||||||
|
self.types.define_predicate(
|
||||||
|
stmt.name.lexeme,
|
||||||
|
Predicate(
|
||||||
|
type=type,
|
||||||
|
body=stmt.body,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
||||||
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
|
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from midas.checker.types import (
|
|||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
|
Predicate,
|
||||||
TopType,
|
TopType,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
@@ -24,6 +25,7 @@ class TypesRegistry:
|
|||||||
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
|
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
|
||||||
self._types: dict[str, Type] = {}
|
self._types: dict[str, Type] = {}
|
||||||
self._members: dict[str, dict[str, Type]] = {}
|
self._members: dict[str, dict[str, Type]] = {}
|
||||||
|
self._predicates: dict[str, Predicate] = {}
|
||||||
|
|
||||||
def get_type(self, name: str) -> Type:
|
def get_type(self, name: str) -> Type:
|
||||||
"""Get a type from its name
|
"""Get a type from its name
|
||||||
@@ -81,6 +83,11 @@ class TypesRegistry:
|
|||||||
else:
|
else:
|
||||||
members[member_name] = member_type
|
members[member_name] = member_type
|
||||||
|
|
||||||
|
def define_predicate(self, name: str, predicate: Predicate):
|
||||||
|
if name in self._predicates:
|
||||||
|
raise ValueError(f"Predicate {name} already defined")
|
||||||
|
self._predicates[name] = predicate
|
||||||
|
|
||||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||||
"""Check whether `type1` is a subtype of `type2`
|
"""Check whether `type1` is a subtype of `type2`
|
||||||
|
|
||||||
|
|||||||
@@ -215,6 +215,12 @@ def unfold_type(type: Type) -> Type:
|
|||||||
return type
|
return type
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Predicate:
|
||||||
|
type: Type
|
||||||
|
body: m.Expr
|
||||||
|
|
||||||
|
|
||||||
Type = (
|
Type = (
|
||||||
TopType
|
TopType
|
||||||
| BaseType
|
| BaseType
|
||||||
|
|||||||
@@ -506,20 +506,20 @@ class MidasParser(Parser):
|
|||||||
PredicateStmt: the parsed predicate declaration statement
|
PredicateStmt: the parsed predicate declaration statement
|
||||||
"""
|
"""
|
||||||
keyword: Token = self.previous()
|
keyword: Token = self.previous()
|
||||||
|
|
||||||
name: Token = self.consume_identifier("Expected predicate name")
|
name: Token = self.consume_identifier("Expected predicate name")
|
||||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
|
||||||
subject: Token = self.consume_identifier("Expected subject name")
|
params: list[ParamSpec] = []
|
||||||
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
while self.check(TokenType.LEFT_PAREN):
|
||||||
type: Type = self.type_expr()
|
params.append(self.function_args())
|
||||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
|
||||||
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
|
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
|
||||||
condition: Expr = self.constraint()
|
body: Expr = self.constraint()
|
||||||
return PredicateStmt(
|
return PredicateStmt(
|
||||||
location=keyword.location_to(self.previous()),
|
location=keyword.location_to(self.previous()),
|
||||||
name=name,
|
name=name,
|
||||||
subject=subject,
|
params=params,
|
||||||
type=type,
|
body=body,
|
||||||
condition=condition,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def function(self) -> FunctionType:
|
def function(self) -> FunctionType:
|
||||||
|
|||||||
@@ -80,9 +80,8 @@ class MidasAstJsonSerializer(
|
|||||||
return {
|
return {
|
||||||
"_type": "PredicateStmt",
|
"_type": "PredicateStmt",
|
||||||
"name": stmt.name.lexeme,
|
"name": stmt.name.lexeme,
|
||||||
"subject": stmt.subject.lexeme,
|
"params": [self._serialize_param_spec(spec) for spec in stmt.params],
|
||||||
"type": stmt.type.accept(self),
|
"body": stmt.body.accept(self),
|
||||||
"condition": stmt.condition.accept(self),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
||||||
|
|||||||
Reference in New Issue
Block a user