diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py index 52057d2c8..7f96af9fd 100644 --- a/discord/app_commands/tree.py +++ b/discord/app_commands/tree.py @@ -26,7 +26,7 @@ from __future__ import annotations import inspect import sys import traceback -from typing import Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload +from typing import Callable, Dict, Generic, List, Literal, Optional, TYPE_CHECKING, Tuple, TypeVar, Union, overload from .namespace import Namespace, ResolveKey @@ -52,8 +52,10 @@ if TYPE_CHECKING: __all__ = ('CommandTree',) +ClientT = TypeVar('ClientT', bound='Client') -class CommandTree: + +class CommandTree(Generic[ClientT]): """Represents a container that holds application command information. Parameters @@ -62,8 +64,8 @@ class CommandTree: The client instance to get application command information from. """ - def __init__(self, client: Client): - self.client = client + def __init__(self, client: ClientT): + self.client: ClientT = client self._http = client.http self._state = client._connection self._state._command_tree = self diff --git a/discord/client.py b/discord/client.py index 07d86703f..3b897dfff 100644 --- a/discord/client.py +++ b/discord/client.py @@ -77,6 +77,7 @@ from .threads import Thread from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory if TYPE_CHECKING: + from typing_extensions import Self from .types.guild import Guild as GuildPayload from .abc import SnowflakeTime, Snowflake, PrivateChannel from .guild import GuildChannel @@ -254,7 +255,7 @@ class Client: } self._enable_debug_events: bool = options.pop('enable_debug_events', False) - self._connection: ConnectionState = self._get_state(**options) + self._connection: ConnectionState[Self] = self._get_state(**options) self._connection.shard_count = self.shard_count self._closed: bool = False self._ready: asyncio.Event = asyncio.Event() diff --git a/discord/interactions.py b/discord/interactions.py index 4e9db132a..d506a3347 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union +from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union import asyncio from . import utils @@ -53,6 +53,7 @@ if TYPE_CHECKING: Interaction as InteractionPayload, InteractionData, ) + from .client import Client from .guild import Guild from .state import ConnectionState from .file import File @@ -70,9 +71,10 @@ if TYPE_CHECKING: ] MISSING: Any = utils.MISSING +ClientT = TypeVar('ClientT', bound='Client') -class Interaction: +class Interaction(Generic[ClientT]): """Represents a Discord interaction. An interaction happens when a user does an action that needs to @@ -116,6 +118,7 @@ class Interaction: 'version', '_permissions', '_state', + '_client', '_session', '_original_message', '_cs_response', @@ -123,8 +126,9 @@ class Interaction: '_cs_channel', ) - def __init__(self, *, data: InteractionPayload, state: ConnectionState): - self._state: ConnectionState = state + def __init__(self, *, data: InteractionPayload, state: ConnectionState[ClientT]): + self._state: ConnectionState[ClientT] = state + self._client: ClientT = state._get_client() self._session: ClientSession = state.http._HTTPClient__session # type: ignore - Mangled attribute for __session self._original_message: Optional[InteractionMessage] = None self._from_data(data) @@ -166,6 +170,11 @@ class Interaction: except KeyError: pass + @property + def client(self) -> ClientT: + """:class:`Client`: The client that is handling this interaction.""" + return self._client + @property def guild(self) -> Optional[Guild]: """Optional[:class:`Guild`]: The guild the interaction was sent from.""" diff --git a/discord/shard.py b/discord/shard.py index 9093b9c56..923647763 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -46,6 +46,7 @@ from .enums import Status from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict if TYPE_CHECKING: + from typing_extensions import Self from .gateway import DiscordWebSocket from .activity import BaseActivity from .enums import Status @@ -316,7 +317,7 @@ class AutoShardedClient(Client): """ if TYPE_CHECKING: - _connection: AutoShardedConnectionState + _connection: AutoShardedConnectionState[Self] def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None: kwargs.pop('shard_id', None) diff --git a/discord/state.py b/discord/state.py index bc6638c35..f23a8402f 100644 --- a/discord/state.py +++ b/discord/state.py @@ -30,7 +30,21 @@ import copy import datetime import itertools import logging -from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Sequence, Tuple, Deque +from typing import ( + Dict, + Generic, + Optional, + TYPE_CHECKING, + Union, + Callable, + Any, + List, + TypeVar, + Coroutine, + Sequence, + Tuple, + Deque, +) import weakref import inspect @@ -84,6 +98,8 @@ if TYPE_CHECKING: T = TypeVar('T') Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable] +ClientT = TypeVar('ClientT', bound='Client') + class ChunkRequest: def __init__( @@ -143,10 +159,10 @@ async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) -> _log.exception('Exception occurred during %s', info) -class ConnectionState: +class ConnectionState(Generic[ClientT]): if TYPE_CHECKING: _get_websocket: Callable[..., DiscordWebSocket] - _get_client: Callable[..., Client] + _get_client: Callable[..., ClientT] _parsers: Dict[str, Callable[[Dict[str, Any]], None]] def __init__( @@ -1471,7 +1487,7 @@ class ConnectionState: return Message(state=self, channel=channel, data=data) -class AutoShardedConnectionState(ConnectionState): +class AutoShardedConnectionState(ConnectionState[ClientT]): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.shard_ids: Union[List[int], range] = []