|
@ -27,7 +27,6 @@ from __future__ import annotations |
|
|
import asyncio |
|
|
import asyncio |
|
|
from collections import deque, OrderedDict |
|
|
from collections import deque, OrderedDict |
|
|
import copy |
|
|
import copy |
|
|
import itertools |
|
|
|
|
|
import logging |
|
|
import logging |
|
|
from typing import ( |
|
|
from typing import ( |
|
|
Dict, |
|
|
Dict, |
|
@ -302,9 +301,6 @@ class ConnectionState: |
|
|
else: |
|
|
else: |
|
|
await coro(*args, **kwargs) |
|
|
await coro(*args, **kwargs) |
|
|
|
|
|
|
|
|
async def async_setup(self) -> None: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def self_id(self) -> Optional[int]: |
|
|
def self_id(self) -> Optional[int]: |
|
|
u = self.user |
|
|
u = self.user |
|
@ -561,7 +557,7 @@ class ConnectionState: |
|
|
if self._ready_task is not None: |
|
|
if self._ready_task is not None: |
|
|
self._ready_task.cancel() |
|
|
self._ready_task.cancel() |
|
|
|
|
|
|
|
|
self._ready_state = asyncio.Queue() |
|
|
self._ready_state: asyncio.Queue[Guild] = asyncio.Queue() |
|
|
self.clear(views=False) |
|
|
self.clear(views=False) |
|
|
self.user = user = ClientUser(state=self, data=data['user']) |
|
|
self.user = user = ClientUser(state=self, data=data['user']) |
|
|
self._users[user.id] = user # type: ignore |
|
|
self._users[user.id] = user # type: ignore |
|
@ -1111,6 +1107,15 @@ class ConnectionState: |
|
|
else: |
|
|
else: |
|
|
self.dispatch('guild_join', guild) |
|
|
self.dispatch('guild_join', guild) |
|
|
|
|
|
|
|
|
|
|
|
def _add_ready_state(self, guild: Guild) -> bool: |
|
|
|
|
|
try: |
|
|
|
|
|
# Notify the on_ready state, if any, that this guild is complete. |
|
|
|
|
|
self._ready_state.put_nowait(guild) |
|
|
|
|
|
except AttributeError: |
|
|
|
|
|
return False |
|
|
|
|
|
else: |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def parse_guild_create(self, data: gw.GuildCreateEvent) -> None: |
|
|
def parse_guild_create(self, data: gw.GuildCreateEvent) -> None: |
|
|
unavailable = data.get('unavailable') |
|
|
unavailable = data.get('unavailable') |
|
|
if unavailable is True: |
|
|
if unavailable is True: |
|
@ -1119,14 +1124,8 @@ class ConnectionState: |
|
|
|
|
|
|
|
|
guild = self._get_create_guild(data) |
|
|
guild = self._get_create_guild(data) |
|
|
|
|
|
|
|
|
try: |
|
|
if self._add_ready_state(guild): |
|
|
# Notify the on_ready state, if any, that this guild is complete. |
|
|
return # We're waiting for the ready event, put the rest on hold |
|
|
self._ready_state.put_nowait(guild) |
|
|
|
|
|
except AttributeError: |
|
|
|
|
|
pass |
|
|
|
|
|
else: |
|
|
|
|
|
# If we're waiting for the event, put the rest on hold |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
# check if it requires chunking |
|
|
# check if it requires chunking |
|
|
if self._guild_needs_chunking(guild): |
|
|
if self._guild_needs_chunking(guild): |
|
@ -1510,8 +1509,12 @@ class ConnectionState: |
|
|
class AutoShardedConnectionState(ConnectionState): |
|
|
class AutoShardedConnectionState(ConnectionState): |
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None: |
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None: |
|
|
super().__init__(*args, **kwargs) |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
|
|
|
self.shard_ids: Union[List[int], range] = [] |
|
|
self.shard_ids: Union[List[int], range] = [] |
|
|
|
|
|
|
|
|
|
|
|
self._ready_tasks: Dict[int, asyncio.Task[None]] = {} |
|
|
|
|
|
self._ready_states: Dict[int, asyncio.Queue[Guild]] = {} |
|
|
|
|
|
|
|
|
def _update_message_references(self) -> None: |
|
|
def _update_message_references(self) -> None: |
|
|
# self._messages won't be None when this is called |
|
|
# self._messages won't be None when this is called |
|
|
for msg in self._messages: # type: ignore |
|
|
for msg in self._messages: # type: ignore |
|
@ -1525,9 +1528,6 @@ class AutoShardedConnectionState(ConnectionState): |
|
|
# channel will either be a TextChannel, Thread or Object |
|
|
# channel will either be a TextChannel, Thread or Object |
|
|
msg._rebind_cached_references(new_guild, channel) # type: ignore |
|
|
msg._rebind_cached_references(new_guild, channel) # type: ignore |
|
|
|
|
|
|
|
|
async def async_setup(self) -> None: |
|
|
|
|
|
self.shards_launched: asyncio.Event = asyncio.Event() |
|
|
|
|
|
|
|
|
|
|
|
async def chunker( |
|
|
async def chunker( |
|
|
self, |
|
|
self, |
|
|
guild_id: int, |
|
|
guild_id: int, |
|
@ -1541,76 +1541,80 @@ class AutoShardedConnectionState(ConnectionState): |
|
|
ws = self._get_websocket(guild_id, shard_id=shard_id) |
|
|
ws = self._get_websocket(guild_id, shard_id=shard_id) |
|
|
await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) |
|
|
await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) |
|
|
|
|
|
|
|
|
|
|
|
def _add_ready_state(self, guild: Guild) -> bool: |
|
|
|
|
|
try: |
|
|
|
|
|
# Notify the on_ready state, if any, that this guild is complete. |
|
|
|
|
|
self._ready_states[guild.shard_id].put_nowait(guild) |
|
|
|
|
|
except KeyError: |
|
|
|
|
|
return False |
|
|
|
|
|
else: |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
async def _delay_ready(self) -> None: |
|
|
async def _delay_ready(self) -> None: |
|
|
await self.shards_launched.wait() |
|
|
await asyncio.gather(*self._ready_tasks.values()) |
|
|
processed = [] |
|
|
|
|
|
max_concurrency = len(self.shard_ids) * 2 |
|
|
# clear the current tasks |
|
|
current_bucket = [] |
|
|
self._ready_task = None |
|
|
while True: |
|
|
self._ready_tasks = {} |
|
|
# this snippet of code is basically waiting N seconds |
|
|
|
|
|
# until the last GUILD_CREATE was sent |
|
|
# dispatch the event |
|
|
try: |
|
|
self.call_handlers('ready') |
|
|
guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout) |
|
|
self.dispatch('ready') |
|
|
except asyncio.TimeoutError: |
|
|
|
|
|
break |
|
|
async def _delay_shard_ready(self, shard_id: int) -> None: |
|
|
else: |
|
|
try: |
|
|
if self._guild_needs_chunking(guild): |
|
|
states = [] |
|
|
_log.debug('Guild ID %d requires chunking, will be done in the background.', guild.id) |
|
|
while True: |
|
|
if len(current_bucket) >= max_concurrency: |
|
|
# this snippet of code is basically waiting N seconds |
|
|
try: |
|
|
# until the last GUILD_CREATE was sent |
|
|
await utils.sane_wait_for(current_bucket, timeout=max_concurrency * 70.0) |
|
|
try: |
|
|
except asyncio.TimeoutError: |
|
|
guild = await asyncio.wait_for(self._ready_states[shard_id].get(), timeout=self.guild_ready_timeout) |
|
|
fmt = 'Shard ID %s failed to wait for chunks from a sub-bucket with length %d' |
|
|
except asyncio.TimeoutError: |
|
|
_log.warning(fmt, guild.shard_id, len(current_bucket)) |
|
|
break |
|
|
finally: |
|
|
|
|
|
current_bucket = [] |
|
|
|
|
|
|
|
|
|
|
|
# Chunk the guild in the background while we wait for GUILD_CREATE streaming |
|
|
|
|
|
future = asyncio.ensure_future(self.chunk_guild(guild)) |
|
|
|
|
|
current_bucket.append(future) |
|
|
|
|
|
else: |
|
|
else: |
|
|
future = self.loop.create_future() |
|
|
if self._guild_needs_chunking(guild): |
|
|
future.set_result([]) |
|
|
future = await self.chunk_guild(guild, wait=False) |
|
|
|
|
|
states.append((guild, future)) |
|
|
|
|
|
else: |
|
|
|
|
|
if guild.unavailable is False: |
|
|
|
|
|
self.dispatch('guild_available', guild) |
|
|
|
|
|
else: |
|
|
|
|
|
self.dispatch('guild_join', guild) |
|
|
|
|
|
|
|
|
processed.append((guild, future)) |
|
|
for guild, future in states: |
|
|
|
|
|
try: |
|
|
|
|
|
await asyncio.wait_for(future, timeout=5.0) |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
|
|
|
_log.warning('Shard ID %s timed out waiting for chunks for guild_id %s.', guild.shard_id, guild.id) |
|
|
|
|
|
|
|
|
guilds = sorted(processed, key=lambda g: g[0].shard_id) |
|
|
|
|
|
for shard_id, info in itertools.groupby(guilds, key=lambda g: g[0].shard_id): |
|
|
|
|
|
children, futures = zip(*info) |
|
|
|
|
|
# 110 reqs/minute w/ 1 req/guild plus some buffer |
|
|
|
|
|
timeout = 61 * (len(children) / 110) |
|
|
|
|
|
try: |
|
|
|
|
|
await utils.sane_wait_for(futures, timeout=timeout) |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
|
|
|
_log.warning( |
|
|
|
|
|
'Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds', shard_id, timeout, len(guilds) |
|
|
|
|
|
) |
|
|
|
|
|
for guild in children: |
|
|
|
|
|
if guild.unavailable is False: |
|
|
if guild.unavailable is False: |
|
|
self.dispatch('guild_available', guild) |
|
|
self.dispatch('guild_available', guild) |
|
|
else: |
|
|
else: |
|
|
self.dispatch('guild_join', guild) |
|
|
self.dispatch('guild_join', guild) |
|
|
|
|
|
|
|
|
self.dispatch('shard_ready', shard_id) |
|
|
# remove the state |
|
|
|
|
|
try: |
|
|
|
|
|
del self._ready_states[shard_id] |
|
|
|
|
|
except KeyError: |
|
|
|
|
|
pass # already been deleted somehow |
|
|
|
|
|
|
|
|
# remove the state |
|
|
except asyncio.CancelledError: |
|
|
try: |
|
|
pass |
|
|
del self._ready_state |
|
|
else: |
|
|
except AttributeError: |
|
|
# dispatch the event |
|
|
pass # already been deleted somehow |
|
|
self.dispatch('shard_ready', shard_id) |
|
|
|
|
|
|
|
|
# regular users cannot shard so we won't worry about it here. |
|
|
def parse_ready(self, data: gw.ReadyEvent) -> None: |
|
|
|
|
|
if self._ready_task is not None: |
|
|
|
|
|
self._ready_task.cancel() |
|
|
|
|
|
|
|
|
# clear the current task |
|
|
shard_id = data['shard'][0] # shard_id, num_shards |
|
|
self._ready_task = None |
|
|
|
|
|
|
|
|
|
|
|
# dispatch the event |
|
|
if shard_id in self._ready_tasks: |
|
|
self.call_handlers('ready') |
|
|
self._ready_tasks[shard_id].cancel() |
|
|
self.dispatch('ready') |
|
|
|
|
|
|
|
|
|
|
|
def parse_ready(self, data: gw.ReadyEvent) -> None: |
|
|
if shard_id not in self._ready_states: |
|
|
if not hasattr(self, '_ready_state'): |
|
|
self._ready_states[shard_id] = asyncio.Queue() |
|
|
self._ready_state = asyncio.Queue() |
|
|
|
|
|
|
|
|
|
|
|
self.user: Optional[ClientUser] |
|
|
self.user: Optional[ClientUser] |
|
|
self.user = user = ClientUser(state=self, data=data['user']) |
|
|
self.user = user = ClientUser(state=self, data=data['user']) |
|
@ -1633,9 +1637,12 @@ class AutoShardedConnectionState(ConnectionState): |
|
|
self._update_message_references() |
|
|
self._update_message_references() |
|
|
|
|
|
|
|
|
self.dispatch('connect') |
|
|
self.dispatch('connect') |
|
|
self.dispatch('shard_connect', data['__shard_id__']) # type: ignore # This is an internal discord.py key |
|
|
self.dispatch('shard_connect', shard_id) |
|
|
|
|
|
|
|
|
|
|
|
self._ready_tasks[shard_id] = asyncio.create_task(self._delay_shard_ready(shard_id)) |
|
|
|
|
|
|
|
|
if self._ready_task is None: |
|
|
# The delay task for every shard has been started |
|
|
|
|
|
if len(self._ready_tasks) == len(self.shard_ids): |
|
|
self._ready_task = asyncio.create_task(self._delay_ready()) |
|
|
self._ready_task = asyncio.create_task(self._delay_ready()) |
|
|
|
|
|
|
|
|
def parse_resumed(self, data: gw.ResumedEvent) -> None: |
|
|
def parse_resumed(self, data: gw.ResumedEvent) -> None: |
|
|