From 47a58d354d3c289ce8fcd56f817976a43029887f Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sat, 14 Oct 2017 21:17:27 -0400 Subject: [PATCH] Reimplement zlib streaming. This time with less bugs. It turned out that the crash was due to a synchronisation issue between the pending reads and the actual shard polling mechanism. Essentially the pending reads would be cancelled via a simple bool but there would still be a pass left and thus we would have a single pending read left before or after running the polling mechanism and this would cause a race condition. Now the pending read mechanism is properly waited for before returning control back to the caller. --- discord/gateway.py | 15 +++++++++++++-- discord/http.py | 16 ++++++++++++---- discord/shard.py | 42 ++++++++++++++++++++++++++++++------------ 3 files changed, 55 insertions(+), 18 deletions(-) diff --git a/discord/gateway.py b/discord/gateway.py index 0ab027603..547d34015 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -186,6 +186,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): # ws related stuff self.session_id = None self.sequence = None + self._zlib = zlib.decompressobj() + self._buffer = bytearray() @classmethod @asyncio.coroutine @@ -312,8 +314,17 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): self._dispatch('socket_raw_receive', msg) if isinstance(msg, bytes): - msg = zlib.decompress(msg, 15, 10490000) # This is 10 MiB - msg = msg.decode('utf-8') + self._buffer.extend(msg) + + if len(msg) >= 4: + if msg[-4:] == b'\x00\x00\xff\xff': + msg = self._zlib.decompress(self._buffer) + msg = msg.decode('utf-8') + self._buffer = bytearray() + else: + return + else: + return msg = json.loads(msg) diff --git a/discord/http.py b/discord/http.py index fa6678eee..8c4ebb166 100644 --- a/discord/http.py +++ b/discord/http.py @@ -739,21 +739,29 @@ class HTTPClient: return self.request(Route('GET', '/oauth2/applications/@me')) @asyncio.coroutine - def get_gateway(self): + def get_gateway(self, *, encoding='json', v=6, zlib=True): try: data = yield from self.request(Route('GET', '/gateway')) except HTTPException as e: raise GatewayNotFound() from e - return data.get('url') + '?encoding=json&v=6' + if zlib: + value = '{0}?encoding={1}&v={2}&compress=zlib-stream' + else: + value = '{0}?encoding={1}&v={2}' + return value.format(data['url'], encoding, v) @asyncio.coroutine - def get_bot_gateway(self): + def get_bot_gateway(self, *, encoding='json', v=6, zlib=True): try: data = yield from self.request(Route('GET', '/gateway/bot')) except HTTPException as e: raise GatewayNotFound() from e + + if zlib: + value = '{0}?encoding={1}&v={2}&compress=zlib-stream' else: - return data['shards'], data['url'] + '?encoding=json&v=6' + value = '{0}?encoding={1}&v={2}' + return data['shards'], value.format(data['url'], encoding, v) def get_user_info(self, user_id): return self.request(Route('GET', '/users/{user_id}', user_id=user_id)) diff --git a/discord/shard.py b/discord/shard.py index 89463059f..f7f230db8 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -28,7 +28,7 @@ from .state import AutoShardedConnectionState from .client import Client from .gateway import * from .errors import ClientException, InvalidArgument -from . import compat +from . import compat, utils from .enums import Status import asyncio @@ -45,11 +45,32 @@ class Shard: self.loop = self._client.loop self._current = compat.create_future(self.loop) self._current.set_result(None) # we just need an already done future + self._pending = asyncio.Event(loop=self.loop) + self._pending_task = None @property def id(self): return self.ws.shard_id + def is_pending(self): + return not self._pending.is_set() + + def complete_pending_reads(self): + self._pending.set() + + def _pending_reads(self): + try: + while self.is_pending(): + yield from self.poll() + except asyncio.CancelledError: + pass + + def launch_pending_reads(self): + self._pending_task = compat.create_task(self._pending_reads(), loop=self.loop) + + def wait(self): + return self._pending_task + @asyncio.coroutine def poll(self): try: @@ -127,7 +148,6 @@ class AutoShardedClient(Client): return self.shards[i].ws self._connection._get_websocket = _get_websocket - self._still_sharding = True @asyncio.coroutine def _chunker(self, guild, *, shard_id=None): @@ -199,14 +219,6 @@ class AutoShardedClient(Client): sub_guilds = list(sub_guilds) yield from self._connection.request_offline_members(sub_guilds, shard_id=shard_id) - @asyncio.coroutine - def pending_reads(self, shard): - try: - while self._still_sharding: - yield from shard.poll() - except asyncio.CancelledError: - pass - @asyncio.coroutine def launch_shard(self, gateway, shard_id): try: @@ -235,7 +247,7 @@ class AutoShardedClient(Client): # keep reading the shard while others connect self.shards[shard_id] = ret = Shard(ws, self) - compat.create_task(self.pending_reads(ret), loop=self.loop) + ret.launch_pending_reads() yield from asyncio.sleep(5.0, loop=self.loop) @asyncio.coroutine @@ -252,7 +264,13 @@ class AutoShardedClient(Client): for shard_id in shard_ids: yield from self.launch_shard(gateway, shard_id) - self._still_sharding = False + shards_to_wait_for = [] + for shard in self.shards.values(): + shard.complete_pending_reads() + shards_to_wait_for.append(shard.wait()) + + # wait for all pending tasks to finish + yield from utils.sane_wait_for(shards_to_wait_for, timeout=300.0, loop=self.loop) @asyncio.coroutine def _connect(self):