diff --git a/gen/midas.py b/gen/midas.py index 42caf4f..287fcc3 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -26,6 +26,14 @@ class MemberKind(Enum): METHOD = auto() +@dataclass(frozen=True, kw_only=True) +class ParamSpec: + l_paren: Token + pos: list[FunctionType.Argument] + mixed: list[FunctionType.Argument] + kw: list[FunctionType.Argument] + + ###< @@ -50,9 +58,8 @@ class ExtendStmt: class PredicateStmt: name: Token - subject: Token - type: Type - condition: Expr + params: list[ParamSpec] + body: Expr ###< @@ -78,6 +85,12 @@ class UnaryExpr: right: Expr +class CallExpr: + callee: Expr + arguments: list[Expr] + keywords: dict[str, Expr] + + class GetExpr: expr: Expr name: Token @@ -128,9 +141,7 @@ class ExtensionType: class FunctionType: - pos_args: list[Argument] - args: list[Argument] - kw_args: list[Argument] + params: ParamSpec returns: Type @dataclass(frozen=True, kw_only=True) diff --git a/midas/ast/midas.py b/midas/ast/midas.py index e71aff9..1ece261 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -27,6 +27,14 @@ class MemberKind(Enum): METHOD = auto() +@dataclass(frozen=True, kw_only=True) +class ParamSpec: + l_paren: Token + pos: list[FunctionType.Argument] + mixed: list[FunctionType.Argument] + kw: list[FunctionType.Argument] + + ############## # Statements # ############## @@ -86,9 +94,8 @@ class ExtendStmt(Stmt): @dataclass(frozen=True) class PredicateStmt(Stmt): name: Token - subject: Token - type: Type - condition: Expr + params: list[ParamSpec] + body: Expr def accept(self, visitor: Stmt.Visitor[T]) -> T: return visitor.visit_predicate_stmt(self) @@ -116,6 +123,9 @@ class Expr(ABC): @abstractmethod def visit_unary_expr(self, expr: UnaryExpr) -> T: ... + @abstractmethod + def visit_call_expr(self, expr: CallExpr) -> T: ... + @abstractmethod def visit_get_expr(self, expr: GetExpr) -> T: ... @@ -161,6 +171,16 @@ class UnaryExpr(Expr): return visitor.visit_unary_expr(self) +@dataclass(frozen=True) +class CallExpr(Expr): + callee: Expr + arguments: list[Expr] + keywords: dict[str, Expr] + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_call_expr(self) + + @dataclass(frozen=True) class GetExpr(Expr): expr: Expr @@ -279,9 +299,7 @@ class ExtensionType(Type): @dataclass(frozen=True) class FunctionType(Type): - pos_args: list[Argument] - args: list[Argument] - kw_args: list[Argument] + params: ParamSpec returns: Type @dataclass(frozen=True, kw_only=True) diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 694c272..1c75a44 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -150,13 +150,17 @@ class MidasAstPrinter( self._write_line("PredicateStmt") with self._child_level(): self._write_line(f'name: "{stmt.name.lexeme}"') - self._write_line(f'subject: "{stmt.subject.lexeme}"') - self._write_line("type") + self._write_line("params") + with self._child_level(): + for i, spec in enumerate(stmt.params): + self._idx = i + if i == len(stmt.params) - 1: + self._mark_last() + self._visit_param_spec(spec) + + self._write_line("body", last=True) with self._child_level(single=True): - stmt.type.accept(self) - self._write_line("condition", last=True) - with self._child_level(single=True): - stmt.condition.accept(self) + stmt.body.accept(self) # Expressions @@ -195,6 +199,29 @@ class MidasAstPrinter( with self._child_level(single=True): expr.right.accept(self) + def visit_call_expr(self, expr: m.CallExpr) -> None: + self._write_line("CallExpr") + with self._child_level(): + self._write_line("callee") + with self._child_level(single=True): + expr.callee.accept(self) + self._write_line("arguments") + with self._child_level(): + for i, arg in enumerate(expr.arguments): + self._idx = i + if i == len(expr.arguments) - 1: + self._mark_last() + arg.accept(self) + self._write_line("keywords", last=True) + with self._child_level(): + for i, (name, arg) in enumerate(expr.keywords.items()): + self._idx = i + if i == len(expr.keywords) - 1: + self._mark_last() + self._write_line(name) + with self._child_level(single=True): + arg.accept(self) + def visit_get_expr(self, expr: m.GetExpr): self._write_line("GetExpr") with self._child_level(): @@ -276,34 +303,41 @@ class MidasAstPrinter( def visit_function_type(self, type: m.FunctionType) -> None: self._write_line("FunctionType") with self._child_level(): - self._write_line("pos_args") - with self._child_level(): - for i, arg in enumerate(type.pos_args): - self._idx = i - if i == len(type.pos_args) - 1: - self._mark_last() - self._print_function_arg(arg) - - self._write_line("args") - with self._child_level(): - for i, arg in enumerate(type.args): - self._idx = i - if i == len(type.args) - 1: - self._mark_last() - self._print_function_arg(arg) - - self._write_line("kw_args") - with self._child_level(): - for i, arg in enumerate(type.kw_args): - self._idx = i - if i == len(type.kw_args) - 1: - self._mark_last() - self._print_function_arg(arg) + self._write_line("params") + with self._child_level(single=True): + self._visit_param_spec(type.params) self._write_line("returns", last=True) with self._child_level(single=True): type.returns.accept(self) + def _visit_param_spec(self, spec: m.ParamSpec) -> None: + self._write_line("ParamSpec") + with self._child_level(): + self._write_line("pos") + with self._child_level(): + for i, arg in enumerate(spec.pos): + self._idx = i + if i == len(spec.pos) - 1: + self._mark_last() + self._print_function_arg(arg) + + self._write_line("mixed") + with self._child_level(): + for i, arg in enumerate(spec.mixed): + self._idx = i + if i == len(spec.mixed) - 1: + self._mark_last() + self._print_function_arg(arg) + + self._write_line("kw", last=True) + with self._child_level(): + for i, arg in enumerate(spec.kw): + self._idx = i + if i == len(spec.kw) - 1: + self._mark_last() + self._print_function_arg(arg) + def _print_function_arg(self, arg: m.FunctionType.Argument) -> None: self._write_line("Argument") with self._child_level(): @@ -367,10 +401,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] def visit_predicate_stmt(self, stmt: m.PredicateStmt): name: str = stmt.name.lexeme - subject: str = stmt.subject.lexeme - type: str = stmt.type.accept(self) - condition: str = stmt.condition.accept(self) - return self.indented(f"predicate {name}({subject}: {type}) = {condition}") + sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params) + body: str = stmt.body.accept(self) + return self.indented(f"predicate {name}{sig} = {body}") def visit_logical_expr(self, expr: m.LogicalExpr): left: str = expr.left.accept(self) @@ -389,6 +422,12 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] right: str = expr.right.accept(self) return f"{operator}{right}" + def visit_call_expr(self, expr: m.CallExpr) -> str: + args: list[str] = [arg.accept(self) for arg in expr.arguments] + [ + f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items() + ] + return f"{expr.callee.accept(self)}({', '.join(args)})" + def visit_get_expr(self, expr: m.GetExpr): expr_: str = expr.expr.accept(self) name: str = expr.name.lexeme @@ -436,9 +475,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] return f"{type.base.accept(self)} & {type.extension.accept(self)}" def visit_function_type(self, type: m.FunctionType) -> str: - pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args] - mixed_args: list[str] = [self._print_arg(arg) for arg in type.args] - kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args] + spec: str = self._visit_param_spec(type.params) + return f"fn {spec} -> {type.returns.accept(self)}" + + def _visit_param_spec(self, spec: m.ParamSpec) -> str: + pos_args: list[str] = [self._print_arg(arg) for arg in spec.pos] + mixed_args: list[str] = [self._print_arg(arg) for arg in spec.mixed] + kw_args: list[str] = [self._print_arg(arg) for arg in spec.kw] args: list[str] = pos_args if len(pos_args) != 0: @@ -447,8 +490,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] if len(kw_args) != 0: args.append("*") args += kw_args - - return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}" + return f"({', '.join(args)})" def _print_arg(self, arg: m.FunctionType.Argument) -> str: res: str = "" diff --git a/midas/checker/builtins.py b/midas/checker/builtins.py index 7bf1d98..daf0015 100644 --- a/midas/checker/builtins.py +++ b/midas/checker/builtins.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: BUILTIN_SUBTYPES: dict[str, set[str]] = { + "object": {"float", "list", "dict"}, "float": {"int"}, "int": {"bool"}, } diff --git a/midas/checker/midas.py b/midas/checker/midas.py index 3764c03..60d2b85 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -1,27 +1,64 @@ import logging +from dataclasses import dataclass from pathlib import Path from typing import Optional import midas.ast.midas as m +from midas.ast.location import Location from midas.checker.builtins import define_builtins +from midas.checker.environment import Environment +from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS +from midas.checker.preamble import Preamble from midas.checker.registry import TypesRegistry from midas.checker.reporter import FileReporter, Reporter from midas.checker.types import ( AliasType, + AppliedType, ComplexType, + ConstraintType, ExtensionType, Function, GenericType, + OverloadedFunction, + Predicate, Type, TypeVar, UnknownType, + unfold_type, ) from midas.lexer.midas import MidasLexer from midas.lexer.token import Token from midas.parser.midas import MidasParser -class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]): +@dataclass(frozen=True, kw_only=True) +class TypedParamSpec: + pos: list[Function.Argument] + mixed: list[Function.Argument] + kw: list[Function.Argument] + + +TypedExpr = tuple[m.Expr, Type] + + +class ReturnException(Exception): + pass + + +@dataclass(frozen=True, kw_only=True) +class MappedArgument: + expr: m.Expr + type: Type + argument: Function.Argument + + +@dataclass(frozen=True, kw_only=True) +class OverloadCandidate: + function: Function + mapped: list[MappedArgument] + + +class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]): """A resolver which evaluates Midas type definitions and build a registry""" def __init__(self, types: TypesRegistry, reporter: Reporter) -> None: @@ -31,12 +68,18 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type self.types: TypesRegistry = types self._local_variables: dict[str, TypeVar] = {} + self._predicate_params: dict[str, Type] = {} + self._current_name: Optional[str] = None define_builtins(self.types) builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve() self.process(builtins_path.read_text(), str(builtins_path)) + self._bool: Type = self.get_type("bool") + + self._preamble: Environment = Preamble(self.types) + def process(self, source: str, path: Optional[str]): self.reporter = self.reporter.for_file(path) lexer: MidasLexer = MidasLexer(source) @@ -47,6 +90,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type self.reporter.error(error.token.get_location(), error.message) self.resolve(stmts) + def type_of(self, expr: m.Expr) -> Type: + type: Type = expr.accept(self) + return type + def get_type(self, name: str) -> Type: """Get a type from its name @@ -63,6 +110,19 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type return self._local_variables[name] return self.types.get_type(name) + def get_variable(self, name: str) -> Type: + if name in self._predicate_params: + return self._predicate_params[name] + predicate: Optional[Predicate] = self.types.lookup_predicate(name) + if predicate is not None: + return predicate.type + + global_: Optional[Type] = self._preamble.get(name) + if global_ is not None: + return global_ + + raise NameError(f"Unknown variable '{name}'") + def resolve(self, stmts: list[m.Stmt]): """Process a sequence of statements @@ -72,6 +132,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type for stmt in stmts: stmt.accept(self) + def assert_bool(self, expr: m.Expr): + type: Type = self.type_of(expr) + if not self.types.is_subtype(type, self._bool): + self.reporter.error(expr.location, f"Must be a boolean but is {type}") + def visit_type_stmt(self, stmt: m.TypeStmt) -> None: name: str = stmt.name.lexeme self._current_name = name @@ -106,31 +171,163 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type ) def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: - self.reporter.warning(stmt.location, "PredicateStmt not yet supported") + for spec in stmt.params: + for param in spec.mixed: + assert param.name is not None + self._predicate_params[param.name.lexeme] = param.type.accept(self) - def visit_logical_expr(self, expr: m.LogicalExpr) -> None: - self.reporter.warning(expr.location, "LogicalExpr not yet supported") + type: Type = self.type_of(stmt.body) + params: list[TypedParamSpec] = [ + self._visit_param_spec(spec) for spec in stmt.params + ] - def visit_binary_expr(self, expr: m.BinaryExpr) -> None: - self.reporter.warning(expr.location, "BinaryExpr not yet supported") + if not self._is_valid_predicate(type): + self.reporter.error( + stmt.body.location, + f"Predicate function body must evaluate to a boolean, got {type}", + ) + if len(params) != 0: + type = self._bool + for spec in reversed(params): + type = Function( + pos_args=spec.pos, + args=spec.mixed, + kw_args=spec.kw, + returns=type, + ) + self._predicate_params = {} + self.types.define_predicate( + stmt.name.lexeme, + Predicate( + type=type, + body=stmt.body, + alias=len(params) == 0, + ), + ) - def visit_unary_expr(self, expr: m.UnaryExpr) -> None: - self.reporter.warning(expr.location, "UnaryExpr not yet supported") + def _is_valid_predicate(self, body: Type) -> bool: + match body: + case Function(returns=returns): + return self._is_valid_predicate(returns) + case _ if self.types.is_subtype(body, self._bool): + return True + case _: + return False - def visit_get_expr(self, expr: m.GetExpr) -> None: - self.reporter.warning(expr.location, "GetExpr not yet supported") + def visit_logical_expr(self, expr: m.LogicalExpr) -> Type: + self.assert_bool(expr.left) + self.assert_bool(expr.right) + return self._bool - def visit_variable_expr(self, expr: m.VariableExpr) -> None: - self.reporter.warning(expr.location, "VariableExpr not yet supported") + def visit_binary_expr(self, expr: m.BinaryExpr) -> Type: + method: Optional[str] = MIDAS_BINARY_METHODS.get(expr.operator.type) + if method is None: + self.logger.warning(f"Unsupported operator {expr.operator.lexeme}") + self.reporter.warning( + expr.location, f"Unsupported operator {expr.operator.lexeme}" + ) + return UnknownType() - def visit_grouping_expr(self, expr: m.GroupingExpr) -> None: + return self._visit_binary_expr(expr.location, expr.left, expr.right, method) + + def _visit_binary_expr( + self, location: Location, left_expr: m.Expr, right_expr: m.Expr, method: str + ) -> Type: + left: Type = self.type_of(left_expr) + right: Type = self.type_of(right_expr) + + operation: Optional[Type] = self.types.lookup_member(left, method) + if operation is None: + self.reporter.error( + location, + f"Undefined operation {method} between {left} and {right}", + ) + return UnknownType() + + result: Optional[Type] = self._get_call_result( + location, + operation, + [(right_expr, right)], + {}, + ) + return result or UnknownType() + + def visit_unary_expr(self, expr: m.UnaryExpr) -> Type: + method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type) + if method is None: + self.logger.warning(f"Unsupported operator {expr.operator.lexeme}") + self.reporter.warning( + expr.location, f"Unsupported operator {expr.operator.lexeme}" + ) + return UnknownType() + + operand: Type = self.type_of(expr.right) + operation: Optional[Type] = self.types.lookup_member(operand, method) + if operation is None: + self.reporter.error( + expr.location, + f"Undefined operation {method} for {operand}", + ) + return UnknownType() + + result: Optional[Type] = self._get_call_result( + expr.location, + operation, + [], + {}, + ) + return result or UnknownType() + + def visit_call_expr(self, expr: m.CallExpr) -> Type: + callee: Type = expr.callee.accept(self) + positional: list[TypedExpr] = [ + (arg, self.type_of(arg)) for arg in expr.arguments + ] + keywords: dict[str, TypedExpr] = { + name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items() + } + return ( + self._get_call_result( + expr.location, + callee, + positional, + keywords, + ) + or UnknownType() + ) + + def visit_get_expr(self, expr: m.GetExpr) -> Type: + object: Type = expr.expr.accept(self) + member: Optional[Type] = self.types.lookup_member(object, expr.name.lexeme) + if member is None: + self.reporter.error( + expr.location, f"Unknown member '{expr.name.lexeme}' of {object}" + ) + return UnknownType() + return member + + def visit_variable_expr(self, expr: m.VariableExpr) -> Type: + return self.get_variable(expr.name.lexeme) + + def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type: return expr.expr.accept(self) - def visit_literal_expr(self, expr: m.LiteralExpr) -> None: - self.reporter.warning(expr.location, "LiteralExpr not yet supported") + def visit_literal_expr(self, expr: m.LiteralExpr) -> Type: + match expr.value: + case bool(): # Must be before int + return self.types.get_type("bool") + case int(): + return self.types.get_type("int") + case float(): + return self.types.get_type("float") + case str(): + return self.types.get_type("str") + case _: + self.reporter.warning(expr.location, f"Unknown literal {expr}") + return UnknownType() - def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: - self.reporter.warning(expr.location, "WildcardExpr not yet supported") + def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type: + return self.get_variable("_") def visit_named_type(self, type: m.NamedType) -> Type: name: str = type.name.lexeme @@ -153,10 +350,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type return UnknownType() def visit_constraint_type(self, type: m.ConstraintType) -> Type: - type_: Type = type.type.accept(self) - type.constraint.accept(self) - # TODO - return UnknownType() + return ConstraintType( + type=type.type.accept(self), + constraint=type.constraint, + ) def visit_complex_type(self, type: m.ComplexType) -> ComplexType: return ComplexType( @@ -172,8 +369,17 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type ) def visit_function_type(self, type: m.FunctionType) -> Type: - n_pos_args: int = len(type.pos_args) - n_args: int = len(type.args) + params: TypedParamSpec = self._visit_param_spec(type.params) + return Function( + pos_args=params.pos, + args=params.mixed, + kw_args=params.kw, + returns=type.returns.accept(self), + ) + + def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec: + n_pos: int = len(spec.pos) + n_mixed: int = len(spec.mixed) def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument: return Function.Argument( @@ -183,14 +389,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type required=arg.required, ) - return Function( - pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)], - args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)], - kw_args=[ - process_arg(arg, i + n_pos_args + n_args) - for i, arg in enumerate(type.kw_args) - ], - returns=type.returns.accept(self), + return TypedParamSpec( + pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)], + mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)], + kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)], ) def _resolve_type_params(self, params: list[m.TypeParam]): @@ -204,3 +406,343 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type self._local_variables[name] = var vars.append(var) return vars + + def _get_call_result( + self, + location: Location, + callee: Type, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + report_errors: bool = True, + ) -> Optional[Type]: + """Get the result type of a function call + + If the function has overloads, the function will try to resolve the + appropriate signature. + Argument types are matched to the defined parameters. + The function doesn't take the raw expression as a parameter to accommodate + for desugared calls such as for operators. + + Args: + location (Location): the call location + callee (Type): the called function + positional (list[TypedExpr]): the list positional arguments + keywords (dict[str, TypedExpr]): the map of keyword arguments + report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. + + Returns: + Type: the return type of the call, or `None` if either + the call is invalid or no overload matched the arguments uniquely + """ + match callee: + case Function() as function: + valid: bool + mapped: list[MappedArgument] + valid, mapped = self.map_call_arguments( + function, location, positional, keywords + ) + valid = valid and self._are_arguments_valid(mapped, report_errors) + if not valid: + return None + return function.returns + + case OverloadedFunction(overloads=overloads): + function = self._match_overload( + overloads, location, positional, keywords, report_errors + ) + if function is None: + return None + return function.returns + + case AppliedType(body=body): + return self._get_call_result( + location, body, positional, keywords, report_errors + ) + + case UnknownType(): + return UnknownType() + + case _: + if report_errors: + self.reporter.error(location, f"{callee} is not callable") + return None + + def _are_arguments_valid( + self, + arguments: list[MappedArgument], + report_errors: bool = True, + ) -> bool: + """Check whether the passed argument types correspond to their matched parameter definitions + + Args: + arguments (list[MappedArgument]): the list of argument/parameter pairs + report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. + + Returns: + bool: True if all arguments fit the matching parameter definitions, False otherwise + """ + valid: bool = True + for arg in arguments: + if not self.types.is_subtype(arg.type, arg.argument.type): + if report_errors: + self.reporter.error( + arg.expr.location, + f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", + ) + valid = False + return valid + + def _match_overload( + self, + overloads: list[Type], + location: Location, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + report_errors: bool = True, + ) -> Optional[Function]: + """Try and resolve the appropriate overload for the given arguments + + Args: + overloads (list[Type]): the list of possible overloads + location (Location): the call location + positional (list[TypedExpr]): the list of positional arguments + keywords (dict[str, TypedExpr]): the map of keywords arguments + report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. + + Returns: + Optional[Function]: the resolved function signature if it can be + determined unambiguously, or `None`. + """ + candidates: list[OverloadCandidate] = [] + for overload in overloads: + function: Type = unfold_type(overload) + if not isinstance(function, Function): + if report_errors: + self.logger.error( + f"Overload is not a function: {overload} is {function}" + ) + continue + valid, mapped = self.map_call_arguments( + function=function, + location=location, + positional=positional, + keywords=keywords, + report_errors=False, + ) + if valid and self._are_arguments_valid(mapped, report_errors=False): + candidates.append( + OverloadCandidate( + function=function, + mapped=mapped, + ) + ) + + pos_types: str = ", ".join(str(type) for _, type in positional) + kw_types: str = ", ".join( + f"{name}: {type}" for name, (_, type) in keywords.items() + ) + for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}" + + n_candidates: int = len(candidates) + + # Exactly 1 match -> return it + if n_candidates == 1: + return candidates[0].function + + # No match -> invalid call + if n_candidates == 0: + overloads_str: str = ", ".join(map(str, overloads)) + if report_errors: + self.reporter.error( + location, + f"No matching overload in [{overloads_str}] {for_args}", + ) + return None + + # Multiple matches -> see if one <: all others (more specific) + for i1, c1 in enumerate(candidates): + mapped1: list[MappedArgument] = c1.mapped + best_match: bool = True + for i2, c2 in enumerate(candidates): + if i1 == i2: + continue + mapped2: list[MappedArgument] = c2.mapped + if not self._are_mapped_subtypes(mapped1, mapped2): + best_match = False + break + self.logger.debug(f"{c1.function} is a full overload of {c2.function}") + if best_match: + return c1.function + + candidates_str: str = ", ".join( + str(candidate.function) for candidate in candidates + ) + if report_errors: + self.reporter.error( + location, + f"Multiple matching overloads {for_args}: {candidates_str}", + ) + return None + + def map_call_arguments( + self, + function: Function, + location: Location, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + report_errors: bool = True, + ) -> tuple[bool, list[MappedArgument]]: + """Map call arguments to a function's parameters as defined in its signature + + This method maps positional-only, keyword-only and mixed parameter definitions + with the arguments passed at the call site + + Any mismatched, missing or unexpected argument is reported as a diagnostic, + unless `report_errors` is set to `False` + + Args: + function (Function): the function definition + location (Location): the call location + positional (list[TypedExpr]): the list of positional arguments + keywords (dict[str, TypedExpr]): the map of keyword arguments + report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. + + Returns: + tuple[bool, list[MappedArgument]]: a boolean reporting whether + the call is valid and the list of mapped arguments + """ + set_args: set[str] = set() + + required_positional: list[str] = [ + arg.name for arg in function.pos_args + function.args if arg.required + ] + required_keyword: list[str] = [ + arg.name for arg in function.kw_args if arg.required + ] + + mapped: list[MappedArgument] = [] + + pos_params: list[Function.Argument] = list(function.pos_args) + mixed_params: list[Function.Argument] = list(function.args) + kw_params: dict[str, Function.Argument] = { + arg.name: arg for arg in function.kw_args + } + + valid_call: bool = True + + # TODO: handle *args and **kwargs sinks + for arg in positional: + param: Function.Argument + if len(pos_params) != 0: + param = pos_params.pop(0) + elif len(mixed_params) != 0: + param = mixed_params.pop(0) + else: + if report_errors: + self.reporter.error( + arg[0].location, "Too many positional arguments" + ) + valid_call = False + break + name: str = param.name + if name in required_positional: + required_positional.remove(name) + if name in required_keyword: + required_keyword.remove(name) + set_args.add(name) + mapped.append( + MappedArgument( + expr=arg[0], + type=arg[1], + argument=param, + ) + ) + + kw_params.update({arg.name: arg for arg in mixed_params}) + for name, arg in keywords.items(): + param: Function.Argument + if name not in kw_params: + if report_errors: + if name in set_args: + self.reporter.error( + arg[0].location, f"Multiple values for argument '{name}'" + ) + else: + self.reporter.error( + arg[0].location, f"Unknown keyword argument '{name}'" + ) + valid_call = False + continue + param = kw_params.pop(name) + if name in required_positional: + required_positional.remove(name) + if name in required_keyword: + required_keyword.remove(name) + set_args.add(name) + mapped.append( + MappedArgument( + expr=arg[0], + type=arg[1], + argument=param, + ) + ) + + def join_args(args: list[str]) -> str: + args = list(map(lambda a: f"'{a}'", args)) + if len(args) == 0: + return "" + if len(args) == 1: + return args[0] + return ", ".join(args[:-1]) + " and " + args[-1] + + if len(required_positional) != 0: + plural: str = "" if len(required_positional) == 1 else "s" + args: str = join_args(required_positional) + if report_errors: + self.reporter.error( + location, + f"Missing required positional argument{plural}: {args}", + ) + valid_call = False + + if len(required_keyword) != 0: + plural: str = "" if len(required_keyword) == 1 else "s" + args: str = join_args(required_keyword) + if report_errors: + self.reporter.error( + location, + f"Missing required keyword argument{plural}: {args}", + ) + valid_call = False + + return valid_call, mapped + + def _are_mapped_subtypes( + self, mapped1: list[MappedArgument], mapped2: list[MappedArgument] + ) -> bool: + """Check whether the given argument mappings are subtype/supertype of one another + + This function checks whether the argument mappings `mapped1` are subtypes + of `mapped2`. If any of the parameter type in `mapped1` is not a subtype + of the corresponding parameter in `mapped2`, `False` is returned. + + This is used to check whether a given overload is + a more specific function/ a subtype of another. + + Args: + mapped1 (list[MappedArgument]): the first argument mappings (subtype) + mapped2 (list[MappedArgument]): the second argument mappings (supertype) + + Returns: + bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise + """ + by_expr: dict[m.Expr, Type] = {} + for arg in mapped1: + by_expr[arg.expr] = arg.argument.type + + for arg in mapped2: + type2: Type = arg.argument.type + type1: Type = by_expr[arg.expr] + if not self.types.is_subtype(type1, type2): + return False + return True diff --git a/midas/checker/operators.py b/midas/checker/operators.py index 58af88c..f9354a0 100644 --- a/midas/checker/operators.py +++ b/midas/checker/operators.py @@ -1,7 +1,9 @@ import ast from typing import Type -OPERATOR_METHODS: dict[Type[ast.operator], str] = { +from midas.lexer.token import TokenType + +PY_OPERATOR_METHODS: dict[Type[ast.operator], str] = { ast.Add: "__add__", ast.Sub: "__sub__", ast.Mult: "__mul__", @@ -17,9 +19,9 @@ OPERATOR_METHODS: dict[Type[ast.operator], str] = { ast.FloorDiv: "__floordiv__", } -COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = { +PY_COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = { ast.Eq: "__eq__", - # ast.NotEq: "__noteq__", + ast.NotEq: "__eq__", ast.Lt: "__lt__", ast.LtE: "__le__", ast.Gt: "__gt__", @@ -30,9 +32,40 @@ COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = { # ast.NotIn: "__notin__", } -UNARY_METHODS: dict[Type[ast.unaryop], str] = { +PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = { ast.Invert: "__invert__", # ast.Not: "", ast.UAdd: "__pos__", ast.USub: "__neg__", } + + +MIDAS_BINARY_METHODS: dict[TokenType, str] = { + # TokenType.PLUS: "__add__", + TokenType.MINUS: "__sub__", + TokenType.STAR: "__mul__", + TokenType.SLASH: "__truediv__", + # TokenType.MODULO: "__mod__", + # TokenType.POW: "__pow__", + # ast.BitOr: "__or__", + # ast.BitXor: "__xor__", + # ast.BitAnd: "__and__", + # ast.FloorDiv: "__floordiv__", + TokenType.EQUAL_EQUAL: "__eq__", + TokenType.BANG_EQUAL: "__eq__", + TokenType.LESS: "__lt__", + TokenType.LESS_EQUAL: "__le__", + TokenType.GREATER: "__gt__", + TokenType.GREATER_EQUAL: "__ge__", + # ast.Is: "__is__", + # ast.IsNot: "__isnot__", + # ast.In: "__in__", + # ast.NotIn: "__notin__", +} + +MIDAS_UNARY_METHODS: dict[TokenType, str] = { + # ast.Invert: "__invert__", + # ast.Not: "", + # TokenType.PLUS: "__pos__", + TokenType.MINUS: "__neg__", +} diff --git a/midas/checker/python.py b/midas/checker/python.py index c4bffff..6b2892b 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -6,7 +6,11 @@ from typing import Optional import midas.ast.python as p from midas.ast.location import Location from midas.checker.environment import Environment -from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS +from midas.checker.operators import ( + PY_COMPARATOR_METHODS, + PY_OPERATOR_METHODS, + PY_UNARY_METHODS, +) from midas.checker.preamble import Preamble from midas.checker.registry import TypesRegistry from midas.checker.reporter import FileReporter, Reporter @@ -376,7 +380,7 @@ class PythonTyper( pass def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: - method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) + method: Optional[str] = PY_OPERATOR_METHODS.get(expr.operator.__class__) if method is None: self.logger.warning(f"Unsupported operator {expr.operator}") self.reporter.warning( @@ -387,7 +391,7 @@ class PythonTyper( return self._visit_binary_expr(expr.location, expr.left, expr.right, method) def visit_compare_expr(self, expr: p.CompareExpr) -> Type: - method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__) + method: Optional[str] = PY_COMPARATOR_METHODS.get(expr.operator.__class__) if method is None: self.logger.warning(f"Unsupported operator {expr.operator}") self.reporter.warning( @@ -420,7 +424,7 @@ class PythonTyper( return result or UnknownType() def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: - method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__) + method: Optional[str] = PY_UNARY_METHODS.get(expr.operator.__class__) if method is None: self.logger.warning(f"Unsupported operator {expr.operator}") self.reporter.warning( @@ -652,7 +656,7 @@ class PythonTyper( If the function has overloads, the function will try to resolve the appropriate signature. Argument types are matched to the defined parameters. - The function doesn't take the raw expression as a parameter to accomodate + The function doesn't take the raw expression as a parameter to accommodate for desugared calls such as for operators. Args: @@ -743,7 +747,7 @@ class PythonTyper( Returns: Optional[Function]: the resolved function signature if it can be - determined unambigously, or `None`. + determined unambiguously, or `None`. """ candidates: list[OverloadCandidate] = [] for overload in overloads: diff --git a/midas/checker/registry.py b/midas/checker/registry.py index 6591548..fa2d1bd 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -7,10 +7,12 @@ from midas.checker.types import ( AppliedType, BaseType, ComplexType, + ConstraintType, ExtensionType, Function, GenericType, OverloadedFunction, + Predicate, TopType, Type, TypeVar, @@ -24,6 +26,7 @@ class TypesRegistry: self.logger: logging.Logger = logging.getLogger("TypesRegistry") self._types: dict[str, Type] = {} self._members: dict[str, dict[str, Type]] = {} + self._predicates: dict[str, Predicate] = {} def get_type(self, name: str) -> Type: """Get a type from its name @@ -81,6 +84,11 @@ class TypesRegistry: else: members[member_name] = member_type + def define_predicate(self, name: str, predicate: Predicate): + if name in self._predicates: + raise ValueError(f"Predicate {name} already defined") + self._predicates[name] = predicate + def is_subtype(self, type1: Type, type2: Type) -> bool: """Check whether `type1` is a subtype of `type2` @@ -123,6 +131,9 @@ class TypesRegistry: return False return self.is_subtype(bound, type2) + case (ConstraintType(type=base1), _): + return self.is_subtype(base1, type2) + return False # TODO: verify the logic in here @@ -345,3 +356,6 @@ class TypesRegistry: case _: self.logger.debug(f"Can't get member on {type}") return None + + def lookup_predicate(self, name: str) -> Optional[Predicate]: + return self._predicates.get(name) diff --git a/midas/checker/types.py b/midas/checker/types.py index 708d68b..82a08ba 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -1,7 +1,10 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, assert_never + +import midas.ast.midas as m +from midas.ast.printer import MidasPrinter @dataclass(frozen=True, kw_only=True) @@ -130,6 +133,16 @@ class AppliedType: return f"{self.name}[{', '.join(map(str, self.args))}]" +@dataclass(frozen=True, kw_only=True) +class ConstraintType: + type: Type + constraint: m.Expr + + def __str__(self) -> str: + printer = MidasPrinter() + return f"{self.type} where {printer.print(self.constraint)}" + + def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: def sub_argument(arg: Function.Argument): return Function.Argument( @@ -195,6 +208,12 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: body=substitute_typevars(body, substitutions), ) + case ConstraintType(): + return ConstraintType( + type=substitute_typevars(type.type, substitutions), + constraint=type.constraint, + ) + case TypeVar(name=name): if name in substitutions: return substitutions[name] @@ -203,9 +222,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: case UnknownType() | UnitType(): return type - case _: + case TopType() | GenericType(): raise NotImplementedError(f"Unsupported type {type}") + # Ensure exhaustiveness + case _: + assert_never(type) + def unfold_type(type: Type) -> Type: match type: @@ -215,6 +238,65 @@ def unfold_type(type: Type) -> Type: return type +def to_annotation(type: Type) -> str: + def _args_annotation(func: Function) -> str: + if len(func.kw_args) != 0: + return "..." + + args: str = ", ".join( + to_annotation(arg.type) for arg in func.pos_args + func.args + ) + return f"[{args}]" + + match type: + case TopType(): + return "Any" + + case BaseType(name=name): + return name + + case AliasType(name=name): + return name + + case UnknownType(): + return "Any" + + case UnitType(): + return "None" + + case Function(returns=returns): + params_annot: str = _args_annotation(type) + return f"Callable[{params_annot}, {to_annotation(returns)}]" + + case OverloadedFunction(): + return "Callable" + + case ComplexType() | ExtensionType(): + raise NotImplementedError + + case TypeVar(name=name): + return name + + case GenericType(name=name, params=params): + return f"{name}[{', '.join(map(to_annotation, params))}]" + + case AppliedType(name=name, args=args): + return f"{name}[{', '.join(map(to_annotation, args))}]" + + case ConstraintType(): + return str(type) + + case _: + assert_never(type) + + +@dataclass(frozen=True, kw_only=True) +class Predicate: + type: Type + body: m.Expr + alias: bool + + Type = ( TopType | BaseType @@ -228,4 +310,5 @@ Type = ( | TypeVar | GenericType | AppliedType + | ConstraintType ) diff --git a/midas/cli/commands/compile.py b/midas/cli/commands/compile.py index 5a410f7..5a623ec 100644 --- a/midas/cli/commands/compile.py +++ b/midas/cli/commands/compile.py @@ -38,5 +38,5 @@ def compile( if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)): sys.exit(1) - generator = Generator(workdir=source_path.parent) + generator = Generator(workdir=source_path.parent, types=checker.types) generator.generate(typed_ast, source_path) diff --git a/midas/cli/commands/registry.py b/midas/cli/commands/registry.py index d978ad9..41fc616 100644 --- a/midas/cli/commands/registry.py +++ b/midas/cli/commands/registry.py @@ -8,6 +8,7 @@ from typing import TextIO import click +from midas.ast.printer import MidasPrinter from midas.checker.checker import TypeChecker from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type @@ -35,6 +36,7 @@ def dump_registry( for types_file in types: checker.import_midas(Path(types_file.name).resolve()) + print("##### Types #####") for name, type in checker.types._types.items(): members: dict[str, Type] = checker.types._members.get(name, {}) print(f"{name} = {base_type(type)}") @@ -42,3 +44,17 @@ def dump_registry( print(" " * 4 + "Members:") for member_name, member_type in members.items(): print(" " * 8 + f"{member_name}: {member_type}") + + print("##### Predicates #####") + printer = MidasPrinter() + for name, predicate in checker.types._predicates.items(): + body: str = printer.print(predicate.body) + if predicate.alias: + print(f"{name}: {predicate.type} = {body}") + else: + print(f"{name}{predicate.type}:") + body = "\n".join( + " " + ("return " if i == 0 else "") + line + for i, line in enumerate(body.split("\n")) + ) + print(body) diff --git a/midas/generator/constraints.py b/midas/generator/constraints.py new file mode 100644 index 0000000..e840b42 --- /dev/null +++ b/midas/generator/constraints.py @@ -0,0 +1,224 @@ +import ast +from typing import Optional + +import midas.ast.midas as m +from midas.checker.registry import TypesRegistry +from midas.checker.types import ( + Function, + Predicate, + Type, + to_annotation, +) +from midas.lexer.token import TokenType + +LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = { + TokenType.AND: ast.And, + # TokenType.OR: ast.Or, +} + +BINARY_OPERATORS: dict[TokenType, type[ast.operator]] = { + # TokenType.PLUS: ast.Add, + TokenType.MINUS: ast.Sub, + TokenType.STAR: ast.Mult, + TokenType.SLASH: ast.Div, +} + +UNARY_OPERATORS: dict[TokenType, type[ast.unaryop]] = { + # TokenType.PLUS: ast.UAdd, + TokenType.MINUS: ast.USub, +} + +COMPARISON_OPERATORS: dict[TokenType, type[ast.cmpop]] = { + TokenType.GREATER: ast.Gt, + TokenType.GREATER_EQUAL: ast.GtE, + TokenType.LESS: ast.Lt, + TokenType.LESS_EQUAL: ast.LtE, + TokenType.EQUAL_EQUAL: ast.Eq, + TokenType.BANG_EQUAL: ast.NotEq, +} + + +class ConstraintGenerator(m.Expr.Visitor[ast.expr]): + def __init__(self, types: TypesRegistry): + self.types: TypesRegistry = types + self._id: int = 0 + self._definitions: list[ast.stmt] = [] + self._aliases: dict[str, str] = {} + + def get_definitions(self) -> list[ast.stmt]: + return self._definitions + + def generate(self, expr: m.Expr) -> ast.expr: + match expr: + case m.VariableExpr(): + return expr.accept(self) + case _: + func = Function( + pos_args=[], + args=[ + Function.Argument( + pos=0, + name="_", + type=self.types.get_type("Any"), + required=True, + ) + ], + kw_args=[], + returns=self.types.get_type("bool"), + ) + alias: str = self.make_alias(None) + definition: ast.stmt = self.make_definition( + alias, Predicate(type=func, body=expr, alias=False) + ) + self._definitions.append(definition) + return ast.Name(id=alias) + + def make_alias(self, name: Optional[str]) -> str: + suffix: str + if name is None: + suffix = f"p{self._id}" + self._id += 1 + else: + suffix = name + alias: str = f"__midas_{suffix}__" + return alias + + def make_definition(self, name: str, predicate: Predicate) -> ast.stmt: + body: ast.expr = predicate.body.accept(self) + if predicate.alias: + return ast.Assign( + targets=[ + ast.Name(id=name), + ], + value=body, + ) + return self.make_func(name, [ast.Return(value=body)], predicate.type) + + def make_args(self, func: Function) -> ast.arguments: + return ast.arguments( + posonlyargs=[ + ast.arg( + arg=arg.name, + annotation=ast.Constant(value=to_annotation(arg.type)), + ) + for arg in func.pos_args + ], + args=[ + ast.arg( + arg=arg.name, + annotation=ast.Constant(value=to_annotation(arg.type)), + ) + for arg in func.args + ], + kwonlyargs=[ + ast.arg( + arg=arg.name, + annotation=ast.Constant(value=to_annotation(arg.type)), + ) + for arg in func.kw_args + ], + defaults=[], + kw_defaults=[], + ) + + def make_func( + self, name: str, inner_body: list[ast.stmt], type: Type, level: int = 0 + ) -> ast.stmt: + match type: + case Function(returns=Function()): + inner_name: str = f"inner{level}" + return ast.FunctionDef( + name=name, + args=self.make_args(type), + body=[ + self.make_func(inner_name, inner_body, type.returns, level + 1), + ast.Return(value=ast.Name(id=inner_name)), + ], + returns=ast.Constant(value=to_annotation(type.returns)), + decorator_list=[], + ) + + case Function(): + return ast.FunctionDef( + name=name, + args=self.make_args(type), + body=inner_body, + returns=ast.Constant(value=to_annotation(type.returns)), + decorator_list=[], + ) + + case _: + raise ValueError(f"Expected function, got {type!r}") + + def get_predicate(self, name: str) -> Optional[ast.expr]: + if name not in self._aliases: + predicate: Optional[Predicate] = self.types.lookup_predicate(name) + if predicate is None: + return None + alias: str = self.make_alias(name) + self._aliases[name] = alias + self._definitions.append(self.make_definition(alias, predicate)) + + return ast.Name(id=self._aliases[name]) + + def visit_logical_expr(self, expr: m.LogicalExpr) -> ast.expr: + return ast.BoolOp( + op=LOGICAL_OPERATORS[expr.operator.type](), + values=[ + expr.left.accept(self), + expr.right.accept(self), + ], + ) + + def visit_binary_expr(self, expr: m.BinaryExpr) -> ast.expr: + op: TokenType = expr.operator.type + if op in BINARY_OPERATORS: + return ast.BinOp( + left=expr.left.accept(self), + op=BINARY_OPERATORS[op](), + right=expr.right.accept(self), + ) + if op in COMPARISON_OPERATORS: + return ast.Compare( + left=expr.left.accept(self), + ops=[COMPARISON_OPERATORS[op]()], + comparators=[expr.right.accept(self)], + ) + raise ValueError(f"Unexpected binary operator {op}") + + def visit_unary_expr(self, expr: m.UnaryExpr) -> ast.expr: + return ast.UnaryOp( + op=UNARY_OPERATORS[expr.operator.type](), + operand=expr.right.accept(self), + ) + + def visit_call_expr(self, expr: m.CallExpr) -> ast.expr: + return ast.Call( + func=expr.callee.accept(self), + args=[arg.accept(self) for arg in expr.arguments], + keywords=[ + ast.keyword(arg=name, value=arg.accept(self)) + for name, arg in expr.keywords.items() + ], + ) + + def visit_get_expr(self, expr: m.GetExpr) -> ast.expr: + return ast.Attribute( + value=expr.expr.accept(self), + attr=expr.name.lexeme, + ) + + def visit_variable_expr(self, expr: m.VariableExpr) -> ast.expr: + name: str = expr.name.lexeme + if (p := self.get_predicate(name)) is not None: + return p + return ast.Name(id=name) + + def visit_grouping_expr(self, expr: m.GroupingExpr) -> ast.expr: + return expr.accept(self) + + def visit_literal_expr(self, expr: m.LiteralExpr) -> ast.expr: + return ast.Constant(value=expr.value) + + def visit_wildcard_expr(self, expr: m.WildcardExpr) -> ast.expr: + return ast.Name(id="_") diff --git a/midas/generator/generator.py b/midas/generator/generator.py index 7575ca5..22eab41 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -2,15 +2,19 @@ import ast import shutil from dataclasses import dataclass, field from pathlib import Path -from typing import Optional +from typing import Optional, assert_never +import midas.ast.midas as m import midas.ast.python as p from midas.ast.location import Location +from midas.ast.printer import MidasPrinter +from midas.checker.registry import TypesRegistry from midas.checker.types import ( AliasType, AppliedType, BaseType, ComplexType, + ConstraintType, ExtensionType, Function, GenericType, @@ -19,7 +23,9 @@ from midas.checker.types import ( Type, TypeVar, UnitType, + UnknownType, ) +from midas.generator.constraints import ConstraintGenerator from midas.utils import TypedAST @@ -30,12 +36,9 @@ class Scope: class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): - def __init__(self, workdir: Path) -> None: + def __init__(self, workdir: Path, types: TypesRegistry) -> None: self.workdir: Path = workdir.resolve() self.build_dir: Path = self.workdir / "build" / "midas" - if self.build_dir.exists(): - shutil.rmtree(self.build_dir) - self.build_dir.mkdir(parents=True, exist_ok=True) self.rel_src_path: Path = Path() self._typed_ast: TypedAST = TypedAST( @@ -43,13 +46,18 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): judgements=[], ) self._alias_count: int = 0 + self._predicate_count: int = 0 self._scopes: list[Scope] = [] + self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types) + self._constraints: list[tuple[m.Expr, ast.expr]] = [] + def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST: - self.rel_src_path = src_path.relative_to(self.workdir) + self.rel_src_path = src_path.resolve().relative_to(self.workdir) self._typed_ast = typed_ast body: list[ast.stmt] = self._visit_body(typed_ast.stmts) - module = ast.Module(body=body, type_ignores=[]) + predicates: list[ast.stmt] = self._constraint_generator.get_definitions() + module = ast.Module(body=predicates + body, type_ignores=[]) module = ast.fix_missing_locations(module) return module @@ -59,6 +67,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): module: ast.AST = self.generate_ast(typed_ast, src_path) compiled: str = ast.unparse(module) if out_path is None: + if self.build_dir.exists(): + shutil.rmtree(self.build_dir) + self.build_dir.mkdir(parents=True, exist_ok=True) out_path = (self.build_dir / self.rel_src_path).resolve() try: _ = out_path.relative_to(self.build_dir) @@ -246,7 +257,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): return generated def _make_alias(self, expr: ast.expr) -> ast.expr: - name: str = f"__midas_alias_{self._alias_count}__" + name: str = f"__midas_a{self._alias_count}__" alias = ast.Name(id=name) self._alias_count += 1 self._scopes[-1].aliases.append(name) @@ -276,6 +287,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type): match type: + case UnknownType(): + pass + case BaseType(name=name): self._add_assert( ast.Call( @@ -301,8 +315,15 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): self._make_cast_assert_message(src_location, expr, type), ) - case AppliedType(): - self._make_cast_asserts(src_location, expr, type.body) + case AppliedType(body=body): + self._make_cast_asserts(src_location, expr, body) + + case ConstraintType(type=base, constraint=constraint): + self._make_cast_asserts(src_location, expr, base) + self._make_constraint_assert(src_location, expr, constraint) + + case TypeVar(): + raise RuntimeError("Unexpected TypeVar") case ( TopType() @@ -314,8 +335,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): ): raise NotImplementedError(f"Can't make assertion for type {type}") - case TypeVar(): - raise RuntimeError("Unexpected TypeVar") + # Ensure exhaustiveness + case _: + assert_never(type) def _make_cast_assert_message( self, location: Location, expr: ast.expr, type: Type @@ -339,3 +361,36 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): ast.Constant(f" to {type}"), ] ) + + def _make_constraint_assert( + self, src_location: Location, expr: ast.expr, constraint: m.Expr + ): + test_func: ast.expr = self._get_constraint(constraint) + self._add_assert( + ast.Call( + func=test_func, + args=[expr], + keywords=[], + ), + self._make_constraint_assert_message(src_location, expr, constraint), + ) + + def _make_constraint_assert_message( + self, location: Location, expr: ast.expr, constraint: m.Expr + ) -> ast.expr: + printer = MidasPrinter() + constraint_str: str = printer.print(constraint) + loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}" + # f"file.py:L1:1: ConstraintError: Value does not fit constraint 'v > 0'" + return ast.Constant( + f"{loc_str}: ConstraintError: Value does not fit constraint '{constraint_str}'" + ) + + def _get_constraint(self, expr: m.Expr) -> ast.expr: + for expr2, constraint in self._constraints: + if expr2 == expr: + return constraint + + constraint: ast.expr = self._constraint_generator.generate(expr) + self._constraints.append((expr, constraint)) + return constraint diff --git a/midas/lexer/midas.py b/midas/lexer/midas.py index c3246fc..0510a6e 100644 --- a/midas/lexer/midas.py +++ b/midas/lexer/midas.py @@ -69,6 +69,8 @@ class MidasLexer(Lexer): ): self.advance() self.add_token(TokenType.WHITESPACE) + case '"' | "'": + self.scan_string(char) case _: if char.isdigit(): self.scan_number() @@ -78,6 +80,17 @@ class MidasLexer(Lexer): self.error("Unexpected character") return None + def scan_string(self, opening: str): + while self.peek() != opening and not self.is_at_end(): + self.advance() + + if self.is_at_end(): + self.error("Unterminated string") + + self.advance() + value: str = self.source[self.start + 1 : self.idx - 1] + self.add_token(TokenType.STRING, value) + def scan_number(self): """Scan the rest of number and add it as a token diff --git a/midas/lexer/token.py b/midas/lexer/token.py index f0c08a1..ce73aef 100644 --- a/midas/lexer/token.py +++ b/midas/lexer/token.py @@ -43,6 +43,7 @@ class TokenType(Enum): TRUE = auto() FALSE = auto() NONE = auto() + STRING = auto() # Keywords TYPE = auto() diff --git a/midas/parser/midas.py b/midas/parser/midas.py index 33069f3..fdb58c2 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -3,6 +3,7 @@ from typing import Optional from midas.ast.location import Location from midas.ast.midas import ( BinaryExpr, + CallExpr, ComplexType, ConstraintType, Expr, @@ -17,6 +18,7 @@ from midas.ast.midas import ( MemberKind, MemberStmt, NamedType, + ParamSpec, PredicateStmt, Stmt, Type, @@ -265,6 +267,9 @@ class MidasParser(Parser): Returns: Expr: the parsed constraint expression """ + return self.expression() + + def expression(self) -> Expr: return self.and_() def and_(self) -> Expr: @@ -331,7 +336,55 @@ class MidasParser(Parser): right: Expr = self.unary() location: Location = Location.span(operator.get_location(), right.location) return UnaryExpr(location=location, operator=operator, right=right) - return self.reference() + return self.call() + + def call(self) -> Expr: + expr: Expr = self.reference() + while self.match(TokenType.LEFT_PAREN): + expr = self.finish_call(expr) + return expr + + def finish_call(self, callee: Expr) -> Expr: + pos_args: list[Expr] = [] + kw_args: dict[str, Expr] = {} + keywords: bool = False + while not self.match(TokenType.RIGHT_PAREN): + if self.check_identifier() and self.check_next(TokenType.EQUAL): + keywords = True + keyword: Token = self.advance() + self.advance() + value: Expr = self.expression() + name: str = keyword.lexeme + if name in kw_args: + self.error( + self.peek(), + f"Multiple values passed for '{name}', only the last occurrence will be used", + ) + kw_args[name] = value + else: + value = self.expression() + if self.check(TokenType.EQUAL): + if keywords: + raise self.error(self.peek(), "Invalid keyword argument name") + else: + raise self.error( + self.peek(), + "Cannot pass positional arguments after a keyword argument", + ) + pos_args.append(value) + + if not self.match(TokenType.COMMA): + break + + r_paren: Token = self.consume( + TokenType.RIGHT_PAREN, "Expected ')' after arguments." + ) + return CallExpr( + location=Location.span(callee.location, r_paren.get_location()), + callee=callee, + arguments=pos_args, + keywords=kw_args, + ) def reference(self) -> Expr: """Parse an attribute access expression or a simpler expression @@ -365,6 +418,9 @@ class MidasParser(Parser): if self.match(TokenType.NUMBER): return LiteralExpr(location=token.get_location(), value=token.value) + if self.match(TokenType.STRING): + return LiteralExpr(location=token.get_location(), value=token.value) + if self.match_identifier(): return VariableExpr(location=token.get_location(), name=token) @@ -453,23 +509,35 @@ class MidasParser(Parser): PredicateStmt: the parsed predicate declaration statement """ keyword: Token = self.previous() + name: Token = self.consume_identifier("Expected predicate name") - self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject") - subject: Token = self.consume_identifier("Expected subject name") - self.consume(TokenType.COLON, "Expected ':' after subject name") - type: Type = self.type_expr() - self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject") + + params: list[ParamSpec] = [] + while self.check(TokenType.LEFT_PAREN): + params.append(self.function_args()) + self.consume(TokenType.EQUAL, "Expected '=' after predicate subject") - condition: Expr = self.constraint() + body: Expr = self.constraint() return PredicateStmt( location=keyword.location_to(self.previous()), name=name, - subject=subject, - type=type, - condition=condition, + params=params, + body=body, ) def function(self) -> FunctionType: + params: ParamSpec = self.function_args() + + self.consume(TokenType.ARROW, "Expected '->' before result type") + result: Type = self.type_expr() + + return FunctionType( + location=params.l_paren.location_to(self.previous()), + params=params, + returns=result, + ) + + def function_args(self) -> ParamSpec: l_paren: Token = self.consume( TokenType.LEFT_PAREN, "Expected '(' before function parameters" ) @@ -526,14 +594,4 @@ class MidasParser(Parser): self.error(token, "Unnamed mixed argument") self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters") - - self.consume(TokenType.ARROW, "Expected '->' before result type") - result: Type = self.type_expr() - - return FunctionType( - location=l_paren.location_to(self.previous()), - pos_args=pos_args, - args=args, - kw_args=kw_args, - returns=result, - ) + return ParamSpec(l_paren=l_paren, pos=pos_args, mixed=args, kw=kw_args) diff --git a/tests/cases/generator/01_simple_types.py.ref.txt b/tests/cases/generator/01_simple_types.py.ref.txt index aa5f964..f8da6a6 100644 --- a/tests/cases/generator/01_simple_types.py.ref.txt +++ b/tests/cases/generator/01_simple_types.py.ref.txt @@ -9,13 +9,13 @@ Module( level=0), Assign( targets=[ - Name(id='__midas_alias_0__')], + Name(id='__midas_a0__')], value=Constant(value=123.45)), Assert( test=Call( func=Name(id='isinstance'), args=[ - Name(id='__midas_alias_0__'), + Name(id='__midas_a0__'), Name(id='float')], keywords=[]), msg=JoinedStr( @@ -26,7 +26,7 @@ Module( value=Call( func=Name(id='type'), args=[ - Name(id='__midas_alias_0__')], + Name(id='__midas_a0__')], keywords=[]), attr='__name__'), conversion=-1), @@ -34,19 +34,19 @@ Module( Assign( targets=[ Name(id='distance')], - value=Name(id='__midas_alias_0__')), + value=Name(id='__midas_a0__')), Delete( targets=[ - Name(id='__midas_alias_0__')]), + Name(id='__midas_a0__')]), Assign( targets=[ - Name(id='__midas_alias_1__')], + Name(id='__midas_a1__')], value=Constant(value=6.7)), Assert( test=Call( func=Name(id='isinstance'), args=[ - Name(id='__midas_alias_1__'), + Name(id='__midas_a1__'), Name(id='float')], keywords=[]), msg=JoinedStr( @@ -57,7 +57,7 @@ Module( value=Call( func=Name(id='type'), args=[ - Name(id='__midas_alias_1__')], + Name(id='__midas_a1__')], keywords=[]), attr='__name__'), conversion=-1), @@ -65,10 +65,10 @@ Module( Assign( targets=[ Name(id='time')], - value=Name(id='__midas_alias_1__')), + value=Name(id='__midas_a1__')), Delete( targets=[ - Name(id='__midas_alias_1__')]), + Name(id='__midas_a1__')]), Assign( targets=[ Name(id='speed')], diff --git a/tests/cases/generator/02_constraints.midas b/tests/cases/generator/02_constraints.midas new file mode 100644 index 0000000..2096221 --- /dev/null +++ b/tests/cases/generator/02_constraints.midas @@ -0,0 +1,14 @@ +// Inline +type T1 = float where _ > 0 + +// Named +predicate is_positive(v: float) = v > 0 +type T2 = float where is_positive(_) + +// Curried +predicate in_range(mn: float, mx: float)(v: float) = v >= mn & v < mx +type T3 = float where in_range(100, 200)(_) + +// Alias +predicate minor = in_range(0, 18) +type T4 = float where minor(_) diff --git a/tests/cases/generator/02_constraints.py b/tests/cases/generator/02_constraints.py new file mode 100644 index 0000000..3d6168b --- /dev/null +++ b/tests/cases/generator/02_constraints.py @@ -0,0 +1,8 @@ +from midas import T1, T2, T3, T4, cast + +t: float = 12.5 + +t1: T1 = cast(T1, t) +t2: T2 = cast(T2, t) +t3: T3 = cast(T3, t) +t4: T4 = cast(T4, t) diff --git a/tests/cases/generator/02_constraints.py.ref.txt b/tests/cases/generator/02_constraints.py.ref.txt new file mode 100644 index 0000000..7adaed7 --- /dev/null +++ b/tests/cases/generator/02_constraints.py.ref.txt @@ -0,0 +1,333 @@ +Module( + body=[ + FunctionDef( + name='__midas_p0__', + args=arguments( + posonlyargs=[], + args=[ + arg( + arg='_', + annotation=Constant(value='Any'))], + kwonlyargs=[], + kw_defaults=[], + defaults=[]), + body=[ + Return( + value=Compare( + left=Name(id='_'), + ops=[ + Gt()], + comparators=[ + Constant(value=0.0)]))], + decorator_list=[], + returns=Constant(value='bool')), + FunctionDef( + name='__midas_is_positive__', + args=arguments( + posonlyargs=[], + args=[ + arg( + arg='v', + annotation=Constant(value='float'))], + kwonlyargs=[], + kw_defaults=[], + defaults=[]), + body=[ + Return( + value=Compare( + left=Name(id='v'), + ops=[ + Gt()], + comparators=[ + Constant(value=0.0)]))], + decorator_list=[], + returns=Constant(value='bool')), + FunctionDef( + name='__midas_p1__', + args=arguments( + posonlyargs=[], + args=[ + arg( + arg='_', + annotation=Constant(value='Any'))], + kwonlyargs=[], + kw_defaults=[], + defaults=[]), + body=[ + Return( + value=Call( + func=Name(id='__midas_is_positive__'), + args=[ + Name(id='_')], + keywords=[]))], + decorator_list=[], + returns=Constant(value='bool')), + FunctionDef( + name='__midas_in_range__', + args=arguments( + posonlyargs=[], + args=[ + arg( + arg='mn', + annotation=Constant(value='float')), + arg( + arg='mx', + annotation=Constant(value='float'))], + kwonlyargs=[], + kw_defaults=[], + defaults=[]), + body=[ + FunctionDef( + name='inner0', + args=arguments( + posonlyargs=[], + args=[ + arg( + arg='v', + annotation=Constant(value='float'))], + kwonlyargs=[], + kw_defaults=[], + defaults=[]), + body=[ + Return( + value=BoolOp( + op=And(), + values=[ + Compare( + left=Name(id='v'), + ops=[ + GtE()], + comparators=[ + Name(id='mn')]), + Compare( + left=Name(id='v'), + ops=[ + Lt()], + comparators=[ + Name(id='mx')])]))], + decorator_list=[], + returns=Constant(value='bool')), + Return( + value=Name(id='inner0'))], + decorator_list=[], + returns=Constant(value='Callable[[float], bool]')), + FunctionDef( + name='__midas_p2__', + args=arguments( + posonlyargs=[], + args=[ + arg( + arg='_', + annotation=Constant(value='Any'))], + kwonlyargs=[], + kw_defaults=[], + defaults=[]), + body=[ + Return( + value=Call( + func=Call( + func=Name(id='__midas_in_range__'), + args=[ + Constant(value=100.0), + Constant(value=200.0)], + keywords=[]), + args=[ + Name(id='_')], + keywords=[]))], + decorator_list=[], + returns=Constant(value='bool')), + Assign( + targets=[ + Name(id='__midas_minor__')], + value=Call( + func=Name(id='__midas_in_range__'), + args=[ + Constant(value=0.0), + Constant(value=18.0)], + keywords=[])), + FunctionDef( + name='__midas_p3__', + args=arguments( + posonlyargs=[], + args=[ + arg( + arg='_', + annotation=Constant(value='Any'))], + kwonlyargs=[], + kw_defaults=[], + defaults=[]), + body=[ + Return( + value=Call( + func=Name(id='__midas_minor__'), + args=[ + Name(id='_')], + keywords=[]))], + decorator_list=[], + returns=Constant(value='bool')), + ImportFrom( + module='midas', + names=[ + alias(name='T1'), + alias(name='T2'), + alias(name='T3'), + alias(name='T4'), + alias(name='cast')], + level=0), + Assign( + targets=[ + Name(id='t')], + value=Constant(value=12.5)), + Assign( + targets=[ + Name(id='__midas_a0__')], + value=Name(id='t')), + Assert( + test=Call( + func=Name(id='isinstance'), + args=[ + Name(id='__midas_a0__'), + Name(id='float')], + keywords=[]), + msg=JoinedStr( + values=[ + Constant(value='02_constraints.py:L5:10: CastError: Cannot cast '), + FormattedValue( + value=Attribute( + value=Call( + func=Name(id='type'), + args=[ + Name(id='__midas_a0__')], + keywords=[]), + attr='__name__'), + conversion=-1), + Constant(value=' to float')])), + Assert( + test=Call( + func=Name(id='__midas_p0__'), + args=[ + Name(id='__midas_a0__')], + keywords=[]), + msg=Constant(value="02_constraints.py:L5:10: ConstraintError: Value does not fit constraint '_ > 0.0'")), + Assign( + targets=[ + Name(id='t1')], + value=Name(id='__midas_a0__')), + Delete( + targets=[ + Name(id='__midas_a0__')]), + Assign( + targets=[ + Name(id='__midas_a1__')], + value=Name(id='t')), + Assert( + test=Call( + func=Name(id='isinstance'), + args=[ + Name(id='__midas_a1__'), + Name(id='float')], + keywords=[]), + msg=JoinedStr( + values=[ + Constant(value='02_constraints.py:L6:10: CastError: Cannot cast '), + FormattedValue( + value=Attribute( + value=Call( + func=Name(id='type'), + args=[ + Name(id='__midas_a1__')], + keywords=[]), + attr='__name__'), + conversion=-1), + Constant(value=' to float')])), + Assert( + test=Call( + func=Name(id='__midas_p1__'), + args=[ + Name(id='__midas_a1__')], + keywords=[]), + msg=Constant(value="02_constraints.py:L6:10: ConstraintError: Value does not fit constraint 'is_positive(_)'")), + Assign( + targets=[ + Name(id='t2')], + value=Name(id='__midas_a1__')), + Delete( + targets=[ + Name(id='__midas_a1__')]), + Assign( + targets=[ + Name(id='__midas_a2__')], + value=Name(id='t')), + Assert( + test=Call( + func=Name(id='isinstance'), + args=[ + Name(id='__midas_a2__'), + Name(id='float')], + keywords=[]), + msg=JoinedStr( + values=[ + Constant(value='02_constraints.py:L7:10: CastError: Cannot cast '), + FormattedValue( + value=Attribute( + value=Call( + func=Name(id='type'), + args=[ + Name(id='__midas_a2__')], + keywords=[]), + attr='__name__'), + conversion=-1), + Constant(value=' to float')])), + Assert( + test=Call( + func=Name(id='__midas_p2__'), + args=[ + Name(id='__midas_a2__')], + keywords=[]), + msg=Constant(value="02_constraints.py:L7:10: ConstraintError: Value does not fit constraint 'in_range(100.0, 200.0)(_)'")), + Assign( + targets=[ + Name(id='t3')], + value=Name(id='__midas_a2__')), + Delete( + targets=[ + Name(id='__midas_a2__')]), + Assign( + targets=[ + Name(id='__midas_a3__')], + value=Name(id='t')), + Assert( + test=Call( + func=Name(id='isinstance'), + args=[ + Name(id='__midas_a3__'), + Name(id='float')], + keywords=[]), + msg=JoinedStr( + values=[ + Constant(value='02_constraints.py:L8:10: CastError: Cannot cast '), + FormattedValue( + value=Attribute( + value=Call( + func=Name(id='type'), + args=[ + Name(id='__midas_a3__')], + keywords=[]), + attr='__name__'), + conversion=-1), + Constant(value=' to float')])), + Assert( + test=Call( + func=Name(id='__midas_p3__'), + args=[ + Name(id='__midas_a3__')], + keywords=[]), + msg=Constant(value="02_constraints.py:L8:10: ConstraintError: Value does not fit constraint 'minor(_)'")), + Assign( + targets=[ + Name(id='t4')], + value=Name(id='__midas_a3__')), + Delete( + targets=[ + Name(id='__midas_a3__')])], + type_ignores=[]) \ No newline at end of file diff --git a/tests/cases/midas-parser/01_simple_types.midas.ref.json b/tests/cases/midas-parser/01_simple_types.midas.ref.json index be45687..7f91213 100644 --- a/tests/cases/midas-parser/01_simple_types.midas.ref.json +++ b/tests/cases/midas-parser/01_simple_types.midas.ref.json @@ -2582,18 +2582,21 @@ "name": "__sub__", "type": { "_type": "FunctionType", - "pos_args": [ - { - "name": null, - "type": { - "_type": "NamedType", - "name": "GeoLocation" - }, - "required": true - } - ], - "args": [], - "kw_args": [], + "params": { + "_type": "ParamSpec", + "pos": [ + { + "name": null, + "type": { + "_type": "NamedType", + "name": "GeoLocation" + }, + "required": true + } + ], + "mixed": [], + "kw": [] + }, "returns": { "_type": "GenericType", "type": { @@ -2673,18 +2676,21 @@ "name": "__sub__", "type": { "_type": "FunctionType", - "pos_args": [ - { - "name": null, - "type": { - "_type": "NamedType", - "name": "Latitude" - }, - "required": true - } - ], - "args": [], - "kw_args": [], + "params": { + "_type": "ParamSpec", + "pos": [ + { + "name": null, + "type": { + "_type": "NamedType", + "name": "Latitude" + }, + "required": true + } + ], + "mixed": [], + "kw": [] + }, "returns": { "_type": "GenericType", "type": { @@ -2713,18 +2719,21 @@ "name": "__sub__", "type": { "_type": "FunctionType", - "pos_args": [ - { - "name": null, - "type": { - "_type": "NamedType", - "name": "Longitude" - }, - "required": true - } - ], - "args": [], - "kw_args": [], + "params": { + "_type": "ParamSpec", + "pos": [ + { + "name": null, + "type": { + "_type": "NamedType", + "name": "Longitude" + }, + "required": true + } + ], + "mixed": [], + "kw": [] + }, "returns": { "_type": "GenericType", "type": { @@ -2745,12 +2754,24 @@ { "_type": "PredicateStmt", "name": "Positive", - "subject": "v", - "type": { - "_type": "NamedType", - "name": "float" - }, - "condition": { + "params": [ + { + "_type": "ParamSpec", + "pos": [], + "mixed": [ + { + "name": "v", + "type": { + "_type": "NamedType", + "name": "float" + }, + "required": true + } + ], + "kw": [] + } + ], + "body": { "_type": "BinaryExpr", "left": { "_type": "VariableExpr", @@ -2766,12 +2787,24 @@ { "_type": "PredicateStmt", "name": "StrictlyPositive", - "subject": "v", - "type": { - "_type": "NamedType", - "name": "float" - }, - "condition": { + "params": [ + { + "_type": "ParamSpec", + "pos": [], + "mixed": [ + { + "name": "v", + "type": { + "_type": "NamedType", + "name": "float" + }, + "required": true + } + ], + "kw": [] + } + ], + "body": { "_type": "BinaryExpr", "left": { "_type": "VariableExpr", @@ -2787,12 +2820,24 @@ { "_type": "PredicateStmt", "name": "Equatorial", - "subject": "loc", - "type": { - "_type": "NamedType", - "name": "GeoLocation" - }, - "condition": { + "params": [ + { + "_type": "ParamSpec", + "pos": [], + "mixed": [ + { + "name": "loc", + "type": { + "_type": "NamedType", + "name": "GeoLocation" + }, + "required": true + } + ], + "kw": [] + } + ], + "body": { "_type": "GroupingExpr", "expr": { "_type": "BinaryExpr", @@ -2827,12 +2872,24 @@ { "_type": "PredicateStmt", "name": "Arctic", - "subject": "loc", - "type": { - "_type": "NamedType", - "name": "GeoLocation" - }, - "condition": { + "params": [ + { + "_type": "ParamSpec", + "pos": [], + "mixed": [ + { + "name": "loc", + "type": { + "_type": "NamedType", + "name": "GeoLocation" + }, + "required": true + } + ], + "kw": [] + } + ], + "body": { "_type": "GroupingExpr", "expr": { "_type": "BinaryExpr", diff --git a/tests/generator.py b/tests/generator.py index 72b7002..067790f 100644 --- a/tests/generator.py +++ b/tests/generator.py @@ -45,7 +45,7 @@ class GeneratorTester(Tester): typed_ast: TypedAST = checker.type_check(path) if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics): - generator = Generator(workdir=path.parent) + generator = Generator(workdir=path.parent, types=checker.types) result.compiled_ast = generator.generate_ast(typed_ast, path) return result diff --git a/tests/serializer/midas.py b/tests/serializer/midas.py index 8bffdb3..0ad86f3 100644 --- a/tests/serializer/midas.py +++ b/tests/serializer/midas.py @@ -2,6 +2,7 @@ from typing import Optional, Sequence from midas.ast.midas import ( BinaryExpr, + CallExpr, ComplexType, ConstraintType, Expr, @@ -15,6 +16,7 @@ from midas.ast.midas import ( LogicalExpr, MemberStmt, NamedType, + ParamSpec, PredicateStmt, Stmt, Type, @@ -78,9 +80,8 @@ class MidasAstJsonSerializer( return { "_type": "PredicateStmt", "name": stmt.name.lexeme, - "subject": stmt.subject.lexeme, - "type": stmt.type.accept(self), - "condition": stmt.condition.accept(self), + "params": [self._serialize_param_spec(spec) for spec in stmt.params], + "body": stmt.body.accept(self), } def visit_logical_expr(self, expr: LogicalExpr) -> dict: @@ -106,6 +107,14 @@ class MidasAstJsonSerializer( "right": expr.right.accept(self), } + def visit_call_expr(self, expr: CallExpr) -> dict: + return { + "_type": "CallExpr", + "callee": expr.callee.accept(self), + "arguments": self._serialize_list(expr.arguments), + "keywords": {name: arg.accept(self) for name, arg in expr.keywords.items()}, + } + def visit_get_expr(self, expr: GetExpr) -> dict: return { "_type": "GetExpr", @@ -163,15 +172,21 @@ class MidasAstJsonSerializer( def visit_function_type(self, type: FunctionType) -> dict: return { "_type": "FunctionType", - "pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args], - "args": [self._serialize_func_arg(arg) for arg in type.args], - "kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args], + "params": self._serialize_param_spec(type.params), "returns": type.returns.accept(self), } + def _serialize_param_spec(self, spec: ParamSpec) -> dict: + return { + "_type": "ParamSpec", + "pos": [self._serialize_func_arg(arg) for arg in spec.pos], + "mixed": [self._serialize_func_arg(arg) for arg in spec.mixed], + "kw": [self._serialize_func_arg(arg) for arg in spec.kw], + } + def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict: return { - "name": arg.name, + "name": arg.name.lexeme if arg.name is not None else None, "type": arg.type.accept(self), "required": arg.required, }