Browse Source

Remove generic from Interaction and ConnectionState

This results in poor ergonomics due to the lack of default generics
for the common case. For most users this ends up in a degraded
experience since the type will resolve to Unknown rather than at the
very least a Client.
pull/7492/head
Rapptz 3 years ago
parent
commit
f7315573aa
  1. 3
      discord/client.py
  2. 11
      discord/interactions.py
  3. 3
      discord/shard.py
  4. 8
      discord/state.py

3
discord/client.py

@ -77,7 +77,6 @@ from .threads import Thread
from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from .types.guild import Guild as GuildPayload from .types.guild import Guild as GuildPayload
from .abc import SnowflakeTime, Snowflake, PrivateChannel from .abc import SnowflakeTime, Snowflake, PrivateChannel
from .guild import GuildChannel from .guild import GuildChannel
@ -255,7 +254,7 @@ class Client:
} }
self._enable_debug_events: bool = options.pop('enable_debug_events', False) self._enable_debug_events: bool = options.pop('enable_debug_events', False)
self._connection: ConnectionState[Self] = self._get_state(**options) self._connection: ConnectionState = self._get_state(**options)
self._connection.shard_count = self.shard_count self._connection.shard_count = self.shard_count
self._closed: bool = False self._closed: bool = False
self._ready: asyncio.Event = asyncio.Event() self._ready: asyncio.Event = asyncio.Event()

11
discord/interactions.py

@ -71,10 +71,9 @@ if TYPE_CHECKING:
] ]
MISSING: Any = utils.MISSING MISSING: Any = utils.MISSING
ClientT = TypeVar('ClientT', bound='Client')
class Interaction(Generic[ClientT]): class Interaction:
"""Represents a Discord interaction. """Represents a Discord interaction.
An interaction happens when a user does an action that needs to An interaction happens when a user does an action that needs to
@ -126,9 +125,9 @@ class Interaction(Generic[ClientT]):
'_cs_channel', '_cs_channel',
) )
def __init__(self, *, data: InteractionPayload, state: ConnectionState[ClientT]): def __init__(self, *, data: InteractionPayload, state: ConnectionState):
self._state: ConnectionState[ClientT] = state self._state: ConnectionState = state
self._client: ClientT = state._get_client() self._client: Client = state._get_client()
self._session: ClientSession = state.http._HTTPClient__session # type: ignore - Mangled attribute for __session self._session: ClientSession = state.http._HTTPClient__session # type: ignore - Mangled attribute for __session
self._original_message: Optional[InteractionMessage] = None self._original_message: Optional[InteractionMessage] = None
self._from_data(data) self._from_data(data)
@ -171,7 +170,7 @@ class Interaction(Generic[ClientT]):
pass pass
@property @property
def client(self) -> ClientT: def client(self) -> Client:
""":class:`Client`: The client that is handling this interaction.""" """:class:`Client`: The client that is handling this interaction."""
return self._client return self._client

3
discord/shard.py

@ -46,7 +46,6 @@ from .enums import Status
from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from .gateway import DiscordWebSocket from .gateway import DiscordWebSocket
from .activity import BaseActivity from .activity import BaseActivity
from .enums import Status from .enums import Status
@ -317,7 +316,7 @@ class AutoShardedClient(Client):
""" """
if TYPE_CHECKING: if TYPE_CHECKING:
_connection: AutoShardedConnectionState[Self] _connection: AutoShardedConnectionState
def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None: def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None:
kwargs.pop('shard_id', None) kwargs.pop('shard_id', None)

8
discord/state.py

@ -98,8 +98,6 @@ if TYPE_CHECKING:
T = TypeVar('T') T = TypeVar('T')
Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable] Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable]
ClientT = TypeVar('ClientT', bound='Client')
class ChunkRequest: class ChunkRequest:
def __init__( def __init__(
@ -159,10 +157,10 @@ async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) ->
_log.exception('Exception occurred during %s', info) _log.exception('Exception occurred during %s', info)
class ConnectionState(Generic[ClientT]): class ConnectionState:
if TYPE_CHECKING: if TYPE_CHECKING:
_get_websocket: Callable[..., DiscordWebSocket] _get_websocket: Callable[..., DiscordWebSocket]
_get_client: Callable[..., ClientT] _get_client: Callable[..., Client]
_parsers: Dict[str, Callable[[Dict[str, Any]], None]] _parsers: Dict[str, Callable[[Dict[str, Any]], None]]
def __init__( def __init__(
@ -1487,7 +1485,7 @@ class ConnectionState(Generic[ClientT]):
return Message(state=self, channel=channel, data=data) return Message(state=self, channel=channel, data=data)
class AutoShardedConnectionState(ConnectionState[ClientT]): class AutoShardedConnectionState(ConnectionState):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.shard_ids: Union[List[int], range] = [] self.shard_ids: Union[List[int], range] = []

Loading…
Cancel
Save