diff --git a/discord/backoff.py b/discord/backoff.py new file mode 100644 index 000000000..bfe87b17b --- /dev/null +++ b/discord/backoff.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import time +import random + +class ExponentialBackoff: + """An implementation of the exponential backoff algorithm + + Provides a convenient interface to implement an exponential backoff + for reconnecting or retrying transmissions in a distributed network. + + Once instantiated, the delay method will return the next interval to + wait for when retrying a connection or transmission. The maximum + delay increases exponentially with each retry up to a maximum of + 2^10 * base, and is reset if no more attempts are needed in a period + of 2^11 * base seconds. + + Parameters + ---------- + base: int + The base delay in seconds. The first retry-delay will be up to + this many seconds. + integral: bool + Set to True if whole periods of base is desirable, otherwise any + number in between may be returned. + """ + + def __init__(self, base=1, *, integral=False): + self._base = base + + self._exp = 0 + self._max = 10 + self._reset_time = base * 2 ** 11 + self._last_invocation = time.monotonic() + + # Use our own random instance to avoid messing with global one + rand = random.Random() + rand.seed() + + self._randfunc = rand.rand_range if integral else rand.uniform + + def delay(self): + """Compute the next delay + + Returns the next delay to wait according to the exponential + backoff algorithm. This is a value between 0 and base * 2^exp + where exponent starts off at 1 and is incremented at every + invocation of this method up to a maximum of 10. + + If a period of more than base * 2^11 has passed since the last + retry, the exponent is reset to 1. + """ + invocation = time.monotonic() + interval = invocation - self._last_invocation + self._last_invocation = invocation + + if interval > self._reset_time: + self._exp = 0 + + self._exp = min(self._exp + 1, self._max) + return self._randfunc(0, self._base * 2 ** self._exp) diff --git a/discord/client.py b/discord/client.py index d856a1c57..dc3b24ccb 100644 --- a/discord/client.py +++ b/discord/client.py @@ -35,6 +35,7 @@ from .emoji import Emoji from .http import HTTPClient from .state import ConnectionState from . import utils, compat +from .backoff import ExponentialBackoff import asyncio import aiohttp @@ -347,11 +348,35 @@ class Client: yield from self.close() @asyncio.coroutine - def connect(self): + def _connect(self): + self.ws = yield from DiscordWebSocket.from_client(self) + + while True: + try: + yield from self.ws.poll_event() + except ResumeWebSocket as e: + log.info('Got a request to RESUME the websocket.') + self.ws = yield from DiscordWebSocket.from_client(self, shard_id=self.shard_id, + session=self.ws.session_id, + sequence=self.ws.sequence, + resume=True) + + @asyncio.coroutine + def connect(self, *, reconnect=True): """|coro| Creates a websocket connection and lets the websocket listen - to messages from discord. + to messages from discord. This is a loop that runs the entire + event system and miscellaneous aspects of the library. Control + is not resumed until the WebSocket connection is terminated. + + Parameters + ----------- + reconnect: bool + If we should attempt reconnecting, either due to internet + failure or a specific failure on Discord's part. Certain + disconnects that lead to bad state will not be handled (such as + invalid sharding payloads or bad tokens). Raises ------- @@ -361,21 +386,31 @@ class Client: ConnectionClosed The websocket connection has been terminated. """ - self.ws = yield from DiscordWebSocket.from_client(self) + backoff = ExponentialBackoff() while not self.is_closed(): try: - yield from self.ws.poll_event() - except ResumeWebSocket as e: - log.info('Got a request to RESUME the websocket.') - self.ws = yield from DiscordWebSocket.from_client(self, shard_id=self.shard_id, - session=self.ws.session_id, - sequence=self.ws.sequence, - resume=True) + yield from self._connect() except ConnectionClosed as e: + # We should only get this when an unhandled close code happens, + # such as a clean disconnect (1000) or a bad state (bad token, no sharding, etc) + # in both cases we should just terminate our connection. yield from self.close() if e.code != 1000: raise + except (HTTPException, + GatewayNotFound, + aiohttp.ClientError, + websockets.InvalidHandshake, + websockets.WebSocketProtocolError) as e: + + if not reconnect: + yield from self.close() + raise + + retry = backoff.delay() + log.exception("Attempting a reconnect in {:.2f}s".format(retry)) + yield from asyncio.sleep(retry, loop=self.loop) @asyncio.coroutine def close(self): @@ -409,8 +444,11 @@ class Client: A shorthand coroutine for :meth:`login` + :meth:`connect`. """ - yield from self.login(*args, **kwargs) - yield from self.connect() + + bot = kwargs.pop('bot', True) + reconnect = kwargs.pop('reconnect', True) + yield from self.login(*args, bot=bot) + yield from self.connect(reconnect=reconnect) def run(self, *args, **kwargs): """A blocking call that abstracts away the `event loop`_ diff --git a/discord/shard.py b/discord/shard.py index dade454db..ab0ca31f3 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -60,11 +60,6 @@ class Shard: shard_id=self.id, session=self.ws.session_id, sequence=self.ws.sequence) - except ConnectionClosed as e: - yield from self._client.close() - if e.code != 1000: - raise - def get_future(self): if self._current.done(): self._current = compat.create_task(self.poll(), loop=self.loop) @@ -220,25 +215,15 @@ class AutoShardedClient(Client): self._still_sharding = False @asyncio.coroutine - def connect(self): - """|coro| - - Creates a websocket connection and lets the websocket listen - to messages from discord. - - Raises - ------- - GatewayNotFound - If the gateway to connect to discord is not found. Usually if this - is thrown then there is a discord API outage. - ConnectionClosed - The websocket connection has been terminated. - """ + def _connect(self): yield from self.launch_shards() - while not self.is_closed(): + while True: pollers = [shard.get_future() for shard in self.shards.values()] - yield from asyncio.wait(pollers, loop=self.loop, return_when=asyncio.FIRST_COMPLETED) + done, pending = yield from asyncio.wait(pollers, loop=self.loop, return_when=asyncio.FIRST_COMPLETED) + for f in done: + # we wanna re-raise to the main Client.connect handler if applicable + f.result() @asyncio.coroutine def close(self):