11 changed files with 103 additions and 1589 deletions
@ -1,546 +0,0 @@ |
|||||
""" |
|
||||
The MIT License (MIT) |
|
||||
|
|
||||
Copyright (c) 2015-present Rapptz |
|
||||
|
|
||||
Permission is hereby granted, free of charge, to any person obtaining a |
|
||||
copy of this software and associated documentation files (the "Software"), |
|
||||
to deal in the Software without restriction, including without limitation |
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense, |
|
||||
and/or sell copies of the Software, and to permit persons to whom the |
|
||||
Software is furnished to do so, subject to the following conditions: |
|
||||
|
|
||||
The above copyright notice and this permission notice shall be included in |
|
||||
all copies or substantial portions of the Software. |
|
||||
|
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS |
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING |
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
|
||||
DEALINGS IN THE SOFTWARE. |
|
||||
""" |
|
||||
|
|
||||
from __future__ import annotations |
|
||||
|
|
||||
import asyncio |
|
||||
import logging |
|
||||
|
|
||||
import aiohttp |
|
||||
|
|
||||
from .state import AutoShardedConnectionState |
|
||||
from .client import Client |
|
||||
from .backoff import ExponentialBackoff |
|
||||
from .gateway import * |
|
||||
from .errors import ( |
|
||||
ClientException, |
|
||||
HTTPException, |
|
||||
GatewayNotFound, |
|
||||
ConnectionClosed, |
|
||||
PrivilegedIntentsRequired, |
|
||||
) |
|
||||
|
|
||||
from .enums import Status |
|
||||
|
|
||||
from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict, TypeVar |
|
||||
|
|
||||
if TYPE_CHECKING: |
|
||||
from .gateway import DiscordWebSocket |
|
||||
from .activity import BaseActivity |
|
||||
from .enums import Status |
|
||||
|
|
||||
EI = TypeVar('EI', bound='EventItem') |
|
||||
|
|
||||
__all__ = ( |
|
||||
'AutoShardedClient', |
|
||||
'ShardInfo', |
|
||||
) |
|
||||
|
|
||||
_log = logging.getLogger(__name__) |
|
||||
|
|
||||
|
|
||||
class EventType: |
|
||||
close = 0 |
|
||||
reconnect = 1 |
|
||||
resume = 2 |
|
||||
identify = 3 |
|
||||
terminate = 4 |
|
||||
clean_close = 5 |
|
||||
|
|
||||
|
|
||||
class EventItem: |
|
||||
__slots__ = ('type', 'shard', 'error') |
|
||||
|
|
||||
def __init__(self, etype: int, shard: Optional['Shard'], error: Optional[Exception]) -> None: |
|
||||
self.type: int = etype |
|
||||
self.shard: Optional['Shard'] = shard |
|
||||
self.error: Optional[Exception] = error |
|
||||
|
|
||||
def __lt__(self: EI, other: EI) -> bool: |
|
||||
if not isinstance(other, EventItem): |
|
||||
return NotImplemented |
|
||||
return self.type < other.type |
|
||||
|
|
||||
def __eq__(self: EI, other: EI) -> bool: |
|
||||
if not isinstance(other, EventItem): |
|
||||
return NotImplemented |
|
||||
return self.type == other.type |
|
||||
|
|
||||
def __hash__(self) -> int: |
|
||||
return hash(self.type) |
|
||||
|
|
||||
|
|
||||
class Shard: |
|
||||
def __init__(self, ws: DiscordWebSocket, client: AutoShardedClient, queue_put: Callable[[EventItem], None]) -> None: |
|
||||
self.ws: DiscordWebSocket = ws |
|
||||
self._client: Client = client |
|
||||
self._dispatch: Callable[..., None] = client.dispatch |
|
||||
self._queue_put: Callable[[EventItem], None] = queue_put |
|
||||
self.loop: asyncio.AbstractEventLoop = self._client.loop |
|
||||
self._disconnect: bool = False |
|
||||
self._reconnect = client._reconnect |
|
||||
self._backoff: ExponentialBackoff = ExponentialBackoff() |
|
||||
self._task: Optional[asyncio.Task] = None |
|
||||
self._handled_exceptions: Tuple[Type[Exception], ...] = ( |
|
||||
OSError, |
|
||||
HTTPException, |
|
||||
GatewayNotFound, |
|
||||
ConnectionClosed, |
|
||||
aiohttp.ClientError, |
|
||||
asyncio.TimeoutError, |
|
||||
) |
|
||||
|
|
||||
@property |
|
||||
def id(self) -> int: |
|
||||
# DiscordWebSocket.shard_id is set in the from_client classmethod |
|
||||
return self.ws.shard_id # type: ignore |
|
||||
|
|
||||
def launch(self) -> None: |
|
||||
self._task = self.loop.create_task(self.worker()) |
|
||||
|
|
||||
def _cancel_task(self) -> None: |
|
||||
if self._task is not None and not self._task.done(): |
|
||||
self._task.cancel() |
|
||||
|
|
||||
async def close(self) -> None: |
|
||||
self._cancel_task() |
|
||||
await self.ws.close(code=1000) |
|
||||
|
|
||||
async def disconnect(self) -> None: |
|
||||
await self.close() |
|
||||
self._dispatch('shard_disconnect', self.id) |
|
||||
|
|
||||
async def _handle_disconnect(self, e: Exception) -> None: |
|
||||
self._dispatch('disconnect') |
|
||||
self._dispatch('shard_disconnect', self.id) |
|
||||
if not self._reconnect: |
|
||||
self._queue_put(EventItem(EventType.close, self, e)) |
|
||||
return |
|
||||
|
|
||||
if self._client.is_closed(): |
|
||||
return |
|
||||
|
|
||||
if isinstance(e, OSError) and e.errno in (54, 10054): |
|
||||
# If we get Connection reset by peer then always try to RESUME the connection. |
|
||||
exc = ReconnectWebSocket(self.id, resume=True) |
|
||||
self._queue_put(EventItem(EventType.resume, self, exc)) |
|
||||
return |
|
||||
|
|
||||
if isinstance(e, ConnectionClosed): |
|
||||
if e.code == 4014: |
|
||||
self._queue_put(EventItem(EventType.terminate, self, PrivilegedIntentsRequired(self.id))) |
|
||||
return |
|
||||
if e.code != 1000: |
|
||||
self._queue_put(EventItem(EventType.close, self, e)) |
|
||||
return |
|
||||
|
|
||||
retry = self._backoff.delay() |
|
||||
_log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e) |
|
||||
await asyncio.sleep(retry) |
|
||||
self._queue_put(EventItem(EventType.reconnect, self, e)) |
|
||||
|
|
||||
async def worker(self) -> None: |
|
||||
while not self._client.is_closed(): |
|
||||
try: |
|
||||
await self.ws.poll_event() |
|
||||
except ReconnectWebSocket as e: |
|
||||
etype = EventType.resume if e.resume else EventType.identify |
|
||||
self._queue_put(EventItem(etype, self, e)) |
|
||||
break |
|
||||
except self._handled_exceptions as e: |
|
||||
await self._handle_disconnect(e) |
|
||||
break |
|
||||
except asyncio.CancelledError: |
|
||||
break |
|
||||
except Exception as e: |
|
||||
self._queue_put(EventItem(EventType.terminate, self, e)) |
|
||||
break |
|
||||
|
|
||||
async def reidentify(self, exc: ReconnectWebSocket) -> None: |
|
||||
self._cancel_task() |
|
||||
self._dispatch('disconnect') |
|
||||
self._dispatch('shard_disconnect', self.id) |
|
||||
_log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id) |
|
||||
try: |
|
||||
coro = DiscordWebSocket.from_client( |
|
||||
self._client, |
|
||||
resume=exc.resume, |
|
||||
shard_id=self.id, |
|
||||
session=self.ws.session_id, |
|
||||
sequence=self.ws.sequence, |
|
||||
) |
|
||||
self.ws = await asyncio.wait_for(coro, timeout=60.0) |
|
||||
except self._handled_exceptions as e: |
|
||||
await self._handle_disconnect(e) |
|
||||
except asyncio.CancelledError: |
|
||||
return |
|
||||
except Exception as e: |
|
||||
self._queue_put(EventItem(EventType.terminate, self, e)) |
|
||||
else: |
|
||||
self.launch() |
|
||||
|
|
||||
async def reconnect(self) -> None: |
|
||||
self._cancel_task() |
|
||||
try: |
|
||||
coro = DiscordWebSocket.from_client(self._client, shard_id=self.id) |
|
||||
self.ws = await asyncio.wait_for(coro, timeout=60.0) |
|
||||
except self._handled_exceptions as e: |
|
||||
await self._handle_disconnect(e) |
|
||||
except asyncio.CancelledError: |
|
||||
return |
|
||||
except Exception as e: |
|
||||
self._queue_put(EventItem(EventType.terminate, self, e)) |
|
||||
else: |
|
||||
self.launch() |
|
||||
|
|
||||
|
|
||||
class ShardInfo: |
|
||||
"""A class that gives information and control over a specific shard. |
|
||||
|
|
||||
You can retrieve this object via :meth:`AutoShardedClient.get_shard` |
|
||||
or :attr:`AutoShardedClient.shards`. |
|
||||
|
|
||||
.. versionadded:: 1.4 |
|
||||
|
|
||||
Attributes |
|
||||
------------ |
|
||||
id: :class:`int` |
|
||||
The shard ID for this shard. |
|
||||
shard_count: Optional[:class:`int`] |
|
||||
The shard count for this cluster. If this is ``None`` then the bot has not started yet. |
|
||||
""" |
|
||||
|
|
||||
__slots__ = ('_parent', 'id', 'shard_count') |
|
||||
|
|
||||
def __init__(self, parent: Shard, shard_count: Optional[int]) -> None: |
|
||||
self._parent: Shard = parent |
|
||||
self.id: int = parent.id |
|
||||
self.shard_count: Optional[int] = shard_count |
|
||||
|
|
||||
def is_closed(self) -> bool: |
|
||||
""":class:`bool`: Whether the shard connection is currently closed.""" |
|
||||
return not self._parent.ws.open |
|
||||
|
|
||||
async def disconnect(self) -> None: |
|
||||
"""|coro| |
|
||||
|
|
||||
Disconnects a shard. When this is called, the shard connection will no |
|
||||
longer be open. |
|
||||
|
|
||||
If the shard is already disconnected this does nothing. |
|
||||
""" |
|
||||
if self.is_closed(): |
|
||||
return |
|
||||
|
|
||||
await self._parent.disconnect() |
|
||||
|
|
||||
async def reconnect(self) -> None: |
|
||||
"""|coro| |
|
||||
|
|
||||
Disconnects and then connects the shard again. |
|
||||
""" |
|
||||
if not self.is_closed(): |
|
||||
await self._parent.disconnect() |
|
||||
await self._parent.reconnect() |
|
||||
|
|
||||
async def connect(self) -> None: |
|
||||
"""|coro| |
|
||||
|
|
||||
Connects a shard. If the shard is already connected this does nothing. |
|
||||
""" |
|
||||
if not self.is_closed(): |
|
||||
return |
|
||||
|
|
||||
await self._parent.reconnect() |
|
||||
|
|
||||
@property |
|
||||
def latency(self) -> float: |
|
||||
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds for this shard.""" |
|
||||
return self._parent.ws.latency |
|
||||
|
|
||||
def is_ws_ratelimited(self) -> bool: |
|
||||
""":class:`bool`: Whether the websocket is currently rate limited. |
|
||||
|
|
||||
This can be useful to know when deciding whether you should query members |
|
||||
using HTTP or via the gateway. |
|
||||
|
|
||||
.. versionadded:: 1.6 |
|
||||
""" |
|
||||
return self._parent.ws.is_ratelimited() |
|
||||
|
|
||||
|
|
||||
class AutoShardedClient(Client): |
|
||||
"""A client similar to :class:`Client` except it handles the complications |
|
||||
of sharding for the user into a more manageable and transparent single |
|
||||
process bot. |
|
||||
|
|
||||
When using this client, you will be able to use it as-if it was a regular |
|
||||
:class:`Client` with a single shard when implementation wise internally it |
|
||||
is split up into multiple shards. This allows you to not have to deal with |
|
||||
IPC or other complicated infrastructure. |
|
||||
|
|
||||
It is recommended to use this client only if you have surpassed at least |
|
||||
1000 guilds. |
|
||||
|
|
||||
If no :attr:`.shard_count` is provided, then the library will use the |
|
||||
Bot Gateway endpoint call to figure out how many shards to use. |
|
||||
|
|
||||
If a ``shard_ids`` parameter is given, then those shard IDs will be used |
|
||||
to launch the internal shards. Note that :attr:`.shard_count` must be provided |
|
||||
if this is used. By default, when omitted, the client will launch shards from |
|
||||
0 to ``shard_count - 1``. |
|
||||
|
|
||||
Attributes |
|
||||
------------ |
|
||||
shard_ids: Optional[List[:class:`int`]] |
|
||||
An optional list of shard_ids to launch the shards with. |
|
||||
""" |
|
||||
|
|
||||
if TYPE_CHECKING: |
|
||||
_connection: AutoShardedConnectionState |
|
||||
|
|
||||
def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None: |
|
||||
kwargs.pop('shard_id', None) |
|
||||
self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None) |
|
||||
super().__init__(*args, loop=loop, **kwargs) |
|
||||
|
|
||||
if self.shard_ids is not None: |
|
||||
if self.shard_count is None: |
|
||||
raise ClientException('When passing manual shard_ids, you must provide a shard_count.') |
|
||||
elif not isinstance(self.shard_ids, (list, tuple)): |
|
||||
raise ClientException('shard_ids parameter must be a list or a tuple.') |
|
||||
|
|
||||
# instead of a single websocket, we have multiple |
|
||||
# the key is the shard_id |
|
||||
self.__shards = {} |
|
||||
self._connection._get_websocket = self._get_websocket |
|
||||
self._connection._get_client = lambda: self |
|
||||
self.__queue = asyncio.PriorityQueue() |
|
||||
|
|
||||
def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket: |
|
||||
if shard_id is None: |
|
||||
# guild_id won't be None if shard_id is None and shard_count won't be None here |
|
||||
shard_id = (guild_id >> 22) % self.shard_count # type: ignore |
|
||||
return self.__shards[shard_id].ws |
|
||||
|
|
||||
def _get_state(self, **options: Any) -> AutoShardedConnectionState: |
|
||||
return AutoShardedConnectionState( |
|
||||
dispatch=self.dispatch, |
|
||||
handlers=self._handlers, |
|
||||
hooks=self._hooks, |
|
||||
http=self.http, |
|
||||
loop=self.loop, |
|
||||
**options, |
|
||||
) |
|
||||
|
|
||||
@property |
|
||||
def latency(self) -> float: |
|
||||
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. |
|
||||
|
|
||||
This operates similarly to :meth:`Client.latency` except it uses the average |
|
||||
latency of every shard's latency. To get a list of shard latency, check the |
|
||||
:attr:`latencies` property. Returns ``nan`` if there are no shards ready. |
|
||||
""" |
|
||||
if not self.__shards: |
|
||||
return float('nan') |
|
||||
return sum(latency for _, latency in self.latencies) / len(self.__shards) |
|
||||
|
|
||||
@property |
|
||||
def latencies(self) -> List[Tuple[int, float]]: |
|
||||
"""List[Tuple[:class:`int`, :class:`float`]]: A list of latencies between a HEARTBEAT and a HEARTBEAT_ACK in seconds. |
|
||||
|
|
||||
This returns a list of tuples with elements ``(shard_id, latency)``. |
|
||||
""" |
|
||||
return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()] |
|
||||
|
|
||||
def get_shard(self, shard_id: int) -> Optional[ShardInfo]: |
|
||||
"""Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found.""" |
|
||||
try: |
|
||||
parent = self.__shards[shard_id] |
|
||||
except KeyError: |
|
||||
return None |
|
||||
else: |
|
||||
return ShardInfo(parent, self.shard_count) |
|
||||
|
|
||||
@property |
|
||||
def shards(self) -> Dict[int, ShardInfo]: |
|
||||
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object.""" |
|
||||
return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()} |
|
||||
|
|
||||
async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None: |
|
||||
try: |
|
||||
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id) |
|
||||
ws = await asyncio.wait_for(coro, timeout=180.0) |
|
||||
except Exception: |
|
||||
_log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id) |
|
||||
await asyncio.sleep(5.0) |
|
||||
return await self.launch_shard(gateway, shard_id) |
|
||||
|
|
||||
# keep reading the shard while others connect |
|
||||
self.__shards[shard_id] = ret = Shard(ws, self, self.__queue.put_nowait) |
|
||||
ret.launch() |
|
||||
|
|
||||
async def launch_shards(self) -> None: |
|
||||
if self.shard_count is None: |
|
||||
self.shard_count, gateway = await self.http.get_bot_gateway() |
|
||||
else: |
|
||||
gateway = await self.http.get_gateway() |
|
||||
|
|
||||
self._connection.shard_count = self.shard_count |
|
||||
|
|
||||
shard_ids = self.shard_ids or range(self.shard_count) |
|
||||
self._connection.shard_ids = shard_ids |
|
||||
|
|
||||
for shard_id in shard_ids: |
|
||||
initial = shard_id == shard_ids[0] |
|
||||
await self.launch_shard(gateway, shard_id, initial=initial) |
|
||||
|
|
||||
self._connection.shards_launched.set() |
|
||||
|
|
||||
async def connect(self, *, reconnect: bool = True) -> None: |
|
||||
self._reconnect = reconnect |
|
||||
await self.launch_shards() |
|
||||
|
|
||||
while not self.is_closed(): |
|
||||
item = await self.__queue.get() |
|
||||
if item.type == EventType.close: |
|
||||
await self.close() |
|
||||
if isinstance(item.error, ConnectionClosed): |
|
||||
if item.error.code != 1000: |
|
||||
raise item.error |
|
||||
if item.error.code == 4014: |
|
||||
raise PrivilegedIntentsRequired(item.shard.id) from None |
|
||||
return |
|
||||
elif item.type in (EventType.identify, EventType.resume): |
|
||||
await item.shard.reidentify(item.error) |
|
||||
elif item.type == EventType.reconnect: |
|
||||
await item.shard.reconnect() |
|
||||
elif item.type == EventType.terminate: |
|
||||
await self.close() |
|
||||
raise item.error |
|
||||
elif item.type == EventType.clean_close: |
|
||||
return |
|
||||
|
|
||||
async def close(self) -> None: |
|
||||
"""|coro| |
|
||||
|
|
||||
Closes the connection to Discord. |
|
||||
""" |
|
||||
if self.is_closed(): |
|
||||
return |
|
||||
|
|
||||
self._closed = True |
|
||||
|
|
||||
for vc in self.voice_clients: |
|
||||
try: |
|
||||
await vc.disconnect(force=True) |
|
||||
except Exception: |
|
||||
pass |
|
||||
|
|
||||
to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()] |
|
||||
if to_close: |
|
||||
await asyncio.wait(to_close) |
|
||||
|
|
||||
await self.http.close() |
|
||||
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None)) |
|
||||
|
|
||||
async def change_presence( |
|
||||
self, |
|
||||
*, |
|
||||
activity: Optional[BaseActivity] = None, |
|
||||
status: Optional[Status] = None, |
|
||||
shard_id: int = None, |
|
||||
) -> None: |
|
||||
"""|coro| |
|
||||
|
|
||||
Changes the client's presence. |
|
||||
|
|
||||
Example: :: |
|
||||
|
|
||||
game = discord.Game("with the API") |
|
||||
await client.change_presence(status=discord.Status.idle, activity=game) |
|
||||
|
|
||||
.. versionchanged:: 2.0 |
|
||||
Removed the ``afk`` keyword-only parameter. |
|
||||
|
|
||||
Parameters |
|
||||
---------- |
|
||||
activity: Optional[:class:`BaseActivity`] |
|
||||
The activity being done. ``None`` if no currently active activity is done. |
|
||||
status: Optional[:class:`Status`] |
|
||||
Indicates what status to change to. If ``None``, then |
|
||||
:attr:`Status.online` is used. |
|
||||
shard_id: Optional[:class:`int`] |
|
||||
The shard_id to change the presence to. If not specified |
|
||||
or ``None``, then it will change the presence of every |
|
||||
shard the bot can see. |
|
||||
|
|
||||
Raises |
|
||||
------ |
|
||||
InvalidArgument |
|
||||
If the ``activity`` parameter is not of proper type. |
|
||||
""" |
|
||||
|
|
||||
if status is None: |
|
||||
status_value = 'online' |
|
||||
status_enum = Status.online |
|
||||
elif status is Status.offline: |
|
||||
status_value = 'invisible' |
|
||||
status_enum = Status.offline |
|
||||
else: |
|
||||
status_enum = status |
|
||||
status_value = str(status) |
|
||||
|
|
||||
if shard_id is None: |
|
||||
for shard in self.__shards.values(): |
|
||||
await shard.ws.change_presence(activity=activity, status=status_value) |
|
||||
|
|
||||
guilds = self._connection.guilds |
|
||||
else: |
|
||||
shard = self.__shards[shard_id] |
|
||||
await shard.ws.change_presence(activity=activity, status=status_value) |
|
||||
guilds = [g for g in self._connection.guilds if g.shard_id == shard_id] |
|
||||
|
|
||||
activities = () if activity is None else (activity,) |
|
||||
for guild in guilds: |
|
||||
me = guild.me |
|
||||
if me is None: |
|
||||
continue |
|
||||
|
|
||||
# Member.activities is typehinted as Tuple[ActivityType, ...], we may be setting it as Tuple[BaseActivity, ...] |
|
||||
me.activities = activities # type: ignore |
|
||||
me.status = status_enum |
|
||||
|
|
||||
def is_ws_ratelimited(self) -> bool: |
|
||||
""":class:`bool`: Whether the websocket is currently rate limited. |
|
||||
|
|
||||
This can be useful to know when deciding whether you should query members |
|
||||
using HTTP or via the gateway. |
|
||||
|
|
||||
This implementation checks if any of the shards are rate limited. |
|
||||
For more granular control, consider :meth:`ShardInfo.is_ws_ratelimited`. |
|
||||
|
|
||||
.. versionadded:: 1.6 |
|
||||
""" |
|
||||
return any(shard.ws.is_ratelimited() for shard in self.__shards.values()) |
|
Loading…
Reference in new issue