Browse Source

Add before_identify_hook to have finer control over IDENTIFY syncing

pull/5164/head
Rapptz 5 years ago
parent
commit
394b514cc9
  1. 43
      discord/client.py
  2. 5
      discord/gateway.py
  3. 17
      discord/shard.py
  4. 11
      discord/state.py

43
discord/client.py

@ -223,8 +223,12 @@ class Client:
'ready': self._handle_ready
}
self._hooks = {
'before_identify': self._call_before_identify_hook
}
self._connection = ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
syncer=self._syncer, http=self.http, loop=self.loop, **options)
hooks=self._hooks, syncer=self._syncer, http=self.http, loop=self.loop, **options)
self._connection.shard_count = self.shard_count
self._closed = False
@ -394,6 +398,36 @@ class Client:
await self._connection.request_offline_members(guilds)
# hooks
async def _call_before_identify_hook(self, shard_id, *, initial=False):
# This hook is an internal hook that actually calls the public one.
# It allows the library to have its own hook without stepping on the
# toes of those who need to override their own hook.
await self.before_identify_hook(shard_id, initial=initial)
async def before_identify_hook(self, shard_id, *, initial=False):
"""|coro|
A hook that is called before IDENTIFYing a session. This is useful
if you wish to have more control over the synchronization of multiple
IDENTIFYing clients.
The default implementation sleeps for 5 seconds.
.. versionadded:: 1.4
Parameters
------------
shard_id: :class:`int`
The shard ID that requested being IDENTIFY'd
initial: :class:`bool`
Whether this IDENTIFY is the first initial IDENTIFY.
"""
if not initial:
await asyncio.sleep(5.0)
# login state management
async def login(self, token, *, bot=True):
@ -447,7 +481,7 @@ class Client:
await self.close()
async def _connect(self):
coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id)
coro = DiscordWebSocket.from_client(self, initial=True, shard_id=self.shard_id)
self.ws = await asyncio.wait_for(coro, timeout=180.0)
while True:
try:
@ -455,11 +489,8 @@ class Client:
except ReconnectWebSocket as e:
log.info('Got a request to %s the websocket.', e.op)
self.dispatch('disconnect')
if not e.resume:
await asyncio.sleep(5.0)
coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id, session=self.ws.session_id,
sequence=self.ws.sequence, resume=e.resume)
sequence=self.ws.sequence, resume=e.resume)
self.ws = await asyncio.wait_for(coro, timeout=180.0)
async def connect(self, *, reconnect=True):

5
discord/gateway.py

@ -250,7 +250,7 @@ class DiscordWebSocket:
return not self.socket.closed
@classmethod
async def from_client(cls, client, *, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
"""Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only.
@ -265,6 +265,8 @@ class DiscordWebSocket:
ws._discord_parsers = client._connection.parsers
ws._dispatch = client.dispatch
ws.gateway = gateway
ws.call_hooks = client._connection.call_hooks
ws._initial_identify = initial
ws.shard_id = shard_id
ws.shard_count = client._connection.shard_count
ws.session_id = session
@ -345,6 +347,7 @@ class DiscordWebSocket:
'afk': False
}
await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify)
await self.send_as_json(payload)
log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)

17
discord/shard.py

@ -96,9 +96,6 @@ class Shard:
self._task.cancel()
log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
if not exc.resume:
await asyncio.sleep(5.0)
coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
session=self.ws.session_id, sequence=self.ws.sequence)
self._dispatch('disconnect')
@ -144,7 +141,7 @@ class AutoShardedClient(Client):
self._connection = AutoShardedConnectionState(dispatch=self.dispatch,
handlers=self._handlers, syncer=self._syncer,
http=self.http, loop=self.loop, **kwargs)
hooks=self._hooks, http=self.http, loop=self.loop, **kwargs)
# instead of a single websocket, we have multiple
# the key is the shard_id
@ -208,12 +205,12 @@ class AutoShardedClient(Client):
sub_guilds = list(sub_guilds)
await self._connection.request_offline_members(sub_guilds, shard_id=shard_id)
async def launch_shard(self, gateway, shard_id):
async def launch_shard(self, gateway, shard_id, *, initial=False):
try:
coro = DiscordWebSocket.from_client(self, gateway=gateway, shard_id=shard_id)
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
ws = await asyncio.wait_for(coro, timeout=180.0)
except Exception:
log.info('Failed to connect for shard_id: %s. Retrying...', shard_id)
log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id)
await asyncio.sleep(5.0)
return await self.launch_shard(gateway, shard_id)
@ -232,11 +229,9 @@ class AutoShardedClient(Client):
shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count)
self._connection.shard_ids = shard_ids
last_shard_id = shard_ids[-1]
for shard_id in shard_ids:
await self.launch_shard(gateway, shard_id)
if shard_id != last_shard_id:
await asyncio.sleep(5.0)
initial = shard_id == shard_ids[0]
await self.launch_shard(gateway, shard_id, initial=initial)
self._connection.shards_launched.set()

11
discord/state.py

@ -64,7 +64,7 @@ log = logging.getLogger(__name__)
ReadyState = namedtuple('ReadyState', ('launch', 'guilds'))
class ConnectionState:
def __init__(self, *, dispatch, handlers, syncer, http, loop, **options):
def __init__(self, *, dispatch, handlers, hooks, syncer, http, loop, **options):
self.loop = loop
self.http = http
self.max_messages = options.get('max_messages', 1000)
@ -75,6 +75,7 @@ class ConnectionState:
self.syncer = syncer
self.is_bot = None
self.handlers = handlers
self.hooks = hooks
self.shard_count = None
self._ready_task = None
self._fetch_offline = options.get('fetch_offline_members', True)
@ -170,6 +171,14 @@ class ConnectionState:
else:
func(*args, **kwargs)
async def call_hooks(self, key, *args, **kwargs):
try:
coro = self.hooks[key]
except KeyError:
pass
else:
await coro(*args, **kwargs)
@property
def self_id(self):
u = self.user

Loading…
Cancel
Save