Merge pull request 'Constraint types' (#15) from feat/constraint-type into main

Reviewed-on: #15
This commit was merged in pull request #15.
This commit is contained in:
2026-06-19 20:21:04 +00:00
23 changed files with 1747 additions and 205 deletions

View File

@@ -26,6 +26,14 @@ class MemberKind(Enum):
METHOD = auto()
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
l_paren: Token
pos: list[FunctionType.Argument]
mixed: list[FunctionType.Argument]
kw: list[FunctionType.Argument]
###<
@@ -50,9 +58,8 @@ class ExtendStmt:
class PredicateStmt:
name: Token
subject: Token
type: Type
condition: Expr
params: list[ParamSpec]
body: Expr
###<
@@ -78,6 +85,12 @@ class UnaryExpr:
right: Expr
class CallExpr:
callee: Expr
arguments: list[Expr]
keywords: dict[str, Expr]
class GetExpr:
expr: Expr
name: Token
@@ -128,9 +141,7 @@ class ExtensionType:
class FunctionType:
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
params: ParamSpec
returns: Type
@dataclass(frozen=True, kw_only=True)

View File

@@ -27,6 +27,14 @@ class MemberKind(Enum):
METHOD = auto()
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
l_paren: Token
pos: list[FunctionType.Argument]
mixed: list[FunctionType.Argument]
kw: list[FunctionType.Argument]
##############
# Statements #
##############
@@ -86,9 +94,8 @@ class ExtendStmt(Stmt):
@dataclass(frozen=True)
class PredicateStmt(Stmt):
name: Token
subject: Token
type: Type
condition: Expr
params: list[ParamSpec]
body: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_predicate_stmt(self)
@@ -116,6 +123,9 @@ class Expr(ABC):
@abstractmethod
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
@abstractmethod
def visit_call_expr(self, expr: CallExpr) -> T: ...
@abstractmethod
def visit_get_expr(self, expr: GetExpr) -> T: ...
@@ -161,6 +171,16 @@ class UnaryExpr(Expr):
return visitor.visit_unary_expr(self)
@dataclass(frozen=True)
class CallExpr(Expr):
callee: Expr
arguments: list[Expr]
keywords: dict[str, Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_call_expr(self)
@dataclass(frozen=True)
class GetExpr(Expr):
expr: Expr
@@ -279,9 +299,7 @@ class ExtensionType(Type):
@dataclass(frozen=True)
class FunctionType(Type):
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
params: ParamSpec
returns: Type
@dataclass(frozen=True, kw_only=True)

View File

@@ -150,13 +150,17 @@ class MidasAstPrinter(
self._write_line("PredicateStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line(f'subject: "{stmt.subject.lexeme}"')
self._write_line("type")
self._write_line("params")
with self._child_level():
for i, spec in enumerate(stmt.params):
self._idx = i
if i == len(stmt.params) - 1:
self._mark_last()
self._visit_param_spec(spec)
self._write_line("body", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
self._write_line("condition", last=True)
with self._child_level(single=True):
stmt.condition.accept(self)
stmt.body.accept(self)
# Expressions
@@ -195,6 +199,29 @@ class MidasAstPrinter(
with self._child_level(single=True):
expr.right.accept(self)
def visit_call_expr(self, expr: m.CallExpr) -> None:
self._write_line("CallExpr")
with self._child_level():
self._write_line("callee")
with self._child_level(single=True):
expr.callee.accept(self)
self._write_line("arguments")
with self._child_level():
for i, arg in enumerate(expr.arguments):
self._idx = i
if i == len(expr.arguments) - 1:
self._mark_last()
arg.accept(self)
self._write_line("keywords", last=True)
with self._child_level():
for i, (name, arg) in enumerate(expr.keywords.items()):
self._idx = i
if i == len(expr.keywords) - 1:
self._mark_last()
self._write_line(name)
with self._child_level(single=True):
arg.accept(self)
def visit_get_expr(self, expr: m.GetExpr):
self._write_line("GetExpr")
with self._child_level():
@@ -276,34 +303,41 @@ class MidasAstPrinter(
def visit_function_type(self, type: m.FunctionType) -> None:
self._write_line("FunctionType")
with self._child_level():
self._write_line("pos_args")
with self._child_level():
for i, arg in enumerate(type.pos_args):
self._idx = i
if i == len(type.pos_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("args")
with self._child_level():
for i, arg in enumerate(type.args):
self._idx = i
if i == len(type.args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw_args")
with self._child_level():
for i, arg in enumerate(type.kw_args):
self._idx = i
if i == len(type.kw_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("params")
with self._child_level(single=True):
self._visit_param_spec(type.params)
self._write_line("returns", last=True)
with self._child_level(single=True):
type.returns.accept(self)
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
self._write_line("ParamSpec")
with self._child_level():
self._write_line("pos")
with self._child_level():
for i, arg in enumerate(spec.pos):
self._idx = i
if i == len(spec.pos) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("mixed")
with self._child_level():
for i, arg in enumerate(spec.mixed):
self._idx = i
if i == len(spec.mixed) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw", last=True)
with self._child_level():
for i, arg in enumerate(spec.kw):
self._idx = i
if i == len(spec.kw) - 1:
self._mark_last()
self._print_function_arg(arg)
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
self._write_line("Argument")
with self._child_level():
@@ -367,10 +401,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
name: str = stmt.name.lexeme
subject: str = stmt.subject.lexeme
type: str = stmt.type.accept(self)
condition: str = stmt.condition.accept(self)
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
body: str = stmt.body.accept(self)
return self.indented(f"predicate {name}{sig} = {body}")
def visit_logical_expr(self, expr: m.LogicalExpr):
left: str = expr.left.accept(self)
@@ -389,6 +422,12 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
right: str = expr.right.accept(self)
return f"{operator}{right}"
def visit_call_expr(self, expr: m.CallExpr) -> str:
args: list[str] = [arg.accept(self) for arg in expr.arguments] + [
f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items()
]
return f"{expr.callee.accept(self)}({', '.join(args)})"
def visit_get_expr(self, expr: m.GetExpr):
expr_: str = expr.expr.accept(self)
name: str = expr.name.lexeme
@@ -436,9 +475,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
def visit_function_type(self, type: m.FunctionType) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
spec: str = self._visit_param_spec(type.params)
return f"fn {spec} -> {type.returns.accept(self)}"
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in spec.pos]
mixed_args: list[str] = [self._print_arg(arg) for arg in spec.mixed]
kw_args: list[str] = [self._print_arg(arg) for arg in spec.kw]
args: list[str] = pos_args
if len(pos_args) != 0:
@@ -447,8 +490,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
if len(kw_args) != 0:
args.append("*")
args += kw_args
return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}"
return f"({', '.join(args)})"
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
res: str = ""

View File

@@ -15,6 +15,7 @@ if TYPE_CHECKING:
BUILTIN_SUBTYPES: dict[str, set[str]] = {
"object": {"float", "list", "dict"},
"float": {"int"},
"int": {"bool"},
}

View File

@@ -1,27 +1,64 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import midas.ast.midas as m
from midas.ast.location import Location
from midas.checker.builtins import define_builtins
from midas.checker.environment import Environment
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import (
AliasType,
AppliedType,
ComplexType,
ConstraintType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
Predicate,
Type,
TypeVar,
UnknownType,
unfold_type,
)
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
@dataclass(frozen=True, kw_only=True)
class TypedParamSpec:
pos: list[Function.Argument]
mixed: list[Function.Argument]
kw: list[Function.Argument]
TypedExpr = tuple[m.Expr, Type]
class ReturnException(Exception):
pass
@dataclass(frozen=True, kw_only=True)
class MappedArgument:
expr: m.Expr
type: Type
argument: Function.Argument
@dataclass(frozen=True, kw_only=True)
class OverloadCandidate:
function: Function
mapped: list[MappedArgument]
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry"""
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
@@ -31,12 +68,18 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
self.types: TypesRegistry = types
self._local_variables: dict[str, TypeVar] = {}
self._predicate_params: dict[str, Type] = {}
self._current_name: Optional[str] = None
define_builtins(self.types)
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
self.process(builtins_path.read_text(), str(builtins_path))
self._bool: Type = self.get_type("bool")
self._preamble: Environment = Preamble(self.types)
def process(self, source: str, path: Optional[str]):
self.reporter = self.reporter.for_file(path)
lexer: MidasLexer = MidasLexer(source)
@@ -47,6 +90,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
self.reporter.error(error.token.get_location(), error.message)
self.resolve(stmts)
def type_of(self, expr: m.Expr) -> Type:
type: Type = expr.accept(self)
return type
def get_type(self, name: str) -> Type:
"""Get a type from its name
@@ -63,6 +110,19 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
return self._local_variables[name]
return self.types.get_type(name)
def get_variable(self, name: str) -> Type:
if name in self._predicate_params:
return self._predicate_params[name]
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
if predicate is not None:
return predicate.type
global_: Optional[Type] = self._preamble.get(name)
if global_ is not None:
return global_
raise NameError(f"Unknown variable '{name}'")
def resolve(self, stmts: list[m.Stmt]):
"""Process a sequence of statements
@@ -72,6 +132,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
for stmt in stmts:
stmt.accept(self)
def assert_bool(self, expr: m.Expr):
type: Type = self.type_of(expr)
if not self.types.is_subtype(type, self._bool):
self.reporter.error(expr.location, f"Must be a boolean but is {type}")
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
name: str = stmt.name.lexeme
self._current_name = name
@@ -106,31 +171,163 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.reporter.warning(stmt.location, "PredicateStmt not yet supported")
for spec in stmt.params:
for param in spec.mixed:
assert param.name is not None
self._predicate_params[param.name.lexeme] = param.type.accept(self)
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
type: Type = self.type_of(stmt.body)
params: list[TypedParamSpec] = [
self._visit_param_spec(spec) for spec in stmt.params
]
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
if not self._is_valid_predicate(type):
self.reporter.error(
stmt.body.location,
f"Predicate function body must evaluate to a boolean, got {type}",
)
if len(params) != 0:
type = self._bool
for spec in reversed(params):
type = Function(
pos_args=spec.pos,
args=spec.mixed,
kw_args=spec.kw,
returns=type,
)
self._predicate_params = {}
self.types.define_predicate(
stmt.name.lexeme,
Predicate(
type=type,
body=stmt.body,
alias=len(params) == 0,
),
)
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
def _is_valid_predicate(self, body: Type) -> bool:
match body:
case Function(returns=returns):
return self._is_valid_predicate(returns)
case _ if self.types.is_subtype(body, self._bool):
return True
case _:
return False
def visit_get_expr(self, expr: m.GetExpr) -> None:
self.reporter.warning(expr.location, "GetExpr not yet supported")
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type:
self.assert_bool(expr.left)
self.assert_bool(expr.right)
return self._bool
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
self.reporter.warning(expr.location, "VariableExpr not yet supported")
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
method: Optional[str] = MIDAS_BINARY_METHODS.get(expr.operator.type)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator.lexeme}")
self.reporter.warning(
expr.location, f"Unsupported operator {expr.operator.lexeme}"
)
return UnknownType()
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
def _visit_binary_expr(
self, location: Location, left_expr: m.Expr, right_expr: m.Expr, method: str
) -> Type:
left: Type = self.type_of(left_expr)
right: Type = self.type_of(right_expr)
operation: Optional[Type] = self.types.lookup_member(left, method)
if operation is None:
self.reporter.error(
location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
result: Optional[Type] = self._get_call_result(
location,
operation,
[(right_expr, right)],
{},
)
return result or UnknownType()
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator.lexeme}")
self.reporter.warning(
expr.location, f"Unsupported operator {expr.operator.lexeme}"
)
return UnknownType()
operand: Type = self.type_of(expr.right)
operation: Optional[Type] = self.types.lookup_member(operand, method)
if operation is None:
self.reporter.error(
expr.location,
f"Undefined operation {method} for {operand}",
)
return UnknownType()
result: Optional[Type] = self._get_call_result(
expr.location,
operation,
[],
{},
)
return result or UnknownType()
def visit_call_expr(self, expr: m.CallExpr) -> Type:
callee: Type = expr.callee.accept(self)
positional: list[TypedExpr] = [
(arg, self.type_of(arg)) for arg in expr.arguments
]
keywords: dict[str, TypedExpr] = {
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
}
return (
self._get_call_result(
expr.location,
callee,
positional,
keywords,
)
or UnknownType()
)
def visit_get_expr(self, expr: m.GetExpr) -> Type:
object: Type = expr.expr.accept(self)
member: Optional[Type] = self.types.lookup_member(object, expr.name.lexeme)
if member is None:
self.reporter.error(
expr.location, f"Unknown member '{expr.name.lexeme}' of {object}"
)
return UnknownType()
return member
def visit_variable_expr(self, expr: m.VariableExpr) -> Type:
return self.get_variable(expr.name.lexeme)
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
return expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
self.reporter.warning(expr.location, "LiteralExpr not yet supported")
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type:
match expr.value:
case bool(): # Must be before int
return self.types.get_type("bool")
case int():
return self.types.get_type("int")
case float():
return self.types.get_type("float")
case str():
return self.types.get_type("str")
case _:
self.reporter.warning(expr.location, f"Unknown literal {expr}")
return UnknownType()
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
self.reporter.warning(expr.location, "WildcardExpr not yet supported")
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type:
return self.get_variable("_")
def visit_named_type(self, type: m.NamedType) -> Type:
name: str = type.name.lexeme
@@ -153,10 +350,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
return UnknownType()
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
type_: Type = type.type.accept(self)
type.constraint.accept(self)
# TODO
return UnknownType()
return ConstraintType(
type=type.type.accept(self),
constraint=type.constraint,
)
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
return ComplexType(
@@ -172,8 +369,17 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
)
def visit_function_type(self, type: m.FunctionType) -> Type:
n_pos_args: int = len(type.pos_args)
n_args: int = len(type.args)
params: TypedParamSpec = self._visit_param_spec(type.params)
return Function(
pos_args=params.pos,
args=params.mixed,
kw_args=params.kw,
returns=type.returns.accept(self),
)
def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec:
n_pos: int = len(spec.pos)
n_mixed: int = len(spec.mixed)
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
return Function.Argument(
@@ -183,14 +389,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
required=arg.required,
)
return Function(
pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
kw_args=[
process_arg(arg, i + n_pos_args + n_args)
for i, arg in enumerate(type.kw_args)
],
returns=type.returns.accept(self),
return TypedParamSpec(
pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)],
mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)],
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
)
def _resolve_type_params(self, params: list[m.TypeParam]):
@@ -204,3 +406,343 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
self._local_variables[name] = var
vars.append(var)
return vars
def _get_call_result(
self,
location: Location,
callee: Type,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> Optional[Type]:
"""Get the result type of a function call
If the function has overloads, the function will try to resolve the
appropriate signature.
Argument types are matched to the defined parameters.
The function doesn't take the raw expression as a parameter to accommodate
for desugared calls such as for operators.
Args:
location (Location): the call location
callee (Type): the called function
positional (list[TypedExpr]): the list positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Type: the return type of the call, or `None` if either
the call is invalid or no overload matched the arguments uniquely
"""
match callee:
case Function() as function:
valid: bool
mapped: list[MappedArgument]
valid, mapped = self.map_call_arguments(
function, location, positional, keywords
)
valid = valid and self._are_arguments_valid(mapped, report_errors)
if not valid:
return None
return function.returns
case OverloadedFunction(overloads=overloads):
function = self._match_overload(
overloads, location, positional, keywords, report_errors
)
if function is None:
return None
return function.returns
case AppliedType(body=body):
return self._get_call_result(
location, body, positional, keywords, report_errors
)
case UnknownType():
return UnknownType()
case _:
if report_errors:
self.reporter.error(location, f"{callee} is not callable")
return None
def _are_arguments_valid(
self,
arguments: list[MappedArgument],
report_errors: bool = True,
) -> bool:
"""Check whether the passed argument types correspond to their matched parameter definitions
Args:
arguments (list[MappedArgument]): the list of argument/parameter pairs
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
bool: True if all arguments fit the matching parameter definitions, False otherwise
"""
valid: bool = True
for arg in arguments:
if not self.types.is_subtype(arg.type, arg.argument.type):
if report_errors:
self.reporter.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
)
valid = False
return valid
def _match_overload(
self,
overloads: list[Type],
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> Optional[Function]:
"""Try and resolve the appropriate overload for the given arguments
Args:
overloads (list[Type]): the list of possible overloads
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keywords arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Optional[Function]: the resolved function signature if it can be
determined unambiguously, or `None`.
"""
candidates: list[OverloadCandidate] = []
for overload in overloads:
function: Type = unfold_type(overload)
if not isinstance(function, Function):
if report_errors:
self.logger.error(
f"Overload is not a function: {overload} is {function}"
)
continue
valid, mapped = self.map_call_arguments(
function=function,
location=location,
positional=positional,
keywords=keywords,
report_errors=False,
)
if valid and self._are_arguments_valid(mapped, report_errors=False):
candidates.append(
OverloadCandidate(
function=function,
mapped=mapped,
)
)
pos_types: str = ", ".join(str(type) for _, type in positional)
kw_types: str = ", ".join(
f"{name}: {type}" for name, (_, type) in keywords.items()
)
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
n_candidates: int = len(candidates)
# Exactly 1 match -> return it
if n_candidates == 1:
return candidates[0].function
# No match -> invalid call
if n_candidates == 0:
overloads_str: str = ", ".join(map(str, overloads))
if report_errors:
self.reporter.error(
location,
f"No matching overload in [{overloads_str}] {for_args}",
)
return None
# Multiple matches -> see if one <: all others (more specific)
for i1, c1 in enumerate(candidates):
mapped1: list[MappedArgument] = c1.mapped
best_match: bool = True
for i2, c2 in enumerate(candidates):
if i1 == i2:
continue
mapped2: list[MappedArgument] = c2.mapped
if not self._are_mapped_subtypes(mapped1, mapped2):
best_match = False
break
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
if best_match:
return c1.function
candidates_str: str = ", ".join(
str(candidate.function) for candidate in candidates
)
if report_errors:
self.reporter.error(
location,
f"Multiple matching overloads {for_args}: {candidates_str}",
)
return None
def map_call_arguments(
self,
function: Function,
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> tuple[bool, list[MappedArgument]]:
"""Map call arguments to a function's parameters as defined in its signature
This method maps positional-only, keyword-only and mixed parameter definitions
with the arguments passed at the call site
Any mismatched, missing or unexpected argument is reported as a diagnostic,
unless `report_errors` is set to `False`
Args:
function (Function): the function definition
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
tuple[bool, list[MappedArgument]]: a boolean reporting whether
the call is valid and the list of mapped arguments
"""
set_args: set[str] = set()
required_positional: list[str] = [
arg.name for arg in function.pos_args + function.args if arg.required
]
required_keyword: list[str] = [
arg.name for arg in function.kw_args if arg.required
]
mapped: list[MappedArgument] = []
pos_params: list[Function.Argument] = list(function.pos_args)
mixed_params: list[Function.Argument] = list(function.args)
kw_params: dict[str, Function.Argument] = {
arg.name: arg for arg in function.kw_args
}
valid_call: bool = True
# TODO: handle *args and **kwargs sinks
for arg in positional:
param: Function.Argument
if len(pos_params) != 0:
param = pos_params.pop(0)
elif len(mixed_params) != 0:
param = mixed_params.pop(0)
else:
if report_errors:
self.reporter.error(
arg[0].location, "Too many positional arguments"
)
valid_call = False
break
name: str = param.name
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
kw_params.update({arg.name: arg for arg in mixed_params})
for name, arg in keywords.items():
param: Function.Argument
if name not in kw_params:
if report_errors:
if name in set_args:
self.reporter.error(
arg[0].location, f"Multiple values for argument '{name}'"
)
else:
self.reporter.error(
arg[0].location, f"Unknown keyword argument '{name}'"
)
valid_call = False
continue
param = kw_params.pop(name)
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
def join_args(args: list[str]) -> str:
args = list(map(lambda a: f"'{a}'", args))
if len(args) == 0:
return ""
if len(args) == 1:
return args[0]
return ", ".join(args[:-1]) + " and " + args[-1]
if len(required_positional) != 0:
plural: str = "" if len(required_positional) == 1 else "s"
args: str = join_args(required_positional)
if report_errors:
self.reporter.error(
location,
f"Missing required positional argument{plural}: {args}",
)
valid_call = False
if len(required_keyword) != 0:
plural: str = "" if len(required_keyword) == 1 else "s"
args: str = join_args(required_keyword)
if report_errors:
self.reporter.error(
location,
f"Missing required keyword argument{plural}: {args}",
)
valid_call = False
return valid_call, mapped
def _are_mapped_subtypes(
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
) -> bool:
"""Check whether the given argument mappings are subtype/supertype of one another
This function checks whether the argument mappings `mapped1` are subtypes
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
of the corresponding parameter in `mapped2`, `False` is returned.
This is used to check whether a given overload is
a more specific function/ a subtype of another.
Args:
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
Returns:
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
"""
by_expr: dict[m.Expr, Type] = {}
for arg in mapped1:
by_expr[arg.expr] = arg.argument.type
for arg in mapped2:
type2: Type = arg.argument.type
type1: Type = by_expr[arg.expr]
if not self.types.is_subtype(type1, type2):
return False
return True

View File

@@ -1,7 +1,9 @@
import ast
from typing import Type
OPERATOR_METHODS: dict[Type[ast.operator], str] = {
from midas.lexer.token import TokenType
PY_OPERATOR_METHODS: dict[Type[ast.operator], str] = {
ast.Add: "__add__",
ast.Sub: "__sub__",
ast.Mult: "__mul__",
@@ -17,9 +19,9 @@ OPERATOR_METHODS: dict[Type[ast.operator], str] = {
ast.FloorDiv: "__floordiv__",
}
COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
PY_COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
ast.Eq: "__eq__",
# ast.NotEq: "__noteq__",
ast.NotEq: "__eq__",
ast.Lt: "__lt__",
ast.LtE: "__le__",
ast.Gt: "__gt__",
@@ -30,9 +32,40 @@ COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
# ast.NotIn: "__notin__",
}
UNARY_METHODS: dict[Type[ast.unaryop], str] = {
PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = {
ast.Invert: "__invert__",
# ast.Not: "",
ast.UAdd: "__pos__",
ast.USub: "__neg__",
}
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
# TokenType.PLUS: "__add__",
TokenType.MINUS: "__sub__",
TokenType.STAR: "__mul__",
TokenType.SLASH: "__truediv__",
# TokenType.MODULO: "__mod__",
# TokenType.POW: "__pow__",
# ast.BitOr: "__or__",
# ast.BitXor: "__xor__",
# ast.BitAnd: "__and__",
# ast.FloorDiv: "__floordiv__",
TokenType.EQUAL_EQUAL: "__eq__",
TokenType.BANG_EQUAL: "__eq__",
TokenType.LESS: "__lt__",
TokenType.LESS_EQUAL: "__le__",
TokenType.GREATER: "__gt__",
TokenType.GREATER_EQUAL: "__ge__",
# ast.Is: "__is__",
# ast.IsNot: "__isnot__",
# ast.In: "__in__",
# ast.NotIn: "__notin__",
}
MIDAS_UNARY_METHODS: dict[TokenType, str] = {
# ast.Invert: "__invert__",
# ast.Not: "",
# TokenType.PLUS: "__pos__",
TokenType.MINUS: "__neg__",
}

View File

@@ -6,7 +6,11 @@ from typing import Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.environment import Environment
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS
from midas.checker.operators import (
PY_COMPARATOR_METHODS,
PY_OPERATOR_METHODS,
PY_UNARY_METHODS,
)
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
@@ -376,7 +380,7 @@ class PythonTyper(
pass
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
method: Optional[str] = PY_OPERATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.reporter.warning(
@@ -387,7 +391,7 @@ class PythonTyper(
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
method: Optional[str] = PY_COMPARATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.reporter.warning(
@@ -420,7 +424,7 @@ class PythonTyper(
return result or UnknownType()
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
method: Optional[str] = PY_UNARY_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.reporter.warning(
@@ -652,7 +656,7 @@ class PythonTyper(
If the function has overloads, the function will try to resolve the
appropriate signature.
Argument types are matched to the defined parameters.
The function doesn't take the raw expression as a parameter to accomodate
The function doesn't take the raw expression as a parameter to accommodate
for desugared calls such as for operators.
Args:
@@ -743,7 +747,7 @@ class PythonTyper(
Returns:
Optional[Function]: the resolved function signature if it can be
determined unambigously, or `None`.
determined unambiguously, or `None`.
"""
candidates: list[OverloadCandidate] = []
for overload in overloads:

View File

@@ -7,10 +7,12 @@ from midas.checker.types import (
AppliedType,
BaseType,
ComplexType,
ConstraintType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
Predicate,
TopType,
Type,
TypeVar,
@@ -24,6 +26,7 @@ class TypesRegistry:
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
self._types: dict[str, Type] = {}
self._members: dict[str, dict[str, Type]] = {}
self._predicates: dict[str, Predicate] = {}
def get_type(self, name: str) -> Type:
"""Get a type from its name
@@ -81,6 +84,11 @@ class TypesRegistry:
else:
members[member_name] = member_type
def define_predicate(self, name: str, predicate: Predicate):
if name in self._predicates:
raise ValueError(f"Predicate {name} already defined")
self._predicates[name] = predicate
def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2`
@@ -123,6 +131,9 @@ class TypesRegistry:
return False
return self.is_subtype(bound, type2)
case (ConstraintType(type=base1), _):
return self.is_subtype(base1, type2)
return False
# TODO: verify the logic in here
@@ -345,3 +356,6 @@ class TypesRegistry:
case _:
self.logger.debug(f"Can't get member on {type}")
return None
def lookup_predicate(self, name: str) -> Optional[Predicate]:
return self._predicates.get(name)

View File

@@ -1,7 +1,10 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, assert_never
import midas.ast.midas as m
from midas.ast.printer import MidasPrinter
@dataclass(frozen=True, kw_only=True)
@@ -130,6 +133,16 @@ class AppliedType:
return f"{self.name}[{', '.join(map(str, self.args))}]"
@dataclass(frozen=True, kw_only=True)
class ConstraintType:
type: Type
constraint: m.Expr
def __str__(self) -> str:
printer = MidasPrinter()
return f"{self.type} where {printer.print(self.constraint)}"
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def sub_argument(arg: Function.Argument):
return Function.Argument(
@@ -195,6 +208,12 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
body=substitute_typevars(body, substitutions),
)
case ConstraintType():
return ConstraintType(
type=substitute_typevars(type.type, substitutions),
constraint=type.constraint,
)
case TypeVar(name=name):
if name in substitutions:
return substitutions[name]
@@ -203,9 +222,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
case UnknownType() | UnitType():
return type
case _:
case TopType() | GenericType():
raise NotImplementedError(f"Unsupported type {type}")
# Ensure exhaustiveness
case _:
assert_never(type)
def unfold_type(type: Type) -> Type:
match type:
@@ -215,6 +238,65 @@ def unfold_type(type: Type) -> Type:
return type
def to_annotation(type: Type) -> str:
def _args_annotation(func: Function) -> str:
if len(func.kw_args) != 0:
return "..."
args: str = ", ".join(
to_annotation(arg.type) for arg in func.pos_args + func.args
)
return f"[{args}]"
match type:
case TopType():
return "Any"
case BaseType(name=name):
return name
case AliasType(name=name):
return name
case UnknownType():
return "Any"
case UnitType():
return "None"
case Function(returns=returns):
params_annot: str = _args_annotation(type)
return f"Callable[{params_annot}, {to_annotation(returns)}]"
case OverloadedFunction():
return "Callable"
case ComplexType() | ExtensionType():
raise NotImplementedError
case TypeVar(name=name):
return name
case GenericType(name=name, params=params):
return f"{name}[{', '.join(map(to_annotation, params))}]"
case AppliedType(name=name, args=args):
return f"{name}[{', '.join(map(to_annotation, args))}]"
case ConstraintType():
return str(type)
case _:
assert_never(type)
@dataclass(frozen=True, kw_only=True)
class Predicate:
type: Type
body: m.Expr
alias: bool
Type = (
TopType
| BaseType
@@ -228,4 +310,5 @@ Type = (
| TypeVar
| GenericType
| AppliedType
| ConstraintType
)

View File

@@ -38,5 +38,5 @@ def compile(
if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)):
sys.exit(1)
generator = Generator(workdir=source_path.parent)
generator = Generator(workdir=source_path.parent, types=checker.types)
generator.generate(typed_ast, source_path)

View File

@@ -8,6 +8,7 @@ from typing import TextIO
import click
from midas.ast.printer import MidasPrinter
from midas.checker.checker import TypeChecker
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
@@ -35,6 +36,7 @@ def dump_registry(
for types_file in types:
checker.import_midas(Path(types_file.name).resolve())
print("##### Types #####")
for name, type in checker.types._types.items():
members: dict[str, Type] = checker.types._members.get(name, {})
print(f"{name} = {base_type(type)}")
@@ -42,3 +44,17 @@ def dump_registry(
print(" " * 4 + "Members:")
for member_name, member_type in members.items():
print(" " * 8 + f"{member_name}: {member_type}")
print("##### Predicates #####")
printer = MidasPrinter()
for name, predicate in checker.types._predicates.items():
body: str = printer.print(predicate.body)
if predicate.alias:
print(f"{name}: {predicate.type} = {body}")
else:
print(f"{name}{predicate.type}:")
body = "\n".join(
" " + ("return " if i == 0 else "") + line
for i, line in enumerate(body.split("\n"))
)
print(body)

View File

@@ -0,0 +1,224 @@
import ast
from typing import Optional
import midas.ast.midas as m
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
Function,
Predicate,
Type,
to_annotation,
)
from midas.lexer.token import TokenType
LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = {
TokenType.AND: ast.And,
# TokenType.OR: ast.Or,
}
BINARY_OPERATORS: dict[TokenType, type[ast.operator]] = {
# TokenType.PLUS: ast.Add,
TokenType.MINUS: ast.Sub,
TokenType.STAR: ast.Mult,
TokenType.SLASH: ast.Div,
}
UNARY_OPERATORS: dict[TokenType, type[ast.unaryop]] = {
# TokenType.PLUS: ast.UAdd,
TokenType.MINUS: ast.USub,
}
COMPARISON_OPERATORS: dict[TokenType, type[ast.cmpop]] = {
TokenType.GREATER: ast.Gt,
TokenType.GREATER_EQUAL: ast.GtE,
TokenType.LESS: ast.Lt,
TokenType.LESS_EQUAL: ast.LtE,
TokenType.EQUAL_EQUAL: ast.Eq,
TokenType.BANG_EQUAL: ast.NotEq,
}
class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
def __init__(self, types: TypesRegistry):
self.types: TypesRegistry = types
self._id: int = 0
self._definitions: list[ast.stmt] = []
self._aliases: dict[str, str] = {}
def get_definitions(self) -> list[ast.stmt]:
return self._definitions
def generate(self, expr: m.Expr) -> ast.expr:
match expr:
case m.VariableExpr():
return expr.accept(self)
case _:
func = Function(
pos_args=[],
args=[
Function.Argument(
pos=0,
name="_",
type=self.types.get_type("Any"),
required=True,
)
],
kw_args=[],
returns=self.types.get_type("bool"),
)
alias: str = self.make_alias(None)
definition: ast.stmt = self.make_definition(
alias, Predicate(type=func, body=expr, alias=False)
)
self._definitions.append(definition)
return ast.Name(id=alias)
def make_alias(self, name: Optional[str]) -> str:
suffix: str
if name is None:
suffix = f"p{self._id}"
self._id += 1
else:
suffix = name
alias: str = f"__midas_{suffix}__"
return alias
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
body: ast.expr = predicate.body.accept(self)
if predicate.alias:
return ast.Assign(
targets=[
ast.Name(id=name),
],
value=body,
)
return self.make_func(name, [ast.Return(value=body)], predicate.type)
def make_args(self, func: Function) -> ast.arguments:
return ast.arguments(
posonlyargs=[
ast.arg(
arg=arg.name,
annotation=ast.Constant(value=to_annotation(arg.type)),
)
for arg in func.pos_args
],
args=[
ast.arg(
arg=arg.name,
annotation=ast.Constant(value=to_annotation(arg.type)),
)
for arg in func.args
],
kwonlyargs=[
ast.arg(
arg=arg.name,
annotation=ast.Constant(value=to_annotation(arg.type)),
)
for arg in func.kw_args
],
defaults=[],
kw_defaults=[],
)
def make_func(
self, name: str, inner_body: list[ast.stmt], type: Type, level: int = 0
) -> ast.stmt:
match type:
case Function(returns=Function()):
inner_name: str = f"inner{level}"
return ast.FunctionDef(
name=name,
args=self.make_args(type),
body=[
self.make_func(inner_name, inner_body, type.returns, level + 1),
ast.Return(value=ast.Name(id=inner_name)),
],
returns=ast.Constant(value=to_annotation(type.returns)),
decorator_list=[],
)
case Function():
return ast.FunctionDef(
name=name,
args=self.make_args(type),
body=inner_body,
returns=ast.Constant(value=to_annotation(type.returns)),
decorator_list=[],
)
case _:
raise ValueError(f"Expected function, got {type!r}")
def get_predicate(self, name: str) -> Optional[ast.expr]:
if name not in self._aliases:
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
if predicate is None:
return None
alias: str = self.make_alias(name)
self._aliases[name] = alias
self._definitions.append(self.make_definition(alias, predicate))
return ast.Name(id=self._aliases[name])
def visit_logical_expr(self, expr: m.LogicalExpr) -> ast.expr:
return ast.BoolOp(
op=LOGICAL_OPERATORS[expr.operator.type](),
values=[
expr.left.accept(self),
expr.right.accept(self),
],
)
def visit_binary_expr(self, expr: m.BinaryExpr) -> ast.expr:
op: TokenType = expr.operator.type
if op in BINARY_OPERATORS:
return ast.BinOp(
left=expr.left.accept(self),
op=BINARY_OPERATORS[op](),
right=expr.right.accept(self),
)
if op in COMPARISON_OPERATORS:
return ast.Compare(
left=expr.left.accept(self),
ops=[COMPARISON_OPERATORS[op]()],
comparators=[expr.right.accept(self)],
)
raise ValueError(f"Unexpected binary operator {op}")
def visit_unary_expr(self, expr: m.UnaryExpr) -> ast.expr:
return ast.UnaryOp(
op=UNARY_OPERATORS[expr.operator.type](),
operand=expr.right.accept(self),
)
def visit_call_expr(self, expr: m.CallExpr) -> ast.expr:
return ast.Call(
func=expr.callee.accept(self),
args=[arg.accept(self) for arg in expr.arguments],
keywords=[
ast.keyword(arg=name, value=arg.accept(self))
for name, arg in expr.keywords.items()
],
)
def visit_get_expr(self, expr: m.GetExpr) -> ast.expr:
return ast.Attribute(
value=expr.expr.accept(self),
attr=expr.name.lexeme,
)
def visit_variable_expr(self, expr: m.VariableExpr) -> ast.expr:
name: str = expr.name.lexeme
if (p := self.get_predicate(name)) is not None:
return p
return ast.Name(id=name)
def visit_grouping_expr(self, expr: m.GroupingExpr) -> ast.expr:
return expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> ast.expr:
return ast.Constant(value=expr.value)
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> ast.expr:
return ast.Name(id="_")

View File

@@ -2,15 +2,19 @@ import ast
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from typing import Optional, assert_never
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasPrinter
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
ConstraintType,
ExtensionType,
Function,
GenericType,
@@ -19,7 +23,9 @@ from midas.checker.types import (
Type,
TypeVar,
UnitType,
UnknownType,
)
from midas.generator.constraints import ConstraintGenerator
from midas.utils import TypedAST
@@ -30,12 +36,9 @@ class Scope:
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def __init__(self, workdir: Path) -> None:
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
self.workdir: Path = workdir.resolve()
self.build_dir: Path = self.workdir / "build" / "midas"
if self.build_dir.exists():
shutil.rmtree(self.build_dir)
self.build_dir.mkdir(parents=True, exist_ok=True)
self.rel_src_path: Path = Path()
self._typed_ast: TypedAST = TypedAST(
@@ -43,13 +46,18 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
judgements=[],
)
self._alias_count: int = 0
self._predicate_count: int = 0
self._scopes: list[Scope] = []
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
self._constraints: list[tuple[m.Expr, ast.expr]] = []
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
self.rel_src_path = src_path.relative_to(self.workdir)
self.rel_src_path = src_path.resolve().relative_to(self.workdir)
self._typed_ast = typed_ast
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
module = ast.Module(body=body, type_ignores=[])
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
module = ast.Module(body=predicates + body, type_ignores=[])
module = ast.fix_missing_locations(module)
return module
@@ -59,6 +67,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
module: ast.AST = self.generate_ast(typed_ast, src_path)
compiled: str = ast.unparse(module)
if out_path is None:
if self.build_dir.exists():
shutil.rmtree(self.build_dir)
self.build_dir.mkdir(parents=True, exist_ok=True)
out_path = (self.build_dir / self.rel_src_path).resolve()
try:
_ = out_path.relative_to(self.build_dir)
@@ -246,7 +257,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
return generated
def _make_alias(self, expr: ast.expr) -> ast.expr:
name: str = f"__midas_alias_{self._alias_count}__"
name: str = f"__midas_a{self._alias_count}__"
alias = ast.Name(id=name)
self._alias_count += 1
self._scopes[-1].aliases.append(name)
@@ -276,6 +287,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
match type:
case UnknownType():
pass
case BaseType(name=name):
self._add_assert(
ast.Call(
@@ -301,8 +315,15 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._make_cast_assert_message(src_location, expr, type),
)
case AppliedType():
self._make_cast_asserts(src_location, expr, type.body)
case AppliedType(body=body):
self._make_cast_asserts(src_location, expr, body)
case ConstraintType(type=base, constraint=constraint):
self._make_cast_asserts(src_location, expr, base)
self._make_constraint_assert(src_location, expr, constraint)
case TypeVar():
raise RuntimeError("Unexpected TypeVar")
case (
TopType()
@@ -314,8 +335,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
):
raise NotImplementedError(f"Can't make assertion for type {type}")
case TypeVar():
raise RuntimeError("Unexpected TypeVar")
# Ensure exhaustiveness
case _:
assert_never(type)
def _make_cast_assert_message(
self, location: Location, expr: ast.expr, type: Type
@@ -339,3 +361,36 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
ast.Constant(f" to {type}"),
]
)
def _make_constraint_assert(
self, src_location: Location, expr: ast.expr, constraint: m.Expr
):
test_func: ast.expr = self._get_constraint(constraint)
self._add_assert(
ast.Call(
func=test_func,
args=[expr],
keywords=[],
),
self._make_constraint_assert_message(src_location, expr, constraint),
)
def _make_constraint_assert_message(
self, location: Location, expr: ast.expr, constraint: m.Expr
) -> ast.expr:
printer = MidasPrinter()
constraint_str: str = printer.print(constraint)
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
# f"file.py:L1:1: ConstraintError: Value does not fit constraint 'v > 0'"
return ast.Constant(
f"{loc_str}: ConstraintError: Value does not fit constraint '{constraint_str}'"
)
def _get_constraint(self, expr: m.Expr) -> ast.expr:
for expr2, constraint in self._constraints:
if expr2 == expr:
return constraint
constraint: ast.expr = self._constraint_generator.generate(expr)
self._constraints.append((expr, constraint))
return constraint

View File

@@ -69,6 +69,8 @@ class MidasLexer(Lexer):
):
self.advance()
self.add_token(TokenType.WHITESPACE)
case '"' | "'":
self.scan_string(char)
case _:
if char.isdigit():
self.scan_number()
@@ -78,6 +80,17 @@ class MidasLexer(Lexer):
self.error("Unexpected character")
return None
def scan_string(self, opening: str):
while self.peek() != opening and not self.is_at_end():
self.advance()
if self.is_at_end():
self.error("Unterminated string")
self.advance()
value: str = self.source[self.start + 1 : self.idx - 1]
self.add_token(TokenType.STRING, value)
def scan_number(self):
"""Scan the rest of number and add it as a token

View File

@@ -43,6 +43,7 @@ class TokenType(Enum):
TRUE = auto()
FALSE = auto()
NONE = auto()
STRING = auto()
# Keywords
TYPE = auto()

View File

@@ -3,6 +3,7 @@ from typing import Optional
from midas.ast.location import Location
from midas.ast.midas import (
BinaryExpr,
CallExpr,
ComplexType,
ConstraintType,
Expr,
@@ -17,6 +18,7 @@ from midas.ast.midas import (
MemberKind,
MemberStmt,
NamedType,
ParamSpec,
PredicateStmt,
Stmt,
Type,
@@ -265,6 +267,9 @@ class MidasParser(Parser):
Returns:
Expr: the parsed constraint expression
"""
return self.expression()
def expression(self) -> Expr:
return self.and_()
def and_(self) -> Expr:
@@ -331,7 +336,55 @@ class MidasParser(Parser):
right: Expr = self.unary()
location: Location = Location.span(operator.get_location(), right.location)
return UnaryExpr(location=location, operator=operator, right=right)
return self.reference()
return self.call()
def call(self) -> Expr:
expr: Expr = self.reference()
while self.match(TokenType.LEFT_PAREN):
expr = self.finish_call(expr)
return expr
def finish_call(self, callee: Expr) -> Expr:
pos_args: list[Expr] = []
kw_args: dict[str, Expr] = {}
keywords: bool = False
while not self.match(TokenType.RIGHT_PAREN):
if self.check_identifier() and self.check_next(TokenType.EQUAL):
keywords = True
keyword: Token = self.advance()
self.advance()
value: Expr = self.expression()
name: str = keyword.lexeme
if name in kw_args:
self.error(
self.peek(),
f"Multiple values passed for '{name}', only the last occurrence will be used",
)
kw_args[name] = value
else:
value = self.expression()
if self.check(TokenType.EQUAL):
if keywords:
raise self.error(self.peek(), "Invalid keyword argument name")
else:
raise self.error(
self.peek(),
"Cannot pass positional arguments after a keyword argument",
)
pos_args.append(value)
if not self.match(TokenType.COMMA):
break
r_paren: Token = self.consume(
TokenType.RIGHT_PAREN, "Expected ')' after arguments."
)
return CallExpr(
location=Location.span(callee.location, r_paren.get_location()),
callee=callee,
arguments=pos_args,
keywords=kw_args,
)
def reference(self) -> Expr:
"""Parse an attribute access expression or a simpler expression
@@ -365,6 +418,9 @@ class MidasParser(Parser):
if self.match(TokenType.NUMBER):
return LiteralExpr(location=token.get_location(), value=token.value)
if self.match(TokenType.STRING):
return LiteralExpr(location=token.get_location(), value=token.value)
if self.match_identifier():
return VariableExpr(location=token.get_location(), name=token)
@@ -453,23 +509,35 @@ class MidasParser(Parser):
PredicateStmt: the parsed predicate declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume_identifier("Expected predicate name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
subject: Token = self.consume_identifier("Expected subject name")
self.consume(TokenType.COLON, "Expected ':' after subject name")
type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
params: list[ParamSpec] = []
while self.check(TokenType.LEFT_PAREN):
params.append(self.function_args())
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
condition: Expr = self.constraint()
body: Expr = self.constraint()
return PredicateStmt(
location=keyword.location_to(self.previous()),
name=name,
subject=subject,
type=type,
condition=condition,
params=params,
body=body,
)
def function(self) -> FunctionType:
params: ParamSpec = self.function_args()
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: Type = self.type_expr()
return FunctionType(
location=params.l_paren.location_to(self.previous()),
params=params,
returns=result,
)
def function_args(self) -> ParamSpec:
l_paren: Token = self.consume(
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
)
@@ -526,14 +594,4 @@ class MidasParser(Parser):
self.error(token, "Unnamed mixed argument")
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: Type = self.type_expr()
return FunctionType(
location=l_paren.location_to(self.previous()),
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=result,
)
return ParamSpec(l_paren=l_paren, pos=pos_args, mixed=args, kw=kw_args)

View File

@@ -9,13 +9,13 @@ Module(
level=0),
Assign(
targets=[
Name(id='__midas_alias_0__')],
Name(id='__midas_a0__')],
value=Constant(value=123.45)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_alias_0__'),
Name(id='__midas_a0__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
@@ -26,7 +26,7 @@ Module(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_alias_0__')],
Name(id='__midas_a0__')],
keywords=[]),
attr='__name__'),
conversion=-1),
@@ -34,19 +34,19 @@ Module(
Assign(
targets=[
Name(id='distance')],
value=Name(id='__midas_alias_0__')),
value=Name(id='__midas_a0__')),
Delete(
targets=[
Name(id='__midas_alias_0__')]),
Name(id='__midas_a0__')]),
Assign(
targets=[
Name(id='__midas_alias_1__')],
Name(id='__midas_a1__')],
value=Constant(value=6.7)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_alias_1__'),
Name(id='__midas_a1__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
@@ -57,7 +57,7 @@ Module(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_alias_1__')],
Name(id='__midas_a1__')],
keywords=[]),
attr='__name__'),
conversion=-1),
@@ -65,10 +65,10 @@ Module(
Assign(
targets=[
Name(id='time')],
value=Name(id='__midas_alias_1__')),
value=Name(id='__midas_a1__')),
Delete(
targets=[
Name(id='__midas_alias_1__')]),
Name(id='__midas_a1__')]),
Assign(
targets=[
Name(id='speed')],

View File

@@ -0,0 +1,14 @@
// Inline
type T1 = float where _ > 0
// Named
predicate is_positive(v: float) = v > 0
type T2 = float where is_positive(_)
// Curried
predicate in_range(mn: float, mx: float)(v: float) = v >= mn & v < mx
type T3 = float where in_range(100, 200)(_)
// Alias
predicate minor = in_range(0, 18)
type T4 = float where minor(_)

View File

@@ -0,0 +1,8 @@
from midas import T1, T2, T3, T4, cast
t: float = 12.5
t1: T1 = cast(T1, t)
t2: T2 = cast(T2, t)
t3: T3 = cast(T3, t)
t4: T4 = cast(T4, t)

View File

@@ -0,0 +1,333 @@
Module(
body=[
FunctionDef(
name='__midas_p0__',
args=arguments(
posonlyargs=[],
args=[
arg(
arg='_',
annotation=Constant(value='Any'))],
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
body=[
Return(
value=Compare(
left=Name(id='_'),
ops=[
Gt()],
comparators=[
Constant(value=0.0)]))],
decorator_list=[],
returns=Constant(value='bool')),
FunctionDef(
name='__midas_is_positive__',
args=arguments(
posonlyargs=[],
args=[
arg(
arg='v',
annotation=Constant(value='float'))],
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
body=[
Return(
value=Compare(
left=Name(id='v'),
ops=[
Gt()],
comparators=[
Constant(value=0.0)]))],
decorator_list=[],
returns=Constant(value='bool')),
FunctionDef(
name='__midas_p1__',
args=arguments(
posonlyargs=[],
args=[
arg(
arg='_',
annotation=Constant(value='Any'))],
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
body=[
Return(
value=Call(
func=Name(id='__midas_is_positive__'),
args=[
Name(id='_')],
keywords=[]))],
decorator_list=[],
returns=Constant(value='bool')),
FunctionDef(
name='__midas_in_range__',
args=arguments(
posonlyargs=[],
args=[
arg(
arg='mn',
annotation=Constant(value='float')),
arg(
arg='mx',
annotation=Constant(value='float'))],
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
body=[
FunctionDef(
name='inner0',
args=arguments(
posonlyargs=[],
args=[
arg(
arg='v',
annotation=Constant(value='float'))],
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
body=[
Return(
value=BoolOp(
op=And(),
values=[
Compare(
left=Name(id='v'),
ops=[
GtE()],
comparators=[
Name(id='mn')]),
Compare(
left=Name(id='v'),
ops=[
Lt()],
comparators=[
Name(id='mx')])]))],
decorator_list=[],
returns=Constant(value='bool')),
Return(
value=Name(id='inner0'))],
decorator_list=[],
returns=Constant(value='Callable[[float], bool]')),
FunctionDef(
name='__midas_p2__',
args=arguments(
posonlyargs=[],
args=[
arg(
arg='_',
annotation=Constant(value='Any'))],
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
body=[
Return(
value=Call(
func=Call(
func=Name(id='__midas_in_range__'),
args=[
Constant(value=100.0),
Constant(value=200.0)],
keywords=[]),
args=[
Name(id='_')],
keywords=[]))],
decorator_list=[],
returns=Constant(value='bool')),
Assign(
targets=[
Name(id='__midas_minor__')],
value=Call(
func=Name(id='__midas_in_range__'),
args=[
Constant(value=0.0),
Constant(value=18.0)],
keywords=[])),
FunctionDef(
name='__midas_p3__',
args=arguments(
posonlyargs=[],
args=[
arg(
arg='_',
annotation=Constant(value='Any'))],
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
body=[
Return(
value=Call(
func=Name(id='__midas_minor__'),
args=[
Name(id='_')],
keywords=[]))],
decorator_list=[],
returns=Constant(value='bool')),
ImportFrom(
module='midas',
names=[
alias(name='T1'),
alias(name='T2'),
alias(name='T3'),
alias(name='T4'),
alias(name='cast')],
level=0),
Assign(
targets=[
Name(id='t')],
value=Constant(value=12.5)),
Assign(
targets=[
Name(id='__midas_a0__')],
value=Name(id='t')),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_a0__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='02_constraints.py:L5:10: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_a0__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assert(
test=Call(
func=Name(id='__midas_p0__'),
args=[
Name(id='__midas_a0__')],
keywords=[]),
msg=Constant(value="02_constraints.py:L5:10: ConstraintError: Value does not fit constraint '_ > 0.0'")),
Assign(
targets=[
Name(id='t1')],
value=Name(id='__midas_a0__')),
Delete(
targets=[
Name(id='__midas_a0__')]),
Assign(
targets=[
Name(id='__midas_a1__')],
value=Name(id='t')),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_a1__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='02_constraints.py:L6:10: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_a1__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assert(
test=Call(
func=Name(id='__midas_p1__'),
args=[
Name(id='__midas_a1__')],
keywords=[]),
msg=Constant(value="02_constraints.py:L6:10: ConstraintError: Value does not fit constraint 'is_positive(_)'")),
Assign(
targets=[
Name(id='t2')],
value=Name(id='__midas_a1__')),
Delete(
targets=[
Name(id='__midas_a1__')]),
Assign(
targets=[
Name(id='__midas_a2__')],
value=Name(id='t')),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_a2__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='02_constraints.py:L7:10: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_a2__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assert(
test=Call(
func=Name(id='__midas_p2__'),
args=[
Name(id='__midas_a2__')],
keywords=[]),
msg=Constant(value="02_constraints.py:L7:10: ConstraintError: Value does not fit constraint 'in_range(100.0, 200.0)(_)'")),
Assign(
targets=[
Name(id='t3')],
value=Name(id='__midas_a2__')),
Delete(
targets=[
Name(id='__midas_a2__')]),
Assign(
targets=[
Name(id='__midas_a3__')],
value=Name(id='t')),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_a3__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='02_constraints.py:L8:10: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_a3__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assert(
test=Call(
func=Name(id='__midas_p3__'),
args=[
Name(id='__midas_a3__')],
keywords=[]),
msg=Constant(value="02_constraints.py:L8:10: ConstraintError: Value does not fit constraint 'minor(_)'")),
Assign(
targets=[
Name(id='t4')],
value=Name(id='__midas_a3__')),
Delete(
targets=[
Name(id='__midas_a3__')])],
type_ignores=[])

View File

@@ -2582,7 +2582,9 @@
"name": "__sub__",
"type": {
"_type": "FunctionType",
"pos_args": [
"params": {
"_type": "ParamSpec",
"pos": [
{
"name": null,
"type": {
@@ -2592,8 +2594,9 @@
"required": true
}
],
"args": [],
"kw_args": [],
"mixed": [],
"kw": []
},
"returns": {
"_type": "GenericType",
"type": {
@@ -2673,7 +2676,9 @@
"name": "__sub__",
"type": {
"_type": "FunctionType",
"pos_args": [
"params": {
"_type": "ParamSpec",
"pos": [
{
"name": null,
"type": {
@@ -2683,8 +2688,9 @@
"required": true
}
],
"args": [],
"kw_args": [],
"mixed": [],
"kw": []
},
"returns": {
"_type": "GenericType",
"type": {
@@ -2713,7 +2719,9 @@
"name": "__sub__",
"type": {
"_type": "FunctionType",
"pos_args": [
"params": {
"_type": "ParamSpec",
"pos": [
{
"name": null,
"type": {
@@ -2723,8 +2731,9 @@
"required": true
}
],
"args": [],
"kw_args": [],
"mixed": [],
"kw": []
},
"returns": {
"_type": "GenericType",
"type": {
@@ -2745,12 +2754,24 @@
{
"_type": "PredicateStmt",
"name": "Positive",
"subject": "v",
"params": [
{
"_type": "ParamSpec",
"pos": [],
"mixed": [
{
"name": "v",
"type": {
"_type": "NamedType",
"name": "float"
},
"condition": {
"required": true
}
],
"kw": []
}
],
"body": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
@@ -2766,12 +2787,24 @@
{
"_type": "PredicateStmt",
"name": "StrictlyPositive",
"subject": "v",
"params": [
{
"_type": "ParamSpec",
"pos": [],
"mixed": [
{
"name": "v",
"type": {
"_type": "NamedType",
"name": "float"
},
"condition": {
"required": true
}
],
"kw": []
}
],
"body": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
@@ -2787,12 +2820,24 @@
{
"_type": "PredicateStmt",
"name": "Equatorial",
"subject": "loc",
"params": [
{
"_type": "ParamSpec",
"pos": [],
"mixed": [
{
"name": "loc",
"type": {
"_type": "NamedType",
"name": "GeoLocation"
},
"condition": {
"required": true
}
],
"kw": []
}
],
"body": {
"_type": "GroupingExpr",
"expr": {
"_type": "BinaryExpr",
@@ -2827,12 +2872,24 @@
{
"_type": "PredicateStmt",
"name": "Arctic",
"subject": "loc",
"params": [
{
"_type": "ParamSpec",
"pos": [],
"mixed": [
{
"name": "loc",
"type": {
"_type": "NamedType",
"name": "GeoLocation"
},
"condition": {
"required": true
}
],
"kw": []
}
],
"body": {
"_type": "GroupingExpr",
"expr": {
"_type": "BinaryExpr",

View File

@@ -45,7 +45,7 @@ class GeneratorTester(Tester):
typed_ast: TypedAST = checker.type_check(path)
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
generator = Generator(workdir=path.parent)
generator = Generator(workdir=path.parent, types=checker.types)
result.compiled_ast = generator.generate_ast(typed_ast, path)
return result

View File

@@ -2,6 +2,7 @@ from typing import Optional, Sequence
from midas.ast.midas import (
BinaryExpr,
CallExpr,
ComplexType,
ConstraintType,
Expr,
@@ -15,6 +16,7 @@ from midas.ast.midas import (
LogicalExpr,
MemberStmt,
NamedType,
ParamSpec,
PredicateStmt,
Stmt,
Type,
@@ -78,9 +80,8 @@ class MidasAstJsonSerializer(
return {
"_type": "PredicateStmt",
"name": stmt.name.lexeme,
"subject": stmt.subject.lexeme,
"type": stmt.type.accept(self),
"condition": stmt.condition.accept(self),
"params": [self._serialize_param_spec(spec) for spec in stmt.params],
"body": stmt.body.accept(self),
}
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
@@ -106,6 +107,14 @@ class MidasAstJsonSerializer(
"right": expr.right.accept(self),
}
def visit_call_expr(self, expr: CallExpr) -> dict:
return {
"_type": "CallExpr",
"callee": expr.callee.accept(self),
"arguments": self._serialize_list(expr.arguments),
"keywords": {name: arg.accept(self) for name, arg in expr.keywords.items()},
}
def visit_get_expr(self, expr: GetExpr) -> dict:
return {
"_type": "GetExpr",
@@ -163,15 +172,21 @@ class MidasAstJsonSerializer(
def visit_function_type(self, type: FunctionType) -> dict:
return {
"_type": "FunctionType",
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
"args": [self._serialize_func_arg(arg) for arg in type.args],
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
"params": self._serialize_param_spec(type.params),
"returns": type.returns.accept(self),
}
def _serialize_param_spec(self, spec: ParamSpec) -> dict:
return {
"_type": "ParamSpec",
"pos": [self._serialize_func_arg(arg) for arg in spec.pos],
"mixed": [self._serialize_func_arg(arg) for arg in spec.mixed],
"kw": [self._serialize_func_arg(arg) for arg in spec.kw],
}
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
return {
"name": arg.name,
"name": arg.name.lexeme if arg.name is not None else None,
"type": arg.type.accept(self),
"required": arg.required,
}