diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index cefe837c0..2079fce7e 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -82,6 +82,7 @@ if TYPE_CHECKING: from .core import Command from .hybrid import CommandCallback, ContextT, P from discord.client import _ClientOptions + from discord.shard import _AutoShardedClientOptions _Prefix = Union[Iterable[str], str] _PrefixCallable = MaybeAwaitableFunc[[BotT, Message], _Prefix] @@ -92,6 +93,9 @@ if TYPE_CHECKING: owner_ids: NotRequired[Optional[Collection[int]]] strip_after_prefix: NotRequired[bool] + class _AutoShardedBotOptions(_AutoShardedClientOptions, _BotOptions): + ... + __all__ = ( 'when_mentioned', @@ -1534,4 +1538,13 @@ class AutoShardedBot(BotBase, discord.AutoShardedClient): .. versionadded:: 2.0 """ - pass + if TYPE_CHECKING: + + def __init__( + self, + command_prefix: PrefixType[BotT], + *, + intents: discord.Intents, + **kwargs: Unpack[_AutoShardedBotOptions], + ) -> None: + ... diff --git a/discord/shard.py b/discord/shard.py index 454fd5e28..89d22588a 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -47,11 +47,17 @@ from .enums import Status from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict if TYPE_CHECKING: - from typing_extensions import Unpack + from typing_extensions import Unpack, NotRequired from .gateway import DiscordWebSocket from .activity import BaseActivity from .flags import Intents from .types.gateway import SessionStartLimit + from .client import _ClientOptions + + class _AutoShardedClientOptions(_ClientOptions): + shard_ids: NotRequired[Optional[List[int]]] + shard_connect_timeout: NotRequired[Optional[float]] + __all__ = ( 'AutoShardedClient', @@ -365,10 +371,14 @@ class AutoShardedClient(Client): if TYPE_CHECKING: _connection: AutoShardedConnectionState - def __init__(self, *args: Any, intents: Intents, **kwargs: Any) -> None: + def __init__(self, *args: Any, intents: Intents, **kwargs: Unpack[_AutoShardedClientOptions]) -> None: kwargs.pop('shard_id', None) - self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None) - self.shard_connect_timeout: Optional[float] = kwargs.pop('shard_connect_timeout', 180.0) + self.shard_ids: Optional[List[int]] = kwargs.pop( + 'shard_ids', None + ) # pyright: ignore[reportAttributeAccessIssue] # it's fine + self.shard_connect_timeout: Optional[float] = kwargs.pop( + 'shard_connect_timeout', 180.0 + ) # pyright: ignore[reportAttributeAccessIssue] # it's fine super().__init__(*args, intents=intents, **kwargs)