fix(gen): handle ConstraintType in stubs generator

This commit is contained in:
2026-06-20 17:34:22 +02:00
parent 9e83079910
commit b02ecc6326

View File

@@ -1,5 +1,5 @@
import ast
from typing import Optional
from typing import Optional, assert_never
import midas.ast.midas as m
from midas.checker.registry import Member, TypesRegistry
@@ -8,6 +8,7 @@ from midas.checker.types import (
AppliedType,
BaseType,
ComplexType,
ConstraintType,
ExtensionType,
Function,
GenericType,
@@ -84,6 +85,7 @@ class StubsGenerator:
match type:
case AliasType(type=base):
return [self.dump_type(base)], {}
case GenericType(params=params, body=body):
self.add_typing_import("Generic")
type_vars: ast.expr
@@ -111,6 +113,7 @@ class StubsGenerator:
],
body_subsitutions | substitutions,
)
case _:
return [], {}
@@ -148,15 +151,20 @@ class StubsGenerator:
case TopType() | UnknownType():
self.add_typing_import("Any")
return ast.Name(id="Any")
case BaseType(name=name):
return ast.Name(id=name)
case AliasType(name=name):
return ast.Name(id=name)
case UnitType():
return ast.Constant(value=None)
case Function():
name: str = self.define_protocol(type)
return ast.Name(id=name)
case OverloadedFunction(overloads=overloads):
if len(overloads) == 1:
return self.dump_type(overloads[0])
@@ -176,6 +184,7 @@ class StubsGenerator:
case TypeVar():
return ast.Name(id=type.name)
case GenericType(name=name):
params: ast.expr
if len(type.params) == 1:
@@ -188,6 +197,7 @@ class StubsGenerator:
value=ast.Name(id=type.name),
slice=params,
)
case AppliedType():
args: ast.expr
if len(type.args) == 1:
@@ -199,6 +209,12 @@ class StubsGenerator:
slice=args,
)
case ConstraintType():
return self.dump_type(type.type)
case _:
assert_never(type)
def dump_method(
self, name: str, method: Type, overloaded: bool = False
) -> list[ast.stmt]: