diff --git a/midas/generator/stubs.py b/midas/generator/stubs.py index 4c075fc..abbcf06 100644 --- a/midas/generator/stubs.py +++ b/midas/generator/stubs.py @@ -1,5 +1,5 @@ import ast -from typing import Optional +from typing import Optional, assert_never import midas.ast.midas as m from midas.checker.registry import Member, TypesRegistry @@ -8,6 +8,7 @@ from midas.checker.types import ( AppliedType, BaseType, ComplexType, + ConstraintType, ExtensionType, Function, GenericType, @@ -84,6 +85,7 @@ class StubsGenerator: match type: case AliasType(type=base): return [self.dump_type(base)], {} + case GenericType(params=params, body=body): self.add_typing_import("Generic") type_vars: ast.expr @@ -111,6 +113,7 @@ class StubsGenerator: ], body_subsitutions | substitutions, ) + case _: return [], {} @@ -148,15 +151,20 @@ class StubsGenerator: case TopType() | UnknownType(): self.add_typing_import("Any") return ast.Name(id="Any") + case BaseType(name=name): return ast.Name(id=name) + case AliasType(name=name): return ast.Name(id=name) + case UnitType(): return ast.Constant(value=None) + case Function(): name: str = self.define_protocol(type) return ast.Name(id=name) + case OverloadedFunction(overloads=overloads): if len(overloads) == 1: return self.dump_type(overloads[0]) @@ -176,6 +184,7 @@ class StubsGenerator: case TypeVar(): return ast.Name(id=type.name) + case GenericType(name=name): params: ast.expr if len(type.params) == 1: @@ -188,6 +197,7 @@ class StubsGenerator: value=ast.Name(id=type.name), slice=params, ) + case AppliedType(): args: ast.expr if len(type.args) == 1: @@ -199,6 +209,12 @@ class StubsGenerator: slice=args, ) + case ConstraintType(): + return self.dump_type(type.type) + + case _: + assert_never(type) + def dump_method( self, name: str, method: Type, overloaded: bool = False ) -> list[ast.stmt]: