Browse Source

Immensely improve robustness of search pagination (fix #581)

pull/10109/head
dolfies 5 months ago
parent
commit
047fc48fb1
  1. 128
      discord/abc.py
  2. 21
      discord/guild.py

128
discord/abc.py

@ -27,6 +27,7 @@ from __future__ import annotations
import copy import copy
import asyncio import asyncio
from datetime import datetime from datetime import datetime
import logging
from operator import attrgetter from operator import attrgetter
from typing import ( from typing import (
Any, Any,
@ -75,6 +76,9 @@ __all__ = (
) )
T = TypeVar('T', bound=VoiceProtocol) T = TypeVar('T', bound=VoiceProtocol)
MISSING = utils.MISSING
_log = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self from typing_extensions import Self
@ -114,8 +118,6 @@ if TYPE_CHECKING:
VocalChannel = Union[VoiceChannel, StageChannel, DMChannel, GroupChannel] VocalChannel = Union[VoiceChannel, StageChannel, DMChannel, GroupChannel]
SnowflakeTime = Union["Snowflake", datetime] SnowflakeTime = Union["Snowflake", datetime]
MISSING = utils.MISSING
class _Undefined: class _Undefined:
def __repr__(self) -> str: def __repr__(self) -> str:
@ -279,9 +281,13 @@ async def _handle_message_search(
attachment_filenames: Collection[str] = MISSING, attachment_filenames: Collection[str] = MISSING,
attachment_extensions: Collection[str] = MISSING, attachment_extensions: Collection[str] = MISSING,
application_commands: Collection[Snowflake] = MISSING, application_commands: Collection[Snowflake] = MISSING,
oldest_first: bool = False, oldest_first: bool = MISSING,
most_relevant: bool = False, most_relevant: bool = False,
) -> AsyncIterator[Message]: ) -> AsyncIterator[Message]:
# Important note for message search:
# The endpoint might sometimes time out while waiting for messages
# This will manifest as less results than the limit, even if there are more messages to be found
from .channel import PartialMessageable # circular import from .channel import PartialMessageable # circular import
if limit is not None and limit < 0: if limit is not None and limit < 0:
@ -321,6 +327,11 @@ async def _handle_message_search(
before = Object(id=utils.time_snowflake(before, high=False)) before = Object(id=utils.time_snowflake(before, high=False))
if isinstance(after, datetime): if isinstance(after, datetime):
after = Object(id=utils.time_snowflake(after, high=True)) after = Object(id=utils.time_snowflake(after, high=True))
after = after or OLDEST_OBJECT
if oldest_first is MISSING:
oldest_first = after != OLDEST_OBJECT
if ( if (
include_nsfw is MISSING include_nsfw is MISSING
and not isinstance(destination, Messageable) and not isinstance(destination, Messageable)
@ -329,10 +340,8 @@ async def _handle_message_search(
): ):
include_nsfw = _state.user.nsfw_allowed include_nsfw = _state.user.nsfw_allowed
if before: if offset:
payload['max_id'] = before.id payload['offset'] = offset
if after:
payload['min_id'] = after.id
if include_nsfw is not MISSING: if include_nsfw is not MISSING:
payload['include_nsfw'] = str(include_nsfw).lower() payload['include_nsfw'] = str(include_nsfw).lower()
if content: if content:
@ -366,18 +375,75 @@ async def _handle_message_search(
if oldest_first: if oldest_first:
payload['sort_order'] = 'asc' payload['sort_order'] = 'asc'
if most_relevant: if most_relevant:
# This is the default and it isn't respected anyway, but this ep is cursed enough as it is
# So we will go with what the client does
payload['sort_order'] = 'desc'
payload['sort_by'] = 'relevance' payload['sort_by'] = 'relevance'
async def _state_strategy(retrieve: int, state: Optional[Snowflake], limit: Optional[int]):
payload['limit'] = retrieve
if oldest_first and state:
payload['min_id'] = state.id
elif state:
payload['max_id'] = state.id
data = await endpoint(entity_id, payload)
if data['messages']:
if limit is not None:
limit -= len(data['messages'])
state = Object(id=int(data['messages'][-1][0]['id']))
return data, state, limit
async def _relevance_strategy(retrieve: int, _, limit: Optional[int]):
payload['limit'] = retrieve
data = await endpoint(entity_id, payload)
if data['messages']:
length = len(data['messages'])
if limit is not None:
limit -= length
payload['offset'] = (offset or 0) + length
return data, None, limit
predicate = None
if most_relevant:
strategy, state = _relevance_strategy, None
if before and after != OLDEST_OBJECT:
raise TypeError('Cannot use both before and after with most_relevant')
if before:
payload['max_id'] = before.id
if after != OLDEST_OBJECT:
payload['min_id'] = after.id
elif oldest_first:
strategy, state = _state_strategy, after
if before:
predicate = lambda m: int(m[0]['id']) < before.id
else:
strategy, state = _state_strategy, before
if after and after != OLDEST_OBJECT:
predicate = lambda m: int(m[0]['id']) > after.id
total_results = MISSING
total = 0
while True: while True:
retrieve = min(25 if limit is None else limit, 25) retrieve = 25 if limit is None else min(limit, 25)
if retrieve < 1: if retrieve < 1:
return return
if retrieve != 25:
payload['limit'] = retrieve
if offset:
payload['offset'] = offset
data = await endpoint(entity_id, payload) data, state, limit = await strategy(retrieve, state, limit)
if total_results is MISSING:
total_results = data['total_results']
messages = data['messages']
if predicate:
messages = filter(predicate, messages)
threads = {int(thread['id']): thread for thread in data.get('threads', [])} threads = {int(thread['id']): thread for thread in data.get('threads', [])}
for member in data.get('members', []): for member in data.get('members', []):
thread_id = int(member['id']) thread_id = int(member['id'])
@ -385,17 +451,11 @@ async def _handle_message_search(
if thread: if thread:
thread['member'] = member thread['member'] = member
length = len(data['messages']) count = 0
offset += length
if limit is not None:
limit -= length
# Terminate loop on next iteration; there's no data left after this
if len(data['messages']) < 25:
limit = 0
for raw_messages in data['messages']: for count, raw_messages in enumerate(messages, 1):
if not raw_messages: if not raw_messages:
_log.debug('Search for %s with payload %s yielded an empty subarray.', destination, payload)
continue continue
# Context is no longer sent, so this is probably fine # Context is no longer sent, so this is probably fine
@ -407,6 +467,13 @@ async def _handle_message_search(
channel = _resolve_channel(raw_message) channel = _resolve_channel(raw_message)
yield _state.create_message(channel=channel, data=raw_message, search_result=data) # type: ignore yield _state.create_message(channel=channel, data=raw_message, search_result=data) # type: ignore
if count == 0:
return
total += count
if total >= total_results:
return
@runtime_checkable @runtime_checkable
class Snowflake(Protocol): class Snowflake(Protocol):
@ -2317,7 +2384,7 @@ class Messageable:
attachment_filenames: Collection[str] = MISSING, attachment_filenames: Collection[str] = MISSING,
attachment_extensions: Collection[str] = MISSING, attachment_extensions: Collection[str] = MISSING,
application_commands: Collection[Snowflake] = MISSING, application_commands: Collection[Snowflake] = MISSING,
oldest_first: bool = False, oldest_first: bool = MISSING,
most_relevant: bool = False, most_relevant: bool = False,
) -> AsyncIterator[Message]: ) -> AsyncIterator[Message]:
"""Returns an :term:`asynchronous iterator` that enables searching the channel's messages. """Returns an :term:`asynchronous iterator` that enables searching the channel's messages.
@ -2356,9 +2423,7 @@ class Messageable:
limit: Optional[:class:`int`] limit: Optional[:class:`int`]
The number of messages to retrieve. The number of messages to retrieve.
If ``None``, retrieves every message in the results. Note, however, If ``None``, retrieves every message in the results. Note, however,
that this would make it a slow operation. Additionally, note that the that this would make it a slow operation.
search API has a maximum pagination offset of 5000 (subject to change),
so a limit of over 5000 or ``None`` may eventually raise an exception.
offset: :class:`int` offset: :class:`int`
The pagination offset to start at. The pagination offset to start at.
before: Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`] before: Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]
@ -2397,17 +2462,18 @@ class Messageable:
application_commands: List[:class:`~discord.abc.ApplicationCommand`] application_commands: List[:class:`~discord.abc.ApplicationCommand`]
The used application commands to filter by. The used application commands to filter by.
oldest_first: :class:`bool` oldest_first: :class:`bool`
Whether to return the oldest results first. Whether to return the oldest results first. Defaults to ``True`` if
``after`` is specified, otherwise ``False``. Ignored when ``most_relevant`` is set.
most_relevant: :class:`bool` most_relevant: :class:`bool`
Whether to sort the results by relevance. Using this with ``oldest_first`` Whether to sort the results by relevance. Limits pagination to 9975 entries.
will return the least relevant results first. Prevents using both ``before`` and ``after``.
Raises Raises
------ ------
~discord.Forbidden
You do not have permissions to search the channel's messages.
~discord.HTTPException ~discord.HTTPException
The request to search messages failed. The request to search messages failed.
TypeError
Provided both ``before`` and ``after`` when ``most_relevant`` is set.
ValueError ValueError
Could not resolve the channel's guild ID. Could not resolve the channel's guild ID.

21
discord/guild.py

@ -2951,7 +2951,7 @@ class Guild(Hashable):
attachment_filenames: Collection[str] = MISSING, attachment_filenames: Collection[str] = MISSING,
attachment_extensions: Collection[str] = MISSING, attachment_extensions: Collection[str] = MISSING,
application_commands: Collection[Snowflake] = MISSING, application_commands: Collection[Snowflake] = MISSING,
oldest_first: bool = False, oldest_first: bool = MISSING,
most_relevant: bool = False, most_relevant: bool = False,
) -> AsyncIterator[Message]: ) -> AsyncIterator[Message]:
"""Returns an :term:`asynchronous iterator` that enables searching the guild's messages. """Returns an :term:`asynchronous iterator` that enables searching the guild's messages.
@ -2990,9 +2990,7 @@ class Guild(Hashable):
limit: Optional[:class:`int`] limit: Optional[:class:`int`]
The number of messages to retrieve. The number of messages to retrieve.
If ``None``, retrieves every message in the results. Note, however, If ``None``, retrieves every message in the results. Note, however,
that this would make it a slow operation. Additionally, note that the that this would make it a slow operation.
search API has a maximum pagination offset of 5000 (subject to change),
so a limit of over 5000 or ``None`` may eventually raise an exception.
offset: :class:`int` offset: :class:`int`
The pagination offset to start at. The pagination offset to start at.
before: Union[:class:`abc.Snowflake`, :class:`datetime.datetime`] before: Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]
@ -3035,17 +3033,20 @@ class Guild(Hashable):
application_commands: List[:class:`abc.ApplicationCommand`] application_commands: List[:class:`abc.ApplicationCommand`]
The used application commands to filter by. The used application commands to filter by.
oldest_first: :class:`bool` oldest_first: :class:`bool`
Whether to return the oldest results first. Whether to return the oldest results first. Defaults to ``True`` if
``before`` is specified, otherwise ``False``. Ignored when ``most_relevant`` is set.
most_relevant: :class:`bool` most_relevant: :class:`bool`
Whether to sort the results by relevance. Using this with ``oldest_first`` Whether to sort the results by relevance. Limits pagination to 9975 entries.
will return the least relevant results first. Prevents using both ``before`` and ``after``.
Raises Raises
------ ------
Forbidden ~discord.HTTPException
You do not have permissions to search the channel's messages.
HTTPException
The request to search messages failed. The request to search messages failed.
TypeError
Provided both ``before`` and ``after`` when ``most_relevant`` is set.
ValueError
Could not resolve the channel's guild ID.
Yields Yields
------- -------

Loading…
Cancel
Save