Browse Source

Separately delay ready event for each shard

pull/7932/head
Lilly Rose Berner 3 years ago
committed by GitHub
parent
commit
2dbf14bb72
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      discord/client.py
  2. 2
      discord/gateway.py
  3. 2
      discord/shard.py
  4. 155
      discord/state.py
  5. 7
      discord/types/gateway.py

1
discord/client.py

@ -484,7 +484,6 @@ class Client:
self.loop = loop
self.http.loop = loop
self._connection.loop = loop
await self._connection.async_setup()
self._ready = asyncio.Event()

2
discord/gateway.py

@ -546,8 +546,6 @@ class DiscordWebSocket:
self._trace = trace = data.get('_trace', [])
self.sequence = msg['s']
self.session_id = data['session_id']
# pass back shard ID to ready handler
data['__shard_id__'] = self.shard_id
_log.info(
'Shard ID %s has connected to Gateway: %s (Session ID: %s).',
self.shard_id,

2
discord/shard.py

@ -423,8 +423,6 @@ class AutoShardedClient(Client):
initial = shard_id == shard_ids[0]
await self.launch_shard(gateway, shard_id, initial=initial)
self._connection.shards_launched.set()
async def _async_setup_hook(self) -> None:
await super()._async_setup_hook()
self.__queue = asyncio.PriorityQueue()

155
discord/state.py

@ -27,7 +27,6 @@ from __future__ import annotations
import asyncio
from collections import deque, OrderedDict
import copy
import itertools
import logging
from typing import (
Dict,
@ -302,9 +301,6 @@ class ConnectionState:
else:
await coro(*args, **kwargs)
async def async_setup(self) -> None:
pass
@property
def self_id(self) -> Optional[int]:
u = self.user
@ -561,7 +557,7 @@ class ConnectionState:
if self._ready_task is not None:
self._ready_task.cancel()
self._ready_state = asyncio.Queue()
self._ready_state: asyncio.Queue[Guild] = asyncio.Queue()
self.clear(views=False)
self.user = user = ClientUser(state=self, data=data['user'])
self._users[user.id] = user # type: ignore
@ -1111,6 +1107,15 @@ class ConnectionState:
else:
self.dispatch('guild_join', guild)
def _add_ready_state(self, guild: Guild) -> bool:
try:
# Notify the on_ready state, if any, that this guild is complete.
self._ready_state.put_nowait(guild)
except AttributeError:
return False
else:
return True
def parse_guild_create(self, data: gw.GuildCreateEvent) -> None:
unavailable = data.get('unavailable')
if unavailable is True:
@ -1119,14 +1124,8 @@ class ConnectionState:
guild = self._get_create_guild(data)
try:
# Notify the on_ready state, if any, that this guild is complete.
self._ready_state.put_nowait(guild)
except AttributeError:
pass
else:
# If we're waiting for the event, put the rest on hold
return
if self._add_ready_state(guild):
return # We're waiting for the ready event, put the rest on hold
# check if it requires chunking
if self._guild_needs_chunking(guild):
@ -1510,8 +1509,12 @@ class ConnectionState:
class AutoShardedConnectionState(ConnectionState):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.shard_ids: Union[List[int], range] = []
self._ready_tasks: Dict[int, asyncio.Task[None]] = {}
self._ready_states: Dict[int, asyncio.Queue[Guild]] = {}
def _update_message_references(self) -> None:
# self._messages won't be None when this is called
for msg in self._messages: # type: ignore
@ -1525,9 +1528,6 @@ class AutoShardedConnectionState(ConnectionState):
# channel will either be a TextChannel, Thread or Object
msg._rebind_cached_references(new_guild, channel) # type: ignore
async def async_setup(self) -> None:
self.shards_launched: asyncio.Event = asyncio.Event()
async def chunker(
self,
guild_id: int,
@ -1541,76 +1541,80 @@ class AutoShardedConnectionState(ConnectionState):
ws = self._get_websocket(guild_id, shard_id=shard_id)
await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce)
def _add_ready_state(self, guild: Guild) -> bool:
try:
# Notify the on_ready state, if any, that this guild is complete.
self._ready_states[guild.shard_id].put_nowait(guild)
except KeyError:
return False
else:
return True
async def _delay_ready(self) -> None:
await self.shards_launched.wait()
processed = []
max_concurrency = len(self.shard_ids) * 2
current_bucket = []
while True:
# this snippet of code is basically waiting N seconds
# until the last GUILD_CREATE was sent
try:
guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout)
except asyncio.TimeoutError:
break
else:
if self._guild_needs_chunking(guild):
_log.debug('Guild ID %d requires chunking, will be done in the background.', guild.id)
if len(current_bucket) >= max_concurrency:
try:
await utils.sane_wait_for(current_bucket, timeout=max_concurrency * 70.0)
except asyncio.TimeoutError:
fmt = 'Shard ID %s failed to wait for chunks from a sub-bucket with length %d'
_log.warning(fmt, guild.shard_id, len(current_bucket))
finally:
current_bucket = []
# Chunk the guild in the background while we wait for GUILD_CREATE streaming
future = asyncio.ensure_future(self.chunk_guild(guild))
current_bucket.append(future)
await asyncio.gather(*self._ready_tasks.values())
# clear the current tasks
self._ready_task = None
self._ready_tasks = {}
# dispatch the event
self.call_handlers('ready')
self.dispatch('ready')
async def _delay_shard_ready(self, shard_id: int) -> None:
try:
states = []
while True:
# this snippet of code is basically waiting N seconds
# until the last GUILD_CREATE was sent
try:
guild = await asyncio.wait_for(self._ready_states[shard_id].get(), timeout=self.guild_ready_timeout)
except asyncio.TimeoutError:
break
else:
future = self.loop.create_future()
future.set_result([])
if self._guild_needs_chunking(guild):
future = await self.chunk_guild(guild, wait=False)
states.append((guild, future))
else:
if guild.unavailable is False:
self.dispatch('guild_available', guild)
else:
self.dispatch('guild_join', guild)
processed.append((guild, future))
for guild, future in states:
try:
await asyncio.wait_for(future, timeout=5.0)
except asyncio.TimeoutError:
_log.warning('Shard ID %s timed out waiting for chunks for guild_id %s.', guild.shard_id, guild.id)
guilds = sorted(processed, key=lambda g: g[0].shard_id)
for shard_id, info in itertools.groupby(guilds, key=lambda g: g[0].shard_id):
children, futures = zip(*info)
# 110 reqs/minute w/ 1 req/guild plus some buffer
timeout = 61 * (len(children) / 110)
try:
await utils.sane_wait_for(futures, timeout=timeout)
except asyncio.TimeoutError:
_log.warning(
'Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds', shard_id, timeout, len(guilds)
)
for guild in children:
if guild.unavailable is False:
self.dispatch('guild_available', guild)
else:
self.dispatch('guild_join', guild)
self.dispatch('shard_ready', shard_id)
# remove the state
try:
del self._ready_states[shard_id]
except KeyError:
pass # already been deleted somehow
# remove the state
try:
del self._ready_state
except AttributeError:
pass # already been deleted somehow
except asyncio.CancelledError:
pass
else:
# dispatch the event
self.dispatch('shard_ready', shard_id)
# regular users cannot shard so we won't worry about it here.
def parse_ready(self, data: gw.ReadyEvent) -> None:
if self._ready_task is not None:
self._ready_task.cancel()
# clear the current task
self._ready_task = None
shard_id = data['shard'][0] # shard_id, num_shards
# dispatch the event
self.call_handlers('ready')
self.dispatch('ready')
if shard_id in self._ready_tasks:
self._ready_tasks[shard_id].cancel()
def parse_ready(self, data: gw.ReadyEvent) -> None:
if not hasattr(self, '_ready_state'):
self._ready_state = asyncio.Queue()
if shard_id not in self._ready_states:
self._ready_states[shard_id] = asyncio.Queue()
self.user: Optional[ClientUser]
self.user = user = ClientUser(state=self, data=data['user'])
@ -1633,9 +1637,12 @@ class AutoShardedConnectionState(ConnectionState):
self._update_message_references()
self.dispatch('connect')
self.dispatch('shard_connect', data['__shard_id__']) # type: ignore # This is an internal discord.py key
self.dispatch('shard_connect', shard_id)
self._ready_tasks[shard_id] = asyncio.create_task(self._delay_shard_ready(shard_id))
if self._ready_task is None:
# The delay task for every shard has been started
if len(self._ready_tasks) == len(self.shard_ids):
self._ready_task = asyncio.create_task(self._delay_ready())
def parse_resumed(self, data: gw.ResumedEvent) -> None:

7
discord/types/gateway.py

@ -60,17 +60,12 @@ class GatewayBot(Gateway):
session_start_limit: SessionStartLimit
class ShardInfo(TypedDict):
shard_id: int
shard_count: int
class ReadyEvent(TypedDict):
v: int
user: User
guilds: List[UnavailableGuild]
session_id: str
shard: ShardInfo
shard: List[int] # shard_id, num_shards
application: GatewayAppInfo

Loading…
Cancel
Save