diff --git a/discord/client.py b/discord/client.py index 859315699..068e3d65c 100644 --- a/discord/client.py +++ b/discord/client.py @@ -223,8 +223,12 @@ class Client: 'ready': self._handle_ready } + self._hooks = { + 'before_identify': self._call_before_identify_hook + } + self._connection = ConnectionState(dispatch=self.dispatch, handlers=self._handlers, - syncer=self._syncer, http=self.http, loop=self.loop, **options) + hooks=self._hooks, syncer=self._syncer, http=self.http, loop=self.loop, **options) self._connection.shard_count = self.shard_count self._closed = False @@ -394,6 +398,36 @@ class Client: await self._connection.request_offline_members(guilds) + # hooks + + async def _call_before_identify_hook(self, shard_id, *, initial=False): + # This hook is an internal hook that actually calls the public one. + # It allows the library to have its own hook without stepping on the + # toes of those who need to override their own hook. + await self.before_identify_hook(shard_id, initial=initial) + + async def before_identify_hook(self, shard_id, *, initial=False): + """|coro| + + A hook that is called before IDENTIFYing a session. This is useful + if you wish to have more control over the synchronization of multiple + IDENTIFYing clients. + + The default implementation sleeps for 5 seconds. + + .. versionadded:: 1.4 + + Parameters + ------------ + shard_id: :class:`int` + The shard ID that requested being IDENTIFY'd + initial: :class:`bool` + Whether this IDENTIFY is the first initial IDENTIFY. + """ + + if not initial: + await asyncio.sleep(5.0) + # login state management async def login(self, token, *, bot=True): @@ -447,7 +481,7 @@ class Client: await self.close() async def _connect(self): - coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id) + coro = DiscordWebSocket.from_client(self, initial=True, shard_id=self.shard_id) self.ws = await asyncio.wait_for(coro, timeout=180.0) while True: try: @@ -455,11 +489,8 @@ class Client: except ReconnectWebSocket as e: log.info('Got a request to %s the websocket.', e.op) self.dispatch('disconnect') - if not e.resume: - await asyncio.sleep(5.0) - coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id, session=self.ws.session_id, - sequence=self.ws.sequence, resume=e.resume) + sequence=self.ws.sequence, resume=e.resume) self.ws = await asyncio.wait_for(coro, timeout=180.0) async def connect(self, *, reconnect=True): diff --git a/discord/gateway.py b/discord/gateway.py index 1f82c9ef8..db3c7fd57 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -250,7 +250,7 @@ class DiscordWebSocket: return not self.socket.closed @classmethod - async def from_client(cls, client, *, gateway=None, shard_id=None, session=None, sequence=None, resume=False): + async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False): """Creates a main websocket for Discord from a :class:`Client`. This is for internal use only. @@ -265,6 +265,8 @@ class DiscordWebSocket: ws._discord_parsers = client._connection.parsers ws._dispatch = client.dispatch ws.gateway = gateway + ws.call_hooks = client._connection.call_hooks + ws._initial_identify = initial ws.shard_id = shard_id ws.shard_count = client._connection.shard_count ws.session_id = session @@ -345,6 +347,7 @@ class DiscordWebSocket: 'afk': False } + await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify) await self.send_as_json(payload) log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id) diff --git a/discord/shard.py b/discord/shard.py index d96ad0712..f817fb9ab 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -96,9 +96,6 @@ class Shard: self._task.cancel() log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id) - if not exc.resume: - await asyncio.sleep(5.0) - coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id, session=self.ws.session_id, sequence=self.ws.sequence) self._dispatch('disconnect') @@ -144,7 +141,7 @@ class AutoShardedClient(Client): self._connection = AutoShardedConnectionState(dispatch=self.dispatch, handlers=self._handlers, syncer=self._syncer, - http=self.http, loop=self.loop, **kwargs) + hooks=self._hooks, http=self.http, loop=self.loop, **kwargs) # instead of a single websocket, we have multiple # the key is the shard_id @@ -208,12 +205,12 @@ class AutoShardedClient(Client): sub_guilds = list(sub_guilds) await self._connection.request_offline_members(sub_guilds, shard_id=shard_id) - async def launch_shard(self, gateway, shard_id): + async def launch_shard(self, gateway, shard_id, *, initial=False): try: - coro = DiscordWebSocket.from_client(self, gateway=gateway, shard_id=shard_id) + coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id) ws = await asyncio.wait_for(coro, timeout=180.0) except Exception: - log.info('Failed to connect for shard_id: %s. Retrying...', shard_id) + log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id) await asyncio.sleep(5.0) return await self.launch_shard(gateway, shard_id) @@ -232,11 +229,9 @@ class AutoShardedClient(Client): shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count) self._connection.shard_ids = shard_ids - last_shard_id = shard_ids[-1] for shard_id in shard_ids: - await self.launch_shard(gateway, shard_id) - if shard_id != last_shard_id: - await asyncio.sleep(5.0) + initial = shard_id == shard_ids[0] + await self.launch_shard(gateway, shard_id, initial=initial) self._connection.shards_launched.set() diff --git a/discord/state.py b/discord/state.py index 6148889d9..8b793c20a 100644 --- a/discord/state.py +++ b/discord/state.py @@ -64,7 +64,7 @@ log = logging.getLogger(__name__) ReadyState = namedtuple('ReadyState', ('launch', 'guilds')) class ConnectionState: - def __init__(self, *, dispatch, handlers, syncer, http, loop, **options): + def __init__(self, *, dispatch, handlers, hooks, syncer, http, loop, **options): self.loop = loop self.http = http self.max_messages = options.get('max_messages', 1000) @@ -75,6 +75,7 @@ class ConnectionState: self.syncer = syncer self.is_bot = None self.handlers = handlers + self.hooks = hooks self.shard_count = None self._ready_task = None self._fetch_offline = options.get('fetch_offline_members', True) @@ -170,6 +171,14 @@ class ConnectionState: else: func(*args, **kwargs) + async def call_hooks(self, key, *args, **kwargs): + try: + coro = self.hooks[key] + except KeyError: + pass + else: + await coro(*args, **kwargs) + @property def self_id(self): u = self.user