diff --git a/discord/_types.py b/discord/_types.py new file mode 100644 index 000000000..331063544 --- /dev/null +++ b/discord/_types.py @@ -0,0 +1,34 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations +from typing import TypeVar, TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import TypeVar + from .client import Client + + ClientT = TypeVar('ClientT', bound=Client, covariant=True, default=Client) +else: + ClientT = TypeVar('ClientT', bound='Client', covariant=True) diff --git a/discord/client.py b/discord/client.py index 2df4adf52..fce4e2900 100644 --- a/discord/client.py +++ b/discord/client.py @@ -282,7 +282,7 @@ class Client: } self._enable_debug_events: bool = options.pop('enable_debug_events', False) - self._connection: ConnectionState = self._get_state(intents=intents, **options) + self._connection: ConnectionState[Self] = self._get_state(intents=intents, **options) self._connection.shard_count = self.shard_count self._closed: bool = False self._ready: asyncio.Event = MISSING diff --git a/discord/interactions.py b/discord/interactions.py index 36a165e2e..478a1a8b3 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations -from typing import Any, Dict, Optional, TYPE_CHECKING, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Generic, TYPE_CHECKING, Sequence, Tuple, Union import asyncio import datetime @@ -34,6 +34,7 @@ from .enums import try_enum, Locale, InteractionType, InteractionResponseType from .errors import InteractionResponded, HTTPException, ClientException, DiscordException from .flags import MessageFlags from .channel import PartialMessageable, ChannelType +from ._types import ClientT from .user import User from .member import Member @@ -59,7 +60,6 @@ if TYPE_CHECKING: from .types.webhook import ( Webhook as WebhookPayload, ) - from .client import Client from .guild import Guild from .state import ConnectionState from .file import File @@ -80,7 +80,7 @@ if TYPE_CHECKING: MISSING: Any = utils.MISSING -class Interaction: +class Interaction(Generic[ClientT]): """Represents a Discord interaction. An interaction happens when a user does an action that needs to @@ -151,9 +151,9 @@ class Interaction: '_cs_command', ) - def __init__(self, *, data: InteractionPayload, state: ConnectionState): - self._state: ConnectionState = state - self._client: Client = state._get_client() + def __init__(self, *, data: InteractionPayload, state: ConnectionState[ClientT]): + self._state: ConnectionState[ClientT] = state + self._client: ClientT = state._get_client() self._session: ClientSession = state.http._HTTPClient__session # type: ignore # Mangled attribute for __session self._original_response: Optional[InteractionMessage] = None # This baton is used for extra data that might be useful for the lifecycle of @@ -207,7 +207,7 @@ class Interaction: pass @property - def client(self) -> Client: + def client(self) -> ClientT: """:class:`Client`: The client that is handling this interaction. Note that :class:`AutoShardedClient`, :class:`~.commands.Bot`, and diff --git a/discord/state.py b/discord/state.py index 24eb70301..c47a9ab80 100644 --- a/discord/state.py +++ b/discord/state.py @@ -39,6 +39,7 @@ from typing import ( TypeVar, Coroutine, Sequence, + Generic, Tuple, Deque, Literal, @@ -75,6 +76,7 @@ from .threads import Thread, ThreadMember from .sticker import GuildSticker from .automod import AutoModRule, AutoModAction from .audit_logs import AuditLogEntry +from ._types import ClientT if TYPE_CHECKING: from .abc import PrivateChannel @@ -82,7 +84,6 @@ if TYPE_CHECKING: from .guild import GuildChannel from .http import HTTPClient from .voice_client import VoiceProtocol - from .client import Client from .gateway import DiscordWebSocket from .app_commands import CommandTree, Translator @@ -160,10 +161,10 @@ async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) -> _log.exception('Exception occurred during %s', info) -class ConnectionState: +class ConnectionState(Generic[ClientT]): if TYPE_CHECKING: _get_websocket: Callable[..., DiscordWebSocket] - _get_client: Callable[..., Client] + _get_client: Callable[..., ClientT] _parsers: Dict[str, Callable[[Dict[str, Any]], None]] def __init__( @@ -1612,7 +1613,7 @@ class ConnectionState: return Message(state=self, channel=channel, data=data) -class AutoShardedConnectionState(ConnectionState): +class AutoShardedConnectionState(ConnectionState[ClientT]): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs)