diff --git a/discord/client.py b/discord/client.py index 068e3d65c..a0e27d8ee 100644 --- a/discord/client.py +++ b/discord/client.py @@ -480,19 +480,6 @@ class Client: """ await self.close() - async def _connect(self): - coro = DiscordWebSocket.from_client(self, initial=True, shard_id=self.shard_id) - self.ws = await asyncio.wait_for(coro, timeout=180.0) - while True: - try: - await self.ws.poll_event() - except ReconnectWebSocket as e: - log.info('Got a request to %s the websocket.', e.op) - self.dispatch('disconnect') - coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id, session=self.ws.session_id, - sequence=self.ws.sequence, resume=e.resume) - self.ws = await asyncio.wait_for(coro, timeout=180.0) - async def connect(self, *, reconnect=True): """|coro| @@ -519,9 +506,22 @@ class Client: """ backoff = ExponentialBackoff() + ws_params = { + 'initial': True, + 'shard_id': self.shard_id, + } while not self.is_closed(): try: - await self._connect() + coro = DiscordWebSocket.from_client(self, **ws_params) + self.ws = await asyncio.wait_for(coro, timeout=60.0) + ws_params['initial'] = False + while True: + await self.ws.poll_event() + except ReconnectWebSocket as e: + log.info('Got a request to %s the websocket.', e.op) + self.dispatch('disconnect') + ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id) + continue except (OSError, HTTPException, GatewayNotFound, @@ -540,6 +540,11 @@ class Client: if self.is_closed(): return + # If we get connection reset by peer then try to RESUME + if isinstance(exc, OSError) and exc.errno in (54, 10054): + ws_params.update(sequence=self.ws.sequence, initial=False, resume=True, session=self.ws.session_id) + continue + # We should only get this when an unhandled close code happens, # such as a clean disconnect (1000) or a bad state (bad token, no sharding, etc) # sometimes, discord sends us 1000 for unknown reasons so we should reconnect @@ -552,6 +557,10 @@ class Client: retry = backoff.delay() log.exception("Attempting a reconnect in %.2fs", retry) await asyncio.sleep(retry) + # Always try to RESUME the connection + # If the connection is not RESUME-able then the gateway will invalidate the session. + # This is apparently what the official Discord client does. + ws_params.update(sequence=self.ws.sequence, resume=True, session=self.ws.session_id) async def close(self): """|coro| diff --git a/discord/gateway.py b/discord/gateway.py index 3f92ec1fe..f262477fb 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -508,16 +508,21 @@ class DiscordWebSocket: elif msg.type is aiohttp.WSMsgType.ERROR: log.debug('Received %s', msg) raise msg.data - elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE): + elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE): log.debug('Received %s', msg) raise WebSocketClosure - except WebSocketClosure as e: + except WebSocketClosure: + # Ensure the keep alive handler is closed + if self._keep_alive: + self._keep_alive.stop() + self._keep_alive = None + if self._can_handle_close(): log.info('Websocket closed with %s, attempting a reconnect.', self.socket.close_code) - raise ReconnectWebSocket(self.shard_id) from e + raise ReconnectWebSocket(self.shard_id) from None elif self.socket.close_code is not None: log.info('Websocket closed with %s, cannot reconnect.', self.socket.close_code) - raise ConnectionClosed(self.socket, shard_id=self.shard_id) from e + raise ConnectionClosed(self.socket, shard_id=self.shard_id) from None async def send(self, data): self._dispatch('socket_raw_send', data) @@ -598,6 +603,7 @@ class DiscordWebSocket: async def close(self, code=4000): if self._keep_alive: self._keep_alive.stop() + self._keep_alive = None await self.socket.close(code=code) diff --git a/discord/http.py b/discord/http.py index 00e66ac20..ceb6137af 100644 --- a/discord/http.py +++ b/discord/http.py @@ -180,68 +180,76 @@ class HTTPClient: if files: for f in files: f.reset(seek=tries) - - async with self.__session.request(method, url, **kwargs) as r: - log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), r.status) - - # even errors have text involved in them so this is safe to call - data = await json_or_text(r) - - # check if we have rate limit header information - remaining = r.headers.get('X-Ratelimit-Remaining') - if remaining == '0' and r.status != 429: - # we've depleted our current bucket - delta = utils._parse_ratelimit_header(r, use_clock=self.use_clock) - log.debug('A rate limit bucket has been exhausted (bucket: %s, retry: %s).', bucket, delta) - maybe_lock.defer() - self.loop.call_later(delta, lock.release) - - # the request was successful so just return the text/json - if 300 > r.status >= 200: - log.debug('%s %s has received %s', method, url, data) - return data - - # we are being rate limited - if r.status == 429: - if not r.headers.get('Via'): - # Banned by Cloudflare more than likely. + try: + async with self.__session.request(method, url, **kwargs) as r: + log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), r.status) + + # even errors have text involved in them so this is safe to call + data = await json_or_text(r) + + # check if we have rate limit header information + remaining = r.headers.get('X-Ratelimit-Remaining') + if remaining == '0' and r.status != 429: + # we've depleted our current bucket + delta = utils._parse_ratelimit_header(r, use_clock=self.use_clock) + log.debug('A rate limit bucket has been exhausted (bucket: %s, retry: %s).', bucket, delta) + maybe_lock.defer() + self.loop.call_later(delta, lock.release) + + # the request was successful so just return the text/json + if 300 > r.status >= 200: + log.debug('%s %s has received %s', method, url, data) + return data + + # we are being rate limited + if r.status == 429: + if not r.headers.get('Via'): + # Banned by Cloudflare more than likely. + raise HTTPException(r, data) + + fmt = 'We are being rate limited. Retrying in %.2f seconds. Handled under the bucket "%s"' + + # sleep a bit + retry_after = data['retry_after'] / 1000.0 + log.warning(fmt, retry_after, bucket) + + # check if it's a global rate limit + is_global = data.get('global', False) + if is_global: + log.warning('Global rate limit has been hit. Retrying in %.2f seconds.', retry_after) + self._global_over.clear() + + await asyncio.sleep(retry_after) + log.debug('Done sleeping for the rate limit. Retrying...') + + # release the global lock now that the + # global rate limit has passed + if is_global: + self._global_over.set() + log.debug('Global rate limit is now over.') + + continue + + # we've received a 500 or 502, unconditional retry + if r.status in {500, 502}: + await asyncio.sleep(1 + tries * 2) + continue + + # the usual error cases + if r.status == 403: + raise Forbidden(r, data) + elif r.status == 404: + raise NotFound(r, data) + else: raise HTTPException(r, data) - fmt = 'We are being rate limited. Retrying in %.2f seconds. Handled under the bucket "%s"' - - # sleep a bit - retry_after = data['retry_after'] / 1000.0 - log.warning(fmt, retry_after, bucket) - - # check if it's a global rate limit - is_global = data.get('global', False) - if is_global: - log.warning('Global rate limit has been hit. Retrying in %.2f seconds.', retry_after) - self._global_over.clear() - - await asyncio.sleep(retry_after) - log.debug('Done sleeping for the rate limit. Retrying...') - - # release the global lock now that the - # global rate limit has passed - if is_global: - self._global_over.set() - log.debug('Global rate limit is now over.') - - continue - - # we've received a 500 or 502, unconditional retry - if r.status in {500, 502}: - await asyncio.sleep(1 + tries * 2) + # This is handling exceptions from the request + except OSError as e: + # Connection reset by peer + if e.errno in (54, 10054): + # Just re-do the request continue - # the usual error cases - if r.status == 403: - raise Forbidden(r, data) - elif r.status == 404: - raise NotFound(r, data) - else: - raise HTTPException(r, data) # We've run out of retries, raise. raise HTTPException(r, data) diff --git a/discord/shard.py b/discord/shard.py index dfa3849cd..7659e5ec7 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -112,6 +112,12 @@ class Shard: if self._client.is_closed(): return + if isinstance(e, OSError) and e.errno in (54, 10054): + # If we get Connection reset by peer then always try to RESUME the connection. + exc = ReconnectWebSocket(self.id, resume=True) + self._queue.put_nowait(EventItem(EventType.resume, self, exc)) + return + if isinstance(e, ConnectionClosed): if e.code != 1000: self._queue.put_nowait(EventItem(EventType.close, self, e)) @@ -142,7 +148,7 @@ class Shard: try: coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id, session=self.ws.session_id, sequence=self.ws.sequence) - self.ws = await asyncio.wait_for(coro, timeout=180.0) + self.ws = await asyncio.wait_for(coro, timeout=60.0) except self._handled_exceptions as e: await self._handle_disconnect(e) else: @@ -152,7 +158,7 @@ class Shard: self._cancel_task() try: coro = DiscordWebSocket.from_client(self._client, shard_id=self.id) - self.ws = await asyncio.wait_for(coro, timeout=180.0) + self.ws = await asyncio.wait_for(coro, timeout=60.0) except self._handled_exceptions as e: await self._handle_disconnect(e) else: