|
|
@ -33,61 +33,58 @@ import websockets |
|
|
|
from .state import AutoShardedConnectionState |
|
|
|
from .client import Client |
|
|
|
from .gateway import * |
|
|
|
from .errors import ClientException, InvalidArgument |
|
|
|
from .errors import ClientException, InvalidArgument, ConnectionClosed |
|
|
|
from . import utils |
|
|
|
from .enums import Status |
|
|
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
|
|
class EventType: |
|
|
|
Close = 0 |
|
|
|
Resume = 1 |
|
|
|
Identify = 2 |
|
|
|
|
|
|
|
class Shard: |
|
|
|
def __init__(self, ws, client): |
|
|
|
self.ws = ws |
|
|
|
self._client = client |
|
|
|
self._dispatch = client.dispatch |
|
|
|
self._queue = client._queue |
|
|
|
self.loop = self._client.loop |
|
|
|
self._current = self.loop.create_future() |
|
|
|
self._current.set_result(None) # we just need an already done future |
|
|
|
self._pending = asyncio.Event() |
|
|
|
self._pending_task = None |
|
|
|
self._task = None |
|
|
|
|
|
|
|
@property |
|
|
|
def id(self): |
|
|
|
return self.ws.shard_id |
|
|
|
|
|
|
|
def is_pending(self): |
|
|
|
return not self._pending.is_set() |
|
|
|
|
|
|
|
def complete_pending_reads(self): |
|
|
|
self._pending.set() |
|
|
|
|
|
|
|
async def _pending_reads(self): |
|
|
|
try: |
|
|
|
while self.is_pending(): |
|
|
|
await self.poll() |
|
|
|
except asyncio.CancelledError: |
|
|
|
pass |
|
|
|
|
|
|
|
def launch_pending_reads(self): |
|
|
|
self._pending_task = asyncio.ensure_future(self._pending_reads(), loop=self.loop) |
|
|
|
|
|
|
|
def wait(self): |
|
|
|
return self._pending_task |
|
|
|
def launch(self): |
|
|
|
self._task = self.loop.create_task(self.worker()) |
|
|
|
|
|
|
|
async def poll(self): |
|
|
|
try: |
|
|
|
await self.ws.poll_event() |
|
|
|
except ResumeWebSocket: |
|
|
|
log.info('Got a request to RESUME the websocket at Shard ID %s.', self.id) |
|
|
|
coro = DiscordWebSocket.from_client(self._client, resume=True, shard_id=self.id, |
|
|
|
session=self.ws.session_id, sequence=self.ws.sequence) |
|
|
|
self._dispatch('disconnect') |
|
|
|
self.ws = await asyncio.wait_for(coro, timeout=180.0) |
|
|
|
|
|
|
|
def get_future(self): |
|
|
|
if self._current.done(): |
|
|
|
self._current = asyncio.ensure_future(self.poll(), loop=self.loop) |
|
|
|
async def worker(self): |
|
|
|
while True: |
|
|
|
try: |
|
|
|
await self.ws.poll_event() |
|
|
|
except ReconnectWebSocket as e: |
|
|
|
etype = EventType.resume if e.resume else EventType.identify |
|
|
|
self._queue.put_nowait((etype, self, e)) |
|
|
|
break |
|
|
|
except ConnectionClosed as e: |
|
|
|
self._queue.put_nowait((EventType.close, self, e)) |
|
|
|
break |
|
|
|
|
|
|
|
async def reconnect(self, exc): |
|
|
|
if self._task is not None and not self._task.done(): |
|
|
|
self._task.cancel() |
|
|
|
|
|
|
|
log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id) |
|
|
|
if not exc.resume: |
|
|
|
await asyncio.sleep(5.0) |
|
|
|
|
|
|
|
return self._current |
|
|
|
coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id, |
|
|
|
session=self.ws.session_id, sequence=self.ws.sequence) |
|
|
|
self._dispatch('disconnect') |
|
|
|
self.ws = await asyncio.wait_for(coro, timeout=180.0) |
|
|
|
self.launch() |
|
|
|
|
|
|
|
class AutoShardedClient(Client): |
|
|
|
"""A client similar to :class:`Client` except it handles the complications |
|
|
@ -134,6 +131,7 @@ class AutoShardedClient(Client): |
|
|
|
# the key is the shard_id |
|
|
|
self.shards = {} |
|
|
|
self._connection._get_websocket = self._get_websocket |
|
|
|
self._queue = asyncio.PriorityQueue() |
|
|
|
|
|
|
|
def _get_websocket(self, guild_id=None, *, shard_id=None): |
|
|
|
if shard_id is None: |
|
|
@ -220,8 +218,10 @@ class AutoShardedClient(Client): |
|
|
|
|
|
|
|
# keep reading the shard while others connect |
|
|
|
self.shards[shard_id] = ret = Shard(ws, self) |
|
|
|
ret.launch_pending_reads() |
|
|
|
await asyncio.sleep(5.0) |
|
|
|
ret.launch() |
|
|
|
|
|
|
|
if len(self.shards) == self.shard_count: |
|
|
|
self._connection.shards_launched.set() |
|
|
|
|
|
|
|
async def launch_shards(self): |
|
|
|
if self.shard_count is None: |
|
|
@ -234,26 +234,29 @@ class AutoShardedClient(Client): |
|
|
|
shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count) |
|
|
|
self._connection.shard_ids = shard_ids |
|
|
|
|
|
|
|
last_shard_id = shard_ids[-1] |
|
|
|
for shard_id in shard_ids: |
|
|
|
await self.launch_shard(gateway, shard_id) |
|
|
|
if shard_id != last_shard_id: |
|
|
|
await asyncio.sleep(5.0) |
|
|
|
|
|
|
|
shards_to_wait_for = [] |
|
|
|
for shard in self.shards.values(): |
|
|
|
shard.complete_pending_reads() |
|
|
|
shards_to_wait_for.append(shard.wait()) |
|
|
|
# shards_to_wait_for = [] |
|
|
|
# for shard in self.shards.values(): |
|
|
|
# shard.complete_pending_reads() |
|
|
|
# shards_to_wait_for.append(shard.wait()) |
|
|
|
|
|
|
|
# wait for all pending tasks to finish |
|
|
|
await utils.sane_wait_for(shards_to_wait_for, timeout=300.0) |
|
|
|
# # wait for all pending tasks to finish |
|
|
|
# await utils.sane_wait_for(shards_to_wait_for, timeout=300.0) |
|
|
|
|
|
|
|
async def _connect(self): |
|
|
|
await self.launch_shards() |
|
|
|
|
|
|
|
while True: |
|
|
|
pollers = [shard.get_future() for shard in self.shards.values()] |
|
|
|
done, _ = await asyncio.wait(pollers, return_when=asyncio.FIRST_COMPLETED) |
|
|
|
for f in done: |
|
|
|
# we wanna re-raise to the main Client.connect handler if applicable |
|
|
|
f.result() |
|
|
|
etype, shard, exc = await self._queue.get() |
|
|
|
if etype == EventType.close: |
|
|
|
raise exc |
|
|
|
elif etype in (EventType.identify, EventType.resume): |
|
|
|
await shard.reconnect(exc) |
|
|
|
|
|
|
|
async def close(self): |
|
|
|
"""|coro| |
|
|
|