|
|
@ -22,8 +22,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
|
|
|
DEALINGS IN THE SOFTWARE. |
|
|
|
""" |
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
|
|
import asyncio |
|
|
|
import itertools |
|
|
|
import logging |
|
|
|
|
|
|
|
import aiohttp |
|
|
@ -34,22 +35,29 @@ from .backoff import ExponentialBackoff |
|
|
|
from .gateway import * |
|
|
|
from .errors import ( |
|
|
|
ClientException, |
|
|
|
InvalidArgument, |
|
|
|
HTTPException, |
|
|
|
GatewayNotFound, |
|
|
|
ConnectionClosed, |
|
|
|
PrivilegedIntentsRequired, |
|
|
|
) |
|
|
|
|
|
|
|
from . import utils |
|
|
|
from .enums import Status |
|
|
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict, TypeVar |
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
from .gateway import DiscordWebSocket |
|
|
|
from .activity import BaseActivity |
|
|
|
from .enums import Status |
|
|
|
|
|
|
|
EI = TypeVar('EI', bound='EventItem') |
|
|
|
|
|
|
|
__all__ = ( |
|
|
|
'AutoShardedClient', |
|
|
|
'ShardInfo', |
|
|
|
) |
|
|
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
log: logging.Logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
class EventType: |
|
|
|
close = 0 |
|
|
@ -62,36 +70,36 @@ class EventType: |
|
|
|
class EventItem: |
|
|
|
__slots__ = ('type', 'shard', 'error') |
|
|
|
|
|
|
|
def __init__(self, etype, shard, error): |
|
|
|
self.type = etype |
|
|
|
self.shard = shard |
|
|
|
self.error = error |
|
|
|
def __init__(self, etype: int, shard: Optional['Shard'], error: Optional[Exception]) -> None: |
|
|
|
self.type: int = etype |
|
|
|
self.shard: Optional['Shard'] = shard |
|
|
|
self.error: Optional[Exception] = error |
|
|
|
|
|
|
|
def __lt__(self, other): |
|
|
|
def __lt__(self: EI, other: EI) -> bool: |
|
|
|
if not isinstance(other, EventItem): |
|
|
|
return NotImplemented |
|
|
|
return self.type < other.type |
|
|
|
|
|
|
|
def __eq__(self, other): |
|
|
|
def __eq__(self: EI, other: EI) -> bool: |
|
|
|
if not isinstance(other, EventItem): |
|
|
|
return NotImplemented |
|
|
|
return self.type == other.type |
|
|
|
|
|
|
|
def __hash__(self): |
|
|
|
def __hash__(self) -> int: |
|
|
|
return hash(self.type) |
|
|
|
|
|
|
|
class Shard: |
|
|
|
def __init__(self, ws, client, queue_put): |
|
|
|
self.ws = ws |
|
|
|
self._client = client |
|
|
|
self._dispatch = client.dispatch |
|
|
|
self._queue_put = queue_put |
|
|
|
self.loop = self._client.loop |
|
|
|
self._disconnect = False |
|
|
|
def __init__(self, ws: DiscordWebSocket, client: AutoShardedClient, queue_put: Callable[[EventItem], None]) -> None: |
|
|
|
self.ws: DiscordWebSocket = ws |
|
|
|
self._client: Client = client |
|
|
|
self._dispatch: Callable[..., None] = client.dispatch |
|
|
|
self._queue_put: Callable[[EventItem], None] = queue_put |
|
|
|
self.loop: asyncio.AbstractEventLoop = self._client.loop |
|
|
|
self._disconnect: bool = False |
|
|
|
self._reconnect = client._reconnect |
|
|
|
self._backoff = ExponentialBackoff() |
|
|
|
self._task = None |
|
|
|
self._handled_exceptions = ( |
|
|
|
self._backoff: ExponentialBackoff = ExponentialBackoff() |
|
|
|
self._task: Optional[asyncio.Task] = None |
|
|
|
self._handled_exceptions: Tuple[Type[Exception], ...] = ( |
|
|
|
OSError, |
|
|
|
HTTPException, |
|
|
|
GatewayNotFound, |
|
|
@ -101,25 +109,26 @@ class Shard: |
|
|
|
) |
|
|
|
|
|
|
|
@property |
|
|
|
def id(self): |
|
|
|
return self.ws.shard_id |
|
|
|
def id(self) -> int: |
|
|
|
# DiscordWebSocket.shard_id is set in the from_client classmethod |
|
|
|
return self.ws.shard_id # type: ignore |
|
|
|
|
|
|
|
def launch(self): |
|
|
|
def launch(self) -> None: |
|
|
|
self._task = self.loop.create_task(self.worker()) |
|
|
|
|
|
|
|
def _cancel_task(self): |
|
|
|
def _cancel_task(self) -> None: |
|
|
|
if self._task is not None and not self._task.done(): |
|
|
|
self._task.cancel() |
|
|
|
|
|
|
|
async def close(self): |
|
|
|
async def close(self) -> None: |
|
|
|
self._cancel_task() |
|
|
|
await self.ws.close(code=1000) |
|
|
|
|
|
|
|
async def disconnect(self): |
|
|
|
async def disconnect(self) -> None: |
|
|
|
await self.close() |
|
|
|
self._dispatch('shard_disconnect', self.id) |
|
|
|
|
|
|
|
async def _handle_disconnect(self, e): |
|
|
|
async def _handle_disconnect(self, e: Exception) -> None: |
|
|
|
self._dispatch('disconnect') |
|
|
|
self._dispatch('shard_disconnect', self.id) |
|
|
|
if not self._reconnect: |
|
|
@ -148,7 +157,7 @@ class Shard: |
|
|
|
await asyncio.sleep(retry) |
|
|
|
self._queue_put(EventItem(EventType.reconnect, self, e)) |
|
|
|
|
|
|
|
async def worker(self): |
|
|
|
async def worker(self) -> None: |
|
|
|
while not self._client.is_closed(): |
|
|
|
try: |
|
|
|
await self.ws.poll_event() |
|
|
@ -165,7 +174,7 @@ class Shard: |
|
|
|
self._queue_put(EventItem(EventType.terminate, self, e)) |
|
|
|
break |
|
|
|
|
|
|
|
async def reidentify(self, exc): |
|
|
|
async def reidentify(self, exc: ReconnectWebSocket) -> None: |
|
|
|
self._cancel_task() |
|
|
|
self._dispatch('disconnect') |
|
|
|
self._dispatch('shard_disconnect', self.id) |
|
|
@ -183,7 +192,7 @@ class Shard: |
|
|
|
else: |
|
|
|
self.launch() |
|
|
|
|
|
|
|
async def reconnect(self): |
|
|
|
async def reconnect(self) -> None: |
|
|
|
self._cancel_task() |
|
|
|
try: |
|
|
|
coro = DiscordWebSocket.from_client(self._client, shard_id=self.id) |
|
|
@ -215,16 +224,16 @@ class ShardInfo: |
|
|
|
|
|
|
|
__slots__ = ('_parent', 'id', 'shard_count') |
|
|
|
|
|
|
|
def __init__(self, parent, shard_count): |
|
|
|
self._parent = parent |
|
|
|
self.id = parent.id |
|
|
|
self.shard_count = shard_count |
|
|
|
def __init__(self, parent: Shard, shard_count: Optional[int]) -> None: |
|
|
|
self._parent: Shard = parent |
|
|
|
self.id: int = parent.id |
|
|
|
self.shard_count: Optional[int] = shard_count |
|
|
|
|
|
|
|
def is_closed(self): |
|
|
|
def is_closed(self) -> bool: |
|
|
|
""":class:`bool`: Whether the shard connection is currently closed.""" |
|
|
|
return not self._parent.ws.open |
|
|
|
|
|
|
|
async def disconnect(self): |
|
|
|
async def disconnect(self) -> None: |
|
|
|
"""|coro| |
|
|
|
|
|
|
|
Disconnects a shard. When this is called, the shard connection will no |
|
|
@ -237,7 +246,7 @@ class ShardInfo: |
|
|
|
|
|
|
|
await self._parent.disconnect() |
|
|
|
|
|
|
|
async def reconnect(self): |
|
|
|
async def reconnect(self) -> None: |
|
|
|
"""|coro| |
|
|
|
|
|
|
|
Disconnects and then connects the shard again. |
|
|
@ -246,7 +255,7 @@ class ShardInfo: |
|
|
|
await self._parent.disconnect() |
|
|
|
await self._parent.reconnect() |
|
|
|
|
|
|
|
async def connect(self): |
|
|
|
async def connect(self) -> None: |
|
|
|
"""|coro| |
|
|
|
|
|
|
|
Connects a shard. If the shard is already connected this does nothing. |
|
|
@ -257,11 +266,11 @@ class ShardInfo: |
|
|
|
await self._parent.reconnect() |
|
|
|
|
|
|
|
@property |
|
|
|
def latency(self): |
|
|
|
def latency(self) -> float: |
|
|
|
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds for this shard.""" |
|
|
|
return self._parent.ws.latency |
|
|
|
|
|
|
|
def is_ws_ratelimited(self): |
|
|
|
def is_ws_ratelimited(self) -> bool: |
|
|
|
""":class:`bool`: Whether the websocket is currently rate limited. |
|
|
|
|
|
|
|
This can be useful to know when deciding whether you should query members |
|
|
@ -297,9 +306,12 @@ class AutoShardedClient(Client): |
|
|
|
shard_ids: Optional[List[:class:`int`]] |
|
|
|
An optional list of shard_ids to launch the shards with. |
|
|
|
""" |
|
|
|
def __init__(self, *args, loop=None, **kwargs): |
|
|
|
if TYPE_CHECKING: |
|
|
|
_connection: AutoShardedConnectionState |
|
|
|
|
|
|
|
def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None: |
|
|
|
kwargs.pop('shard_id', None) |
|
|
|
self.shard_ids = kwargs.pop('shard_ids', None) |
|
|
|
self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None) |
|
|
|
super().__init__(*args, loop=loop, **kwargs) |
|
|
|
|
|
|
|
if self.shard_ids is not None: |
|
|
@ -315,18 +327,19 @@ class AutoShardedClient(Client): |
|
|
|
self._connection._get_client = lambda: self |
|
|
|
self.__queue = asyncio.PriorityQueue() |
|
|
|
|
|
|
|
def _get_websocket(self, guild_id=None, *, shard_id=None): |
|
|
|
def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket: |
|
|
|
if shard_id is None: |
|
|
|
shard_id = (guild_id >> 22) % self.shard_count |
|
|
|
# guild_id won't be None if shard_id is None and shard_count won't be None here |
|
|
|
shard_id = (guild_id >> 22) % self.shard_count # type: ignore |
|
|
|
return self.__shards[shard_id].ws |
|
|
|
|
|
|
|
def _get_state(self, **options): |
|
|
|
def _get_state(self, **options: Any) -> AutoShardedConnectionState: |
|
|
|
return AutoShardedConnectionState(dispatch=self.dispatch, |
|
|
|
handlers=self._handlers, |
|
|
|
hooks=self._hooks, http=self.http, loop=self.loop, **options) |
|
|
|
|
|
|
|
@property |
|
|
|
def latency(self): |
|
|
|
def latency(self) -> float: |
|
|
|
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. |
|
|
|
|
|
|
|
This operates similarly to :meth:`Client.latency` except it uses the average |
|
|
@ -338,14 +351,14 @@ class AutoShardedClient(Client): |
|
|
|
return sum(latency for _, latency in self.latencies) / len(self.__shards) |
|
|
|
|
|
|
|
@property |
|
|
|
def latencies(self): |
|
|
|
def latencies(self) -> List[Tuple[int, float]]: |
|
|
|
"""List[Tuple[:class:`int`, :class:`float`]]: A list of latencies between a HEARTBEAT and a HEARTBEAT_ACK in seconds. |
|
|
|
|
|
|
|
This returns a list of tuples with elements ``(shard_id, latency)``. |
|
|
|
""" |
|
|
|
return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()] |
|
|
|
|
|
|
|
def get_shard(self, shard_id): |
|
|
|
def get_shard(self, shard_id: int) -> Optional[ShardInfo]: |
|
|
|
"""Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found.""" |
|
|
|
try: |
|
|
|
parent = self.__shards[shard_id] |
|
|
@ -355,11 +368,11 @@ class AutoShardedClient(Client): |
|
|
|
return ShardInfo(parent, self.shard_count) |
|
|
|
|
|
|
|
@property |
|
|
|
def shards(self): |
|
|
|
def shards(self) -> Dict[int, ShardInfo]: |
|
|
|
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object.""" |
|
|
|
return { shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items() } |
|
|
|
|
|
|
|
async def launch_shard(self, gateway, shard_id, *, initial=False): |
|
|
|
async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None: |
|
|
|
try: |
|
|
|
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id) |
|
|
|
ws = await asyncio.wait_for(coro, timeout=180.0) |
|
|
@ -372,7 +385,7 @@ class AutoShardedClient(Client): |
|
|
|
self.__shards[shard_id] = ret = Shard(ws, self, self.__queue.put_nowait) |
|
|
|
ret.launch() |
|
|
|
|
|
|
|
async def launch_shards(self): |
|
|
|
async def launch_shards(self) -> None: |
|
|
|
if self.shard_count is None: |
|
|
|
self.shard_count, gateway = await self.http.get_bot_gateway() |
|
|
|
else: |
|
|
@ -389,7 +402,7 @@ class AutoShardedClient(Client): |
|
|
|
|
|
|
|
self._connection.shards_launched.set() |
|
|
|
|
|
|
|
async def connect(self, *, reconnect=True): |
|
|
|
async def connect(self, *, reconnect: bool = True) -> None: |
|
|
|
self._reconnect = reconnect |
|
|
|
await self.launch_shards() |
|
|
|
|
|
|
@ -413,7 +426,7 @@ class AutoShardedClient(Client): |
|
|
|
elif item.type == EventType.clean_close: |
|
|
|
return |
|
|
|
|
|
|
|
async def close(self): |
|
|
|
async def close(self) -> None: |
|
|
|
"""|coro| |
|
|
|
|
|
|
|
Closes the connection to Discord. |
|
|
@ -425,7 +438,7 @@ class AutoShardedClient(Client): |
|
|
|
|
|
|
|
for vc in self.voice_clients: |
|
|
|
try: |
|
|
|
await vc.disconnect() |
|
|
|
await vc.disconnect(force=True) |
|
|
|
except Exception: |
|
|
|
pass |
|
|
|
|
|
|
@ -436,7 +449,7 @@ class AutoShardedClient(Client): |
|
|
|
await self.http.close() |
|
|
|
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None)) |
|
|
|
|
|
|
|
async def change_presence(self, *, activity=None, status=None, shard_id=None): |
|
|
|
async def change_presence(self, *, activity: Optional[BaseActivity] = None, status: Optional[Status] = None, shard_id: int = None) -> None: |
|
|
|
"""|coro| |
|
|
|
|
|
|
|
Changes the client's presence. |
|
|
@ -468,23 +481,23 @@ class AutoShardedClient(Client): |
|
|
|
""" |
|
|
|
|
|
|
|
if status is None: |
|
|
|
status = 'online' |
|
|
|
status_value = 'online' |
|
|
|
status_enum = Status.online |
|
|
|
elif status is Status.offline: |
|
|
|
status = 'invisible' |
|
|
|
status_value = 'invisible' |
|
|
|
status_enum = Status.offline |
|
|
|
else: |
|
|
|
status_enum = status |
|
|
|
status = str(status) |
|
|
|
status_value = str(status) |
|
|
|
|
|
|
|
if shard_id is None: |
|
|
|
for shard in self.__shards.values(): |
|
|
|
await shard.ws.change_presence(activity=activity, status=status) |
|
|
|
await shard.ws.change_presence(activity=activity, status=status_value) |
|
|
|
|
|
|
|
guilds = self._connection.guilds |
|
|
|
else: |
|
|
|
shard = self.__shards[shard_id] |
|
|
|
await shard.ws.change_presence(activity=activity, status=status) |
|
|
|
await shard.ws.change_presence(activity=activity, status=status_value) |
|
|
|
guilds = [g for g in self._connection.guilds if g.shard_id == shard_id] |
|
|
|
|
|
|
|
activities = () if activity is None else (activity,) |
|
|
@ -493,10 +506,11 @@ class AutoShardedClient(Client): |
|
|
|
if me is None: |
|
|
|
continue |
|
|
|
|
|
|
|
me.activities = activities |
|
|
|
# Member.activities is typehinted as Tuple[ActivityType, ...], we may be setting it as Tuple[BaseActivity, ...] |
|
|
|
me.activities = activities # type: ignore |
|
|
|
me.status = status_enum |
|
|
|
|
|
|
|
def is_ws_ratelimited(self): |
|
|
|
def is_ws_ratelimited(self) -> bool: |
|
|
|
""":class:`bool`: Whether the websocket is currently rate limited. |
|
|
|
|
|
|
|
This can be useful to know when deciding whether you should query members |
|
|
|