Browse Source

Reimplement zlib streaming.

This time with less bugs. It turned out that the crash was due to a
synchronisation issue between the pending reads and the actual shard
polling mechanism.

Essentially the pending reads would be cancelled via a simple bool but
there would still be a pass left and thus we would have a single
pending read left before or after running the polling mechanism and
this would cause a race condition.

Now the pending read mechanism is properly waited for before returning
control back to the caller.
pull/853/head
Rapptz 8 years ago
parent
commit
47a58d354d
  1. 15
      discord/gateway.py
  2. 16
      discord/http.py
  3. 42
      discord/shard.py

15
discord/gateway.py

@ -186,6 +186,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
# ws related stuff
self.session_id = None
self.sequence = None
self._zlib = zlib.decompressobj()
self._buffer = bytearray()
@classmethod
@asyncio.coroutine
@ -312,8 +314,17 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
self._dispatch('socket_raw_receive', msg)
if isinstance(msg, bytes):
msg = zlib.decompress(msg, 15, 10490000) # This is 10 MiB
msg = msg.decode('utf-8')
self._buffer.extend(msg)
if len(msg) >= 4:
if msg[-4:] == b'\x00\x00\xff\xff':
msg = self._zlib.decompress(self._buffer)
msg = msg.decode('utf-8')
self._buffer = bytearray()
else:
return
else:
return
msg = json.loads(msg)

16
discord/http.py

@ -739,21 +739,29 @@ class HTTPClient:
return self.request(Route('GET', '/oauth2/applications/@me'))
@asyncio.coroutine
def get_gateway(self):
def get_gateway(self, *, encoding='json', v=6, zlib=True):
try:
data = yield from self.request(Route('GET', '/gateway'))
except HTTPException as e:
raise GatewayNotFound() from e
return data.get('url') + '?encoding=json&v=6'
if zlib:
value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
else:
value = '{0}?encoding={1}&v={2}'
return value.format(data['url'], encoding, v)
@asyncio.coroutine
def get_bot_gateway(self):
def get_bot_gateway(self, *, encoding='json', v=6, zlib=True):
try:
data = yield from self.request(Route('GET', '/gateway/bot'))
except HTTPException as e:
raise GatewayNotFound() from e
if zlib:
value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
else:
return data['shards'], data['url'] + '?encoding=json&v=6'
value = '{0}?encoding={1}&v={2}'
return data['shards'], value.format(data['url'], encoding, v)
def get_user_info(self, user_id):
return self.request(Route('GET', '/users/{user_id}', user_id=user_id))

42
discord/shard.py

@ -28,7 +28,7 @@ from .state import AutoShardedConnectionState
from .client import Client
from .gateway import *
from .errors import ClientException, InvalidArgument
from . import compat
from . import compat, utils
from .enums import Status
import asyncio
@ -45,11 +45,32 @@ class Shard:
self.loop = self._client.loop
self._current = compat.create_future(self.loop)
self._current.set_result(None) # we just need an already done future
self._pending = asyncio.Event(loop=self.loop)
self._pending_task = None
@property
def id(self):
return self.ws.shard_id
def is_pending(self):
return not self._pending.is_set()
def complete_pending_reads(self):
self._pending.set()
def _pending_reads(self):
try:
while self.is_pending():
yield from self.poll()
except asyncio.CancelledError:
pass
def launch_pending_reads(self):
self._pending_task = compat.create_task(self._pending_reads(), loop=self.loop)
def wait(self):
return self._pending_task
@asyncio.coroutine
def poll(self):
try:
@ -127,7 +148,6 @@ class AutoShardedClient(Client):
return self.shards[i].ws
self._connection._get_websocket = _get_websocket
self._still_sharding = True
@asyncio.coroutine
def _chunker(self, guild, *, shard_id=None):
@ -199,14 +219,6 @@ class AutoShardedClient(Client):
sub_guilds = list(sub_guilds)
yield from self._connection.request_offline_members(sub_guilds, shard_id=shard_id)
@asyncio.coroutine
def pending_reads(self, shard):
try:
while self._still_sharding:
yield from shard.poll()
except asyncio.CancelledError:
pass
@asyncio.coroutine
def launch_shard(self, gateway, shard_id):
try:
@ -235,7 +247,7 @@ class AutoShardedClient(Client):
# keep reading the shard while others connect
self.shards[shard_id] = ret = Shard(ws, self)
compat.create_task(self.pending_reads(ret), loop=self.loop)
ret.launch_pending_reads()
yield from asyncio.sleep(5.0, loop=self.loop)
@asyncio.coroutine
@ -252,7 +264,13 @@ class AutoShardedClient(Client):
for shard_id in shard_ids:
yield from self.launch_shard(gateway, shard_id)
self._still_sharding = False
shards_to_wait_for = []
for shard in self.shards.values():
shard.complete_pending_reads()
shards_to_wait_for.append(shard.wait())
# wait for all pending tasks to finish
yield from utils.sane_wait_for(shards_to_wait_for, timeout=300.0, loop=self.loop)
@asyncio.coroutine
def _connect(self):

Loading…
Cancel
Save