fix(gen): handle ConstraintType in stubs generator

This commit is contained in:
2026-06-20 17:34:22 +02:00
parent 9e83079910
commit b02ecc6326

View File

@@ -1,5 +1,5 @@
import ast import ast
from typing import Optional from typing import Optional, assert_never
import midas.ast.midas as m import midas.ast.midas as m
from midas.checker.registry import Member, TypesRegistry from midas.checker.registry import Member, TypesRegistry
@@ -8,6 +8,7 @@ from midas.checker.types import (
AppliedType, AppliedType,
BaseType, BaseType,
ComplexType, ComplexType,
ConstraintType,
ExtensionType, ExtensionType,
Function, Function,
GenericType, GenericType,
@@ -84,6 +85,7 @@ class StubsGenerator:
match type: match type:
case AliasType(type=base): case AliasType(type=base):
return [self.dump_type(base)], {} return [self.dump_type(base)], {}
case GenericType(params=params, body=body): case GenericType(params=params, body=body):
self.add_typing_import("Generic") self.add_typing_import("Generic")
type_vars: ast.expr type_vars: ast.expr
@@ -111,6 +113,7 @@ class StubsGenerator:
], ],
body_subsitutions | substitutions, body_subsitutions | substitutions,
) )
case _: case _:
return [], {} return [], {}
@@ -148,15 +151,20 @@ class StubsGenerator:
case TopType() | UnknownType(): case TopType() | UnknownType():
self.add_typing_import("Any") self.add_typing_import("Any")
return ast.Name(id="Any") return ast.Name(id="Any")
case BaseType(name=name): case BaseType(name=name):
return ast.Name(id=name) return ast.Name(id=name)
case AliasType(name=name): case AliasType(name=name):
return ast.Name(id=name) return ast.Name(id=name)
case UnitType(): case UnitType():
return ast.Constant(value=None) return ast.Constant(value=None)
case Function(): case Function():
name: str = self.define_protocol(type) name: str = self.define_protocol(type)
return ast.Name(id=name) return ast.Name(id=name)
case OverloadedFunction(overloads=overloads): case OverloadedFunction(overloads=overloads):
if len(overloads) == 1: if len(overloads) == 1:
return self.dump_type(overloads[0]) return self.dump_type(overloads[0])
@@ -176,6 +184,7 @@ class StubsGenerator:
case TypeVar(): case TypeVar():
return ast.Name(id=type.name) return ast.Name(id=type.name)
case GenericType(name=name): case GenericType(name=name):
params: ast.expr params: ast.expr
if len(type.params) == 1: if len(type.params) == 1:
@@ -188,6 +197,7 @@ class StubsGenerator:
value=ast.Name(id=type.name), value=ast.Name(id=type.name),
slice=params, slice=params,
) )
case AppliedType(): case AppliedType():
args: ast.expr args: ast.expr
if len(type.args) == 1: if len(type.args) == 1:
@@ -199,6 +209,12 @@ class StubsGenerator:
slice=args, slice=args,
) )
case ConstraintType():
return self.dump_type(type.type)
case _:
assert_never(type)
def dump_method( def dump_method(
self, name: str, method: Type, overloaded: bool = False self, name: str, method: Type, overloaded: bool = False
) -> list[ast.stmt]: ) -> list[ast.stmt]: