Browse Source

Separately delay ready event for each shard

pull/7932/head
Lilly Rose Berner 3 years ago
committed by GitHub
parent
commit
2dbf14bb72
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      discord/client.py
  2. 2
      discord/gateway.py
  3. 2
      discord/shard.py
  4. 155
      discord/state.py
  5. 7
      discord/types/gateway.py

1
discord/client.py

@ -484,7 +484,6 @@ class Client:
self.loop = loop self.loop = loop
self.http.loop = loop self.http.loop = loop
self._connection.loop = loop self._connection.loop = loop
await self._connection.async_setup()
self._ready = asyncio.Event() self._ready = asyncio.Event()

2
discord/gateway.py

@ -546,8 +546,6 @@ class DiscordWebSocket:
self._trace = trace = data.get('_trace', []) self._trace = trace = data.get('_trace', [])
self.sequence = msg['s'] self.sequence = msg['s']
self.session_id = data['session_id'] self.session_id = data['session_id']
# pass back shard ID to ready handler
data['__shard_id__'] = self.shard_id
_log.info( _log.info(
'Shard ID %s has connected to Gateway: %s (Session ID: %s).', 'Shard ID %s has connected to Gateway: %s (Session ID: %s).',
self.shard_id, self.shard_id,

2
discord/shard.py

@ -423,8 +423,6 @@ class AutoShardedClient(Client):
initial = shard_id == shard_ids[0] initial = shard_id == shard_ids[0]
await self.launch_shard(gateway, shard_id, initial=initial) await self.launch_shard(gateway, shard_id, initial=initial)
self._connection.shards_launched.set()
async def _async_setup_hook(self) -> None: async def _async_setup_hook(self) -> None:
await super()._async_setup_hook() await super()._async_setup_hook()
self.__queue = asyncio.PriorityQueue() self.__queue = asyncio.PriorityQueue()

155
discord/state.py

@ -27,7 +27,6 @@ from __future__ import annotations
import asyncio import asyncio
from collections import deque, OrderedDict from collections import deque, OrderedDict
import copy import copy
import itertools
import logging import logging
from typing import ( from typing import (
Dict, Dict,
@ -302,9 +301,6 @@ class ConnectionState:
else: else:
await coro(*args, **kwargs) await coro(*args, **kwargs)
async def async_setup(self) -> None:
pass
@property @property
def self_id(self) -> Optional[int]: def self_id(self) -> Optional[int]:
u = self.user u = self.user
@ -561,7 +557,7 @@ class ConnectionState:
if self._ready_task is not None: if self._ready_task is not None:
self._ready_task.cancel() self._ready_task.cancel()
self._ready_state = asyncio.Queue() self._ready_state: asyncio.Queue[Guild] = asyncio.Queue()
self.clear(views=False) self.clear(views=False)
self.user = user = ClientUser(state=self, data=data['user']) self.user = user = ClientUser(state=self, data=data['user'])
self._users[user.id] = user # type: ignore self._users[user.id] = user # type: ignore
@ -1111,6 +1107,15 @@ class ConnectionState:
else: else:
self.dispatch('guild_join', guild) self.dispatch('guild_join', guild)
def _add_ready_state(self, guild: Guild) -> bool:
try:
# Notify the on_ready state, if any, that this guild is complete.
self._ready_state.put_nowait(guild)
except AttributeError:
return False
else:
return True
def parse_guild_create(self, data: gw.GuildCreateEvent) -> None: def parse_guild_create(self, data: gw.GuildCreateEvent) -> None:
unavailable = data.get('unavailable') unavailable = data.get('unavailable')
if unavailable is True: if unavailable is True:
@ -1119,14 +1124,8 @@ class ConnectionState:
guild = self._get_create_guild(data) guild = self._get_create_guild(data)
try: if self._add_ready_state(guild):
# Notify the on_ready state, if any, that this guild is complete. return # We're waiting for the ready event, put the rest on hold
self._ready_state.put_nowait(guild)
except AttributeError:
pass
else:
# If we're waiting for the event, put the rest on hold
return
# check if it requires chunking # check if it requires chunking
if self._guild_needs_chunking(guild): if self._guild_needs_chunking(guild):
@ -1510,8 +1509,12 @@ class ConnectionState:
class AutoShardedConnectionState(ConnectionState): class AutoShardedConnectionState(ConnectionState):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.shard_ids: Union[List[int], range] = [] self.shard_ids: Union[List[int], range] = []
self._ready_tasks: Dict[int, asyncio.Task[None]] = {}
self._ready_states: Dict[int, asyncio.Queue[Guild]] = {}
def _update_message_references(self) -> None: def _update_message_references(self) -> None:
# self._messages won't be None when this is called # self._messages won't be None when this is called
for msg in self._messages: # type: ignore for msg in self._messages: # type: ignore
@ -1525,9 +1528,6 @@ class AutoShardedConnectionState(ConnectionState):
# channel will either be a TextChannel, Thread or Object # channel will either be a TextChannel, Thread or Object
msg._rebind_cached_references(new_guild, channel) # type: ignore msg._rebind_cached_references(new_guild, channel) # type: ignore
async def async_setup(self) -> None:
self.shards_launched: asyncio.Event = asyncio.Event()
async def chunker( async def chunker(
self, self,
guild_id: int, guild_id: int,
@ -1541,76 +1541,80 @@ class AutoShardedConnectionState(ConnectionState):
ws = self._get_websocket(guild_id, shard_id=shard_id) ws = self._get_websocket(guild_id, shard_id=shard_id)
await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce)
def _add_ready_state(self, guild: Guild) -> bool:
try:
# Notify the on_ready state, if any, that this guild is complete.
self._ready_states[guild.shard_id].put_nowait(guild)
except KeyError:
return False
else:
return True
async def _delay_ready(self) -> None: async def _delay_ready(self) -> None:
await self.shards_launched.wait() await asyncio.gather(*self._ready_tasks.values())
processed = []
max_concurrency = len(self.shard_ids) * 2 # clear the current tasks
current_bucket = [] self._ready_task = None
while True: self._ready_tasks = {}
# this snippet of code is basically waiting N seconds
# until the last GUILD_CREATE was sent # dispatch the event
try: self.call_handlers('ready')
guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout) self.dispatch('ready')
except asyncio.TimeoutError:
break async def _delay_shard_ready(self, shard_id: int) -> None:
else: try:
if self._guild_needs_chunking(guild): states = []
_log.debug('Guild ID %d requires chunking, will be done in the background.', guild.id) while True:
if len(current_bucket) >= max_concurrency: # this snippet of code is basically waiting N seconds
try: # until the last GUILD_CREATE was sent
await utils.sane_wait_for(current_bucket, timeout=max_concurrency * 70.0) try:
except asyncio.TimeoutError: guild = await asyncio.wait_for(self._ready_states[shard_id].get(), timeout=self.guild_ready_timeout)
fmt = 'Shard ID %s failed to wait for chunks from a sub-bucket with length %d' except asyncio.TimeoutError:
_log.warning(fmt, guild.shard_id, len(current_bucket)) break
finally:
current_bucket = []
# Chunk the guild in the background while we wait for GUILD_CREATE streaming
future = asyncio.ensure_future(self.chunk_guild(guild))
current_bucket.append(future)
else: else:
future = self.loop.create_future() if self._guild_needs_chunking(guild):
future.set_result([]) future = await self.chunk_guild(guild, wait=False)
states.append((guild, future))
else:
if guild.unavailable is False:
self.dispatch('guild_available', guild)
else:
self.dispatch('guild_join', guild)
processed.append((guild, future)) for guild, future in states:
try:
await asyncio.wait_for(future, timeout=5.0)
except asyncio.TimeoutError:
_log.warning('Shard ID %s timed out waiting for chunks for guild_id %s.', guild.shard_id, guild.id)
guilds = sorted(processed, key=lambda g: g[0].shard_id)
for shard_id, info in itertools.groupby(guilds, key=lambda g: g[0].shard_id):
children, futures = zip(*info)
# 110 reqs/minute w/ 1 req/guild plus some buffer
timeout = 61 * (len(children) / 110)
try:
await utils.sane_wait_for(futures, timeout=timeout)
except asyncio.TimeoutError:
_log.warning(
'Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds', shard_id, timeout, len(guilds)
)
for guild in children:
if guild.unavailable is False: if guild.unavailable is False:
self.dispatch('guild_available', guild) self.dispatch('guild_available', guild)
else: else:
self.dispatch('guild_join', guild) self.dispatch('guild_join', guild)
self.dispatch('shard_ready', shard_id) # remove the state
try:
del self._ready_states[shard_id]
except KeyError:
pass # already been deleted somehow
# remove the state except asyncio.CancelledError:
try: pass
del self._ready_state else:
except AttributeError: # dispatch the event
pass # already been deleted somehow self.dispatch('shard_ready', shard_id)
# regular users cannot shard so we won't worry about it here. def parse_ready(self, data: gw.ReadyEvent) -> None:
if self._ready_task is not None:
self._ready_task.cancel()
# clear the current task shard_id = data['shard'][0] # shard_id, num_shards
self._ready_task = None
# dispatch the event if shard_id in self._ready_tasks:
self.call_handlers('ready') self._ready_tasks[shard_id].cancel()
self.dispatch('ready')
def parse_ready(self, data: gw.ReadyEvent) -> None: if shard_id not in self._ready_states:
if not hasattr(self, '_ready_state'): self._ready_states[shard_id] = asyncio.Queue()
self._ready_state = asyncio.Queue()
self.user: Optional[ClientUser] self.user: Optional[ClientUser]
self.user = user = ClientUser(state=self, data=data['user']) self.user = user = ClientUser(state=self, data=data['user'])
@ -1633,9 +1637,12 @@ class AutoShardedConnectionState(ConnectionState):
self._update_message_references() self._update_message_references()
self.dispatch('connect') self.dispatch('connect')
self.dispatch('shard_connect', data['__shard_id__']) # type: ignore # This is an internal discord.py key self.dispatch('shard_connect', shard_id)
self._ready_tasks[shard_id] = asyncio.create_task(self._delay_shard_ready(shard_id))
if self._ready_task is None: # The delay task for every shard has been started
if len(self._ready_tasks) == len(self.shard_ids):
self._ready_task = asyncio.create_task(self._delay_ready()) self._ready_task = asyncio.create_task(self._delay_ready())
def parse_resumed(self, data: gw.ResumedEvent) -> None: def parse_resumed(self, data: gw.ResumedEvent) -> None:

7
discord/types/gateway.py

@ -60,17 +60,12 @@ class GatewayBot(Gateway):
session_start_limit: SessionStartLimit session_start_limit: SessionStartLimit
class ShardInfo(TypedDict):
shard_id: int
shard_count: int
class ReadyEvent(TypedDict): class ReadyEvent(TypedDict):
v: int v: int
user: User user: User
guilds: List[UnavailableGuild] guilds: List[UnavailableGuild]
session_id: str session_id: str
shard: ShardInfo shard: List[int] # shard_id, num_shards
application: GatewayAppInfo application: GatewayAppInfo

Loading…
Cancel
Save