Browse Source

Add Interaction.client property

pull/7492/head
Rapptz 3 years ago
parent
commit
f435d160dd
  1. 10
      discord/app_commands/tree.py
  2. 3
      discord/client.py
  3. 17
      discord/interactions.py
  4. 3
      discord/shard.py
  5. 24
      discord/state.py

10
discord/app_commands/tree.py

@ -26,7 +26,7 @@ from __future__ import annotations
import inspect
import sys
import traceback
from typing import Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload
from typing import Callable, Dict, Generic, List, Literal, Optional, TYPE_CHECKING, Tuple, TypeVar, Union, overload
from .namespace import Namespace, ResolveKey
@ -52,8 +52,10 @@ if TYPE_CHECKING:
__all__ = ('CommandTree',)
ClientT = TypeVar('ClientT', bound='Client')
class CommandTree:
class CommandTree(Generic[ClientT]):
"""Represents a container that holds application command information.
Parameters
@ -62,8 +64,8 @@ class CommandTree:
The client instance to get application command information from.
"""
def __init__(self, client: Client):
self.client = client
def __init__(self, client: ClientT):
self.client: ClientT = client
self._http = client.http
self._state = client._connection
self._state._command_tree = self

3
discord/client.py

@ -77,6 +77,7 @@ from .threads import Thread
from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory
if TYPE_CHECKING:
from typing_extensions import Self
from .types.guild import Guild as GuildPayload
from .abc import SnowflakeTime, Snowflake, PrivateChannel
from .guild import GuildChannel
@ -254,7 +255,7 @@ class Client:
}
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
self._connection: ConnectionState = self._get_state(**options)
self._connection: ConnectionState[Self] = self._get_state(**options)
self._connection.shard_count = self.shard_count
self._closed: bool = False
self._ready: asyncio.Event = asyncio.Event()

17
discord/interactions.py

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union
from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union
import asyncio
from . import utils
@ -53,6 +53,7 @@ if TYPE_CHECKING:
Interaction as InteractionPayload,
InteractionData,
)
from .client import Client
from .guild import Guild
from .state import ConnectionState
from .file import File
@ -70,9 +71,10 @@ if TYPE_CHECKING:
]
MISSING: Any = utils.MISSING
ClientT = TypeVar('ClientT', bound='Client')
class Interaction:
class Interaction(Generic[ClientT]):
"""Represents a Discord interaction.
An interaction happens when a user does an action that needs to
@ -116,6 +118,7 @@ class Interaction:
'version',
'_permissions',
'_state',
'_client',
'_session',
'_original_message',
'_cs_response',
@ -123,8 +126,9 @@ class Interaction:
'_cs_channel',
)
def __init__(self, *, data: InteractionPayload, state: ConnectionState):
self._state: ConnectionState = state
def __init__(self, *, data: InteractionPayload, state: ConnectionState[ClientT]):
self._state: ConnectionState[ClientT] = state
self._client: ClientT = state._get_client()
self._session: ClientSession = state.http._HTTPClient__session # type: ignore - Mangled attribute for __session
self._original_message: Optional[InteractionMessage] = None
self._from_data(data)
@ -166,6 +170,11 @@ class Interaction:
except KeyError:
pass
@property
def client(self) -> ClientT:
""":class:`Client`: The client that is handling this interaction."""
return self._client
@property
def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild the interaction was sent from."""

3
discord/shard.py

@ -46,6 +46,7 @@ from .enums import Status
from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict
if TYPE_CHECKING:
from typing_extensions import Self
from .gateway import DiscordWebSocket
from .activity import BaseActivity
from .enums import Status
@ -316,7 +317,7 @@ class AutoShardedClient(Client):
"""
if TYPE_CHECKING:
_connection: AutoShardedConnectionState
_connection: AutoShardedConnectionState[Self]
def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None:
kwargs.pop('shard_id', None)

24
discord/state.py

@ -30,7 +30,21 @@ import copy
import datetime
import itertools
import logging
from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Sequence, Tuple, Deque
from typing import (
Dict,
Generic,
Optional,
TYPE_CHECKING,
Union,
Callable,
Any,
List,
TypeVar,
Coroutine,
Sequence,
Tuple,
Deque,
)
import weakref
import inspect
@ -84,6 +98,8 @@ if TYPE_CHECKING:
T = TypeVar('T')
Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable]
ClientT = TypeVar('ClientT', bound='Client')
class ChunkRequest:
def __init__(
@ -143,10 +159,10 @@ async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) ->
_log.exception('Exception occurred during %s', info)
class ConnectionState:
class ConnectionState(Generic[ClientT]):
if TYPE_CHECKING:
_get_websocket: Callable[..., DiscordWebSocket]
_get_client: Callable[..., Client]
_get_client: Callable[..., ClientT]
_parsers: Dict[str, Callable[[Dict[str, Any]], None]]
def __init__(
@ -1471,7 +1487,7 @@ class ConnectionState:
return Message(state=self, channel=channel, data=data)
class AutoShardedConnectionState(ConnectionState):
class AutoShardedConnectionState(ConnectionState[ClientT]):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.shard_ids: Union[List[int], range] = []

Loading…
Cancel
Save