From 963dd8aad50e4fc7c7e25386f858a3176900448d Mon Sep 17 00:00:00 2001 From: dolfies Date: Fri, 25 Aug 2023 18:27:49 +0300 Subject: [PATCH] Fix message search in guild channels and index not available retry --- discord/abc.py | 37 +++++++++++++++++++++++++++++++++---- discord/http.py | 11 +++++++++-- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/discord/abc.py b/discord/abc.py index d0b6236ca..7cc80f52e 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -103,7 +103,7 @@ if TYPE_CHECKING: OverwriteType, ) from .types.embed import EmbedType - from .types.message import MessageSearchAuthorType, MessageSearchHasType + from .types.message import MessageSearchAuthorType, MessageSearchHasType, PartialMessage as PartialMessagePayload from .types.snowflake import ( SnowflakeList, ) @@ -312,14 +312,41 @@ async def _handle_message_search( oldest_first: bool = False, most_relevant: bool = False, ) -> AsyncIterator[Message]: + from .channel import PartialMessageable # circular import + if limit is not None and limit < 0: raise ValueError('limit must be greater than or equal to 0') if offset < 0: raise ValueError('offset must be greater than or equal to 0') + # Guild channels must go through the guild search endpoint _state = destination._state - endpoint = _state.http.search_channel if isinstance(destination, Messageable) else _state.http.search_guild - entity_id = (await destination._get_channel()).id if isinstance(destination, Messageable) else destination.id + endpoint = _state.http.search_guild + entity_id = None + channel = None + if isinstance(destination, Messageable): + channel = await destination._get_channel() + if isinstance(channel, PrivateChannel): + endpoint = _state.http.search_channel + entity_id = channel.id + else: + channels = [channel] + entity_id = getattr(channel.guild, 'id', getattr(channel, 'guild_id', None)) + else: + entity_id = destination.id + if not entity_id: + raise ValueError('Could not resolve channel guild ID') + + _channels = {c.id: c for c in channels} if channels else {} + if channel: + _channels[channel.id] = channel + + def _resolve_channel(message: PartialMessagePayload, /): + _channel, _ = _state._get_guild_channel(message) + if isinstance(_channel, PartialMessageable) and _channel.id in _channels: + return _channels[_channel.id] + return _channel + payload = {} if isinstance(before, datetime): @@ -409,7 +436,7 @@ async def _handle_message_search( if channel_id in threads: raw_message['thread'] = threads[channel_id] - channel, _ = _state._get_guild_channel(raw_message) + channel = _resolve_channel(raw_message) yield _state.create_message(channel=channel, data=raw_message, search_result=data) # type: ignore @@ -2336,6 +2363,8 @@ class Messageable: You do not have permissions to search the channel's messages. ~discord.HTTPException The request to search messages failed. + ValueError + Could not resolve the channel's guild ID. Yields ------- diff --git a/discord/http.py b/discord/http.py index f86f53252..a6942b98e 100644 --- a/discord/http.py +++ b/discord/http.py @@ -840,9 +840,16 @@ class HTTPClient: ) # 202s must be retried - if response.status == 202 and isinstance(data, dict) and 'retry_after' in data: + if response.status == 202 and isinstance(data, dict) and data['code'] == 110000: + # We update the `attempts` query parameter + params = kwargs.get('params') + if not params: + kwargs['params'] = {'attempts': 1} + else: + params['attempts'] = (params.get('attempts') or 0) + 1 + # Sometimes retry_after is 0, but that's undesirable - retry_after: float = data['retry_after'] or 0.25 + retry_after: float = data['retry_after'] or 5 _log.debug('%s %s received a 202. Retrying in %s seconds...', method, url, retry_after) await asyncio.sleep(retry_after) continue