Browse Source

Add Interaction.client property

pull/7492/head
Rapptz 3 years ago
parent
commit
f435d160dd
  1. 10
      discord/app_commands/tree.py
  2. 3
      discord/client.py
  3. 17
      discord/interactions.py
  4. 3
      discord/shard.py
  5. 24
      discord/state.py

10
discord/app_commands/tree.py

@ -26,7 +26,7 @@ from __future__ import annotations
import inspect import inspect
import sys import sys
import traceback import traceback
from typing import Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload from typing import Callable, Dict, Generic, List, Literal, Optional, TYPE_CHECKING, Tuple, TypeVar, Union, overload
from .namespace import Namespace, ResolveKey from .namespace import Namespace, ResolveKey
@ -52,8 +52,10 @@ if TYPE_CHECKING:
__all__ = ('CommandTree',) __all__ = ('CommandTree',)
ClientT = TypeVar('ClientT', bound='Client')
class CommandTree:
class CommandTree(Generic[ClientT]):
"""Represents a container that holds application command information. """Represents a container that holds application command information.
Parameters Parameters
@ -62,8 +64,8 @@ class CommandTree:
The client instance to get application command information from. The client instance to get application command information from.
""" """
def __init__(self, client: Client): def __init__(self, client: ClientT):
self.client = client self.client: ClientT = client
self._http = client.http self._http = client.http
self._state = client._connection self._state = client._connection
self._state._command_tree = self self._state._command_tree = self

3
discord/client.py

@ -77,6 +77,7 @@ from .threads import Thread
from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from .types.guild import Guild as GuildPayload from .types.guild import Guild as GuildPayload
from .abc import SnowflakeTime, Snowflake, PrivateChannel from .abc import SnowflakeTime, Snowflake, PrivateChannel
from .guild import GuildChannel from .guild import GuildChannel
@ -254,7 +255,7 @@ class Client:
} }
self._enable_debug_events: bool = options.pop('enable_debug_events', False) self._enable_debug_events: bool = options.pop('enable_debug_events', False)
self._connection: ConnectionState = self._get_state(**options) self._connection: ConnectionState[Self] = self._get_state(**options)
self._connection.shard_count = self.shard_count self._connection.shard_count = self.shard_count
self._closed: bool = False self._closed: bool = False
self._ready: asyncio.Event = asyncio.Event() self._ready: asyncio.Event = asyncio.Event()

17
discord/interactions.py

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union
import asyncio import asyncio
from . import utils from . import utils
@ -53,6 +53,7 @@ if TYPE_CHECKING:
Interaction as InteractionPayload, Interaction as InteractionPayload,
InteractionData, InteractionData,
) )
from .client import Client
from .guild import Guild from .guild import Guild
from .state import ConnectionState from .state import ConnectionState
from .file import File from .file import File
@ -70,9 +71,10 @@ if TYPE_CHECKING:
] ]
MISSING: Any = utils.MISSING MISSING: Any = utils.MISSING
ClientT = TypeVar('ClientT', bound='Client')
class Interaction: class Interaction(Generic[ClientT]):
"""Represents a Discord interaction. """Represents a Discord interaction.
An interaction happens when a user does an action that needs to An interaction happens when a user does an action that needs to
@ -116,6 +118,7 @@ class Interaction:
'version', 'version',
'_permissions', '_permissions',
'_state', '_state',
'_client',
'_session', '_session',
'_original_message', '_original_message',
'_cs_response', '_cs_response',
@ -123,8 +126,9 @@ class Interaction:
'_cs_channel', '_cs_channel',
) )
def __init__(self, *, data: InteractionPayload, state: ConnectionState): def __init__(self, *, data: InteractionPayload, state: ConnectionState[ClientT]):
self._state: ConnectionState = state self._state: ConnectionState[ClientT] = state
self._client: ClientT = state._get_client()
self._session: ClientSession = state.http._HTTPClient__session # type: ignore - Mangled attribute for __session self._session: ClientSession = state.http._HTTPClient__session # type: ignore - Mangled attribute for __session
self._original_message: Optional[InteractionMessage] = None self._original_message: Optional[InteractionMessage] = None
self._from_data(data) self._from_data(data)
@ -166,6 +170,11 @@ class Interaction:
except KeyError: except KeyError:
pass pass
@property
def client(self) -> ClientT:
""":class:`Client`: The client that is handling this interaction."""
return self._client
@property @property
def guild(self) -> Optional[Guild]: def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild the interaction was sent from.""" """Optional[:class:`Guild`]: The guild the interaction was sent from."""

3
discord/shard.py

@ -46,6 +46,7 @@ from .enums import Status
from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from .gateway import DiscordWebSocket from .gateway import DiscordWebSocket
from .activity import BaseActivity from .activity import BaseActivity
from .enums import Status from .enums import Status
@ -316,7 +317,7 @@ class AutoShardedClient(Client):
""" """
if TYPE_CHECKING: if TYPE_CHECKING:
_connection: AutoShardedConnectionState _connection: AutoShardedConnectionState[Self]
def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None: def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None:
kwargs.pop('shard_id', None) kwargs.pop('shard_id', None)

24
discord/state.py

@ -30,7 +30,21 @@ import copy
import datetime import datetime
import itertools import itertools
import logging import logging
from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Sequence, Tuple, Deque from typing import (
Dict,
Generic,
Optional,
TYPE_CHECKING,
Union,
Callable,
Any,
List,
TypeVar,
Coroutine,
Sequence,
Tuple,
Deque,
)
import weakref import weakref
import inspect import inspect
@ -84,6 +98,8 @@ if TYPE_CHECKING:
T = TypeVar('T') T = TypeVar('T')
Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable] Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable]
ClientT = TypeVar('ClientT', bound='Client')
class ChunkRequest: class ChunkRequest:
def __init__( def __init__(
@ -143,10 +159,10 @@ async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) ->
_log.exception('Exception occurred during %s', info) _log.exception('Exception occurred during %s', info)
class ConnectionState: class ConnectionState(Generic[ClientT]):
if TYPE_CHECKING: if TYPE_CHECKING:
_get_websocket: Callable[..., DiscordWebSocket] _get_websocket: Callable[..., DiscordWebSocket]
_get_client: Callable[..., Client] _get_client: Callable[..., ClientT]
_parsers: Dict[str, Callable[[Dict[str, Any]], None]] _parsers: Dict[str, Callable[[Dict[str, Any]], None]]
def __init__( def __init__(
@ -1471,7 +1487,7 @@ class ConnectionState:
return Message(state=self, channel=channel, data=data) return Message(state=self, channel=channel, data=data)
class AutoShardedConnectionState(ConnectionState): class AutoShardedConnectionState(ConnectionState[ClientT]):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.shard_ids: Union[List[int], range] = [] self.shard_ids: Union[List[int], range] = []

Loading…
Cancel
Save