diff --git a/discord/abc.py b/discord/abc.py index bd8ae6b0f..d8f68fa2c 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -27,6 +27,7 @@ from __future__ import annotations import copy import asyncio from datetime import datetime +import logging from operator import attrgetter from typing import ( Any, @@ -75,6 +76,9 @@ __all__ = ( ) T = TypeVar('T', bound=VoiceProtocol) +MISSING = utils.MISSING + +_log = logging.getLogger(__name__) if TYPE_CHECKING: from typing_extensions import Self @@ -114,8 +118,6 @@ if TYPE_CHECKING: VocalChannel = Union[VoiceChannel, StageChannel, DMChannel, GroupChannel] SnowflakeTime = Union["Snowflake", datetime] -MISSING = utils.MISSING - class _Undefined: def __repr__(self) -> str: @@ -279,9 +281,13 @@ async def _handle_message_search( attachment_filenames: Collection[str] = MISSING, attachment_extensions: Collection[str] = MISSING, application_commands: Collection[Snowflake] = MISSING, - oldest_first: bool = False, + oldest_first: bool = MISSING, most_relevant: bool = False, ) -> AsyncIterator[Message]: + # Important note for message search: + # The endpoint might sometimes time out while waiting for messages + # This will manifest as less results than the limit, even if there are more messages to be found + from .channel import PartialMessageable # circular import if limit is not None and limit < 0: @@ -321,6 +327,11 @@ async def _handle_message_search( before = Object(id=utils.time_snowflake(before, high=False)) if isinstance(after, datetime): after = Object(id=utils.time_snowflake(after, high=True)) + + after = after or OLDEST_OBJECT + if oldest_first is MISSING: + oldest_first = after != OLDEST_OBJECT + if ( include_nsfw is MISSING and not isinstance(destination, Messageable) @@ -329,10 +340,8 @@ async def _handle_message_search( ): include_nsfw = _state.user.nsfw_allowed - if before: - payload['max_id'] = before.id - if after: - payload['min_id'] = after.id + if offset: + payload['offset'] = offset if include_nsfw is not MISSING: payload['include_nsfw'] = str(include_nsfw).lower() if content: @@ -366,18 +375,75 @@ async def _handle_message_search( if oldest_first: payload['sort_order'] = 'asc' if most_relevant: + # This is the default and it isn't respected anyway, but this ep is cursed enough as it is + # So we will go with what the client does + payload['sort_order'] = 'desc' payload['sort_by'] = 'relevance' + async def _state_strategy(retrieve: int, state: Optional[Snowflake], limit: Optional[int]): + payload['limit'] = retrieve + if oldest_first and state: + payload['min_id'] = state.id + elif state: + payload['max_id'] = state.id + data = await endpoint(entity_id, payload) + + if data['messages']: + if limit is not None: + limit -= len(data['messages']) + + state = Object(id=int(data['messages'][-1][0]['id'])) + + return data, state, limit + + async def _relevance_strategy(retrieve: int, _, limit: Optional[int]): + payload['limit'] = retrieve + data = await endpoint(entity_id, payload) + + if data['messages']: + length = len(data['messages']) + if limit is not None: + limit -= length + payload['offset'] = (offset or 0) + length + + return data, None, limit + + predicate = None + + if most_relevant: + strategy, state = _relevance_strategy, None + if before and after != OLDEST_OBJECT: + raise TypeError('Cannot use both before and after with most_relevant') + if before: + payload['max_id'] = before.id + if after != OLDEST_OBJECT: + payload['min_id'] = after.id + elif oldest_first: + strategy, state = _state_strategy, after + if before: + predicate = lambda m: int(m[0]['id']) < before.id + else: + strategy, state = _state_strategy, before + if after and after != OLDEST_OBJECT: + predicate = lambda m: int(m[0]['id']) > after.id + + total_results = MISSING + total = 0 + while True: - retrieve = min(25 if limit is None else limit, 25) + retrieve = 25 if limit is None else min(limit, 25) if retrieve < 1: return - if retrieve != 25: - payload['limit'] = retrieve - if offset: - payload['offset'] = offset - data = await endpoint(entity_id, payload) + data, state, limit = await strategy(retrieve, state, limit) + + if total_results is MISSING: + total_results = data['total_results'] + + messages = data['messages'] + if predicate: + messages = filter(predicate, messages) + threads = {int(thread['id']): thread for thread in data.get('threads', [])} for member in data.get('members', []): thread_id = int(member['id']) @@ -385,17 +451,11 @@ async def _handle_message_search( if thread: thread['member'] = member - length = len(data['messages']) - offset += length - if limit is not None: - limit -= length + count = 0 - # Terminate loop on next iteration; there's no data left after this - if len(data['messages']) < 25: - limit = 0 - - for raw_messages in data['messages']: + for count, raw_messages in enumerate(messages, 1): if not raw_messages: + _log.debug('Search for %s with payload %s yielded an empty subarray.', destination, payload) continue # Context is no longer sent, so this is probably fine @@ -407,6 +467,13 @@ async def _handle_message_search( channel = _resolve_channel(raw_message) yield _state.create_message(channel=channel, data=raw_message, search_result=data) # type: ignore + if count == 0: + return + + total += count + if total >= total_results: + return + @runtime_checkable class Snowflake(Protocol): @@ -2317,7 +2384,7 @@ class Messageable: attachment_filenames: Collection[str] = MISSING, attachment_extensions: Collection[str] = MISSING, application_commands: Collection[Snowflake] = MISSING, - oldest_first: bool = False, + oldest_first: bool = MISSING, most_relevant: bool = False, ) -> AsyncIterator[Message]: """Returns an :term:`asynchronous iterator` that enables searching the channel's messages. @@ -2356,9 +2423,7 @@ class Messageable: limit: Optional[:class:`int`] The number of messages to retrieve. If ``None``, retrieves every message in the results. Note, however, - that this would make it a slow operation. Additionally, note that the - search API has a maximum pagination offset of 5000 (subject to change), - so a limit of over 5000 or ``None`` may eventually raise an exception. + that this would make it a slow operation. offset: :class:`int` The pagination offset to start at. before: Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`] @@ -2397,17 +2462,18 @@ class Messageable: application_commands: List[:class:`~discord.abc.ApplicationCommand`] The used application commands to filter by. oldest_first: :class:`bool` - Whether to return the oldest results first. + Whether to return the oldest results first. Defaults to ``True`` if + ``after`` is specified, otherwise ``False``. Ignored when ``most_relevant`` is set. most_relevant: :class:`bool` - Whether to sort the results by relevance. Using this with ``oldest_first`` - will return the least relevant results first. + Whether to sort the results by relevance. Limits pagination to 9975 entries. + Prevents using both ``before`` and ``after``. Raises ------ - ~discord.Forbidden - You do not have permissions to search the channel's messages. ~discord.HTTPException The request to search messages failed. + TypeError + Provided both ``before`` and ``after`` when ``most_relevant`` is set. ValueError Could not resolve the channel's guild ID. diff --git a/discord/guild.py b/discord/guild.py index 7f9f68fb7..ba41a950e 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -2951,7 +2951,7 @@ class Guild(Hashable): attachment_filenames: Collection[str] = MISSING, attachment_extensions: Collection[str] = MISSING, application_commands: Collection[Snowflake] = MISSING, - oldest_first: bool = False, + oldest_first: bool = MISSING, most_relevant: bool = False, ) -> AsyncIterator[Message]: """Returns an :term:`asynchronous iterator` that enables searching the guild's messages. @@ -2990,9 +2990,7 @@ class Guild(Hashable): limit: Optional[:class:`int`] The number of messages to retrieve. If ``None``, retrieves every message in the results. Note, however, - that this would make it a slow operation. Additionally, note that the - search API has a maximum pagination offset of 5000 (subject to change), - so a limit of over 5000 or ``None`` may eventually raise an exception. + that this would make it a slow operation. offset: :class:`int` The pagination offset to start at. before: Union[:class:`abc.Snowflake`, :class:`datetime.datetime`] @@ -3035,17 +3033,20 @@ class Guild(Hashable): application_commands: List[:class:`abc.ApplicationCommand`] The used application commands to filter by. oldest_first: :class:`bool` - Whether to return the oldest results first. + Whether to return the oldest results first. Defaults to ``True`` if + ``before`` is specified, otherwise ``False``. Ignored when ``most_relevant`` is set. most_relevant: :class:`bool` - Whether to sort the results by relevance. Using this with ``oldest_first`` - will return the least relevant results first. + Whether to sort the results by relevance. Limits pagination to 9975 entries. + Prevents using both ``before`` and ``after``. Raises ------ - Forbidden - You do not have permissions to search the channel's messages. - HTTPException + ~discord.HTTPException The request to search messages failed. + TypeError + Provided both ``before`` and ``after`` when ``most_relevant`` is set. + ValueError + Could not resolve the channel's guild ID. Yields -------