feat(gen): generate type hints for functions

This commit is contained in:
2026-06-19 14:11:38 +02:00
parent 657406ea01
commit eb5bf19c61
2 changed files with 81 additions and 4 deletions

View File

@@ -238,6 +238,58 @@ def unfold_type(type: Type) -> Type:
return type
def to_annotation(type: Type) -> str:
def _args_annotation(func: Function) -> str:
if len(func.kw_args) != 0:
return "..."
args: str = ", ".join(
to_annotation(arg.type) for arg in func.pos_args + func.args
)
return f"[{args}]"
match type:
case TopType():
return "Any"
case BaseType(name=name):
return name
case AliasType(name=name):
return name
case UnknownType():
return "Any"
case UnitType():
return "None"
case Function(returns=returns):
params_annot: str = _args_annotation(type)
return f"Callable[{params_annot}, {to_annotation(returns)}]"
case OverloadedFunction():
return "Callable"
case ComplexType() | ExtensionType():
raise NotImplementedError
case TypeVar(name=name):
return name
case GenericType(name=name, params=params):
return f"{name}[{', '.join(map(to_annotation, params))}]"
case AppliedType(name=name, args=args):
return f"{name}[{', '.join(map(to_annotation, args))}]"
case ConstraintType():
return str(type)
case _:
assert_never(type)
@dataclass(frozen=True, kw_only=True)
class Predicate:
type: Type

View File

@@ -3,7 +3,12 @@ from typing import Optional
import midas.ast.midas as m
from midas.checker.registry import TypesRegistry
from midas.checker.types import Function, Predicate, Type
from midas.checker.types import (
Function,
Predicate,
Type,
to_annotation,
)
from midas.lexer.token import TokenType
LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = {
@@ -91,9 +96,27 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
def make_args(self, func: Function) -> ast.arguments:
return ast.arguments(
posonlyargs=[ast.arg(arg=arg.name) for arg in func.pos_args],
args=[ast.arg(arg=arg.name) for arg in func.args],
kwonlyargs=[ast.arg(arg=arg.name) for arg in func.kw_args],
posonlyargs=[
ast.arg(
arg=arg.name,
annotation=ast.Constant(value=to_annotation(arg.type)),
)
for arg in func.pos_args
],
args=[
ast.arg(
arg=arg.name,
annotation=ast.Constant(value=to_annotation(arg.type)),
)
for arg in func.args
],
kwonlyargs=[
ast.arg(
arg=arg.name,
annotation=ast.Constant(value=to_annotation(arg.type)),
)
for arg in func.kw_args
],
defaults=[],
kw_defaults=[],
)
@@ -111,6 +134,7 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
self.make_func(inner_name, inner_body, type.returns, level + 1),
ast.Return(value=ast.Name(id=inner_name)),
],
returns=ast.Constant(value=to_annotation(type.returns)),
decorator_list=[],
)
@@ -119,6 +143,7 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
name=name,
args=self.make_args(type),
body=inner_body,
returns=ast.Constant(value=to_annotation(type.returns)),
decorator_list=[],
)