Browse Source

Immensely improve robustness of search pagination (fix #581)

pull/10109/head
dolfies 5 months ago
parent
commit
047fc48fb1
  1. 128
      discord/abc.py
  2. 21
      discord/guild.py

128
discord/abc.py

@ -27,6 +27,7 @@ from __future__ import annotations
import copy
import asyncio
from datetime import datetime
import logging
from operator import attrgetter
from typing import (
Any,
@ -75,6 +76,9 @@ __all__ = (
)
T = TypeVar('T', bound=VoiceProtocol)
MISSING = utils.MISSING
_log = logging.getLogger(__name__)
if TYPE_CHECKING:
from typing_extensions import Self
@ -114,8 +118,6 @@ if TYPE_CHECKING:
VocalChannel = Union[VoiceChannel, StageChannel, DMChannel, GroupChannel]
SnowflakeTime = Union["Snowflake", datetime]
MISSING = utils.MISSING
class _Undefined:
def __repr__(self) -> str:
@ -279,9 +281,13 @@ async def _handle_message_search(
attachment_filenames: Collection[str] = MISSING,
attachment_extensions: Collection[str] = MISSING,
application_commands: Collection[Snowflake] = MISSING,
oldest_first: bool = False,
oldest_first: bool = MISSING,
most_relevant: bool = False,
) -> AsyncIterator[Message]:
# Important note for message search:
# The endpoint might sometimes time out while waiting for messages
# This will manifest as less results than the limit, even if there are more messages to be found
from .channel import PartialMessageable # circular import
if limit is not None and limit < 0:
@ -321,6 +327,11 @@ async def _handle_message_search(
before = Object(id=utils.time_snowflake(before, high=False))
if isinstance(after, datetime):
after = Object(id=utils.time_snowflake(after, high=True))
after = after or OLDEST_OBJECT
if oldest_first is MISSING:
oldest_first = after != OLDEST_OBJECT
if (
include_nsfw is MISSING
and not isinstance(destination, Messageable)
@ -329,10 +340,8 @@ async def _handle_message_search(
):
include_nsfw = _state.user.nsfw_allowed
if before:
payload['max_id'] = before.id
if after:
payload['min_id'] = after.id
if offset:
payload['offset'] = offset
if include_nsfw is not MISSING:
payload['include_nsfw'] = str(include_nsfw).lower()
if content:
@ -366,18 +375,75 @@ async def _handle_message_search(
if oldest_first:
payload['sort_order'] = 'asc'
if most_relevant:
# This is the default and it isn't respected anyway, but this ep is cursed enough as it is
# So we will go with what the client does
payload['sort_order'] = 'desc'
payload['sort_by'] = 'relevance'
async def _state_strategy(retrieve: int, state: Optional[Snowflake], limit: Optional[int]):
payload['limit'] = retrieve
if oldest_first and state:
payload['min_id'] = state.id
elif state:
payload['max_id'] = state.id
data = await endpoint(entity_id, payload)
if data['messages']:
if limit is not None:
limit -= len(data['messages'])
state = Object(id=int(data['messages'][-1][0]['id']))
return data, state, limit
async def _relevance_strategy(retrieve: int, _, limit: Optional[int]):
payload['limit'] = retrieve
data = await endpoint(entity_id, payload)
if data['messages']:
length = len(data['messages'])
if limit is not None:
limit -= length
payload['offset'] = (offset or 0) + length
return data, None, limit
predicate = None
if most_relevant:
strategy, state = _relevance_strategy, None
if before and after != OLDEST_OBJECT:
raise TypeError('Cannot use both before and after with most_relevant')
if before:
payload['max_id'] = before.id
if after != OLDEST_OBJECT:
payload['min_id'] = after.id
elif oldest_first:
strategy, state = _state_strategy, after
if before:
predicate = lambda m: int(m[0]['id']) < before.id
else:
strategy, state = _state_strategy, before
if after and after != OLDEST_OBJECT:
predicate = lambda m: int(m[0]['id']) > after.id
total_results = MISSING
total = 0
while True:
retrieve = min(25 if limit is None else limit, 25)
retrieve = 25 if limit is None else min(limit, 25)
if retrieve < 1:
return
if retrieve != 25:
payload['limit'] = retrieve
if offset:
payload['offset'] = offset
data = await endpoint(entity_id, payload)
data, state, limit = await strategy(retrieve, state, limit)
if total_results is MISSING:
total_results = data['total_results']
messages = data['messages']
if predicate:
messages = filter(predicate, messages)
threads = {int(thread['id']): thread for thread in data.get('threads', [])}
for member in data.get('members', []):
thread_id = int(member['id'])
@ -385,17 +451,11 @@ async def _handle_message_search(
if thread:
thread['member'] = member
length = len(data['messages'])
offset += length
if limit is not None:
limit -= length
count = 0
# Terminate loop on next iteration; there's no data left after this
if len(data['messages']) < 25:
limit = 0
for raw_messages in data['messages']:
for count, raw_messages in enumerate(messages, 1):
if not raw_messages:
_log.debug('Search for %s with payload %s yielded an empty subarray.', destination, payload)
continue
# Context is no longer sent, so this is probably fine
@ -407,6 +467,13 @@ async def _handle_message_search(
channel = _resolve_channel(raw_message)
yield _state.create_message(channel=channel, data=raw_message, search_result=data) # type: ignore
if count == 0:
return
total += count
if total >= total_results:
return
@runtime_checkable
class Snowflake(Protocol):
@ -2317,7 +2384,7 @@ class Messageable:
attachment_filenames: Collection[str] = MISSING,
attachment_extensions: Collection[str] = MISSING,
application_commands: Collection[Snowflake] = MISSING,
oldest_first: bool = False,
oldest_first: bool = MISSING,
most_relevant: bool = False,
) -> AsyncIterator[Message]:
"""Returns an :term:`asynchronous iterator` that enables searching the channel's messages.
@ -2356,9 +2423,7 @@ class Messageable:
limit: Optional[:class:`int`]
The number of messages to retrieve.
If ``None``, retrieves every message in the results. Note, however,
that this would make it a slow operation. Additionally, note that the
search API has a maximum pagination offset of 5000 (subject to change),
so a limit of over 5000 or ``None`` may eventually raise an exception.
that this would make it a slow operation.
offset: :class:`int`
The pagination offset to start at.
before: Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]
@ -2397,17 +2462,18 @@ class Messageable:
application_commands: List[:class:`~discord.abc.ApplicationCommand`]
The used application commands to filter by.
oldest_first: :class:`bool`
Whether to return the oldest results first.
Whether to return the oldest results first. Defaults to ``True`` if
``after`` is specified, otherwise ``False``. Ignored when ``most_relevant`` is set.
most_relevant: :class:`bool`
Whether to sort the results by relevance. Using this with ``oldest_first``
will return the least relevant results first.
Whether to sort the results by relevance. Limits pagination to 9975 entries.
Prevents using both ``before`` and ``after``.
Raises
------
~discord.Forbidden
You do not have permissions to search the channel's messages.
~discord.HTTPException
The request to search messages failed.
TypeError
Provided both ``before`` and ``after`` when ``most_relevant`` is set.
ValueError
Could not resolve the channel's guild ID.

21
discord/guild.py

@ -2951,7 +2951,7 @@ class Guild(Hashable):
attachment_filenames: Collection[str] = MISSING,
attachment_extensions: Collection[str] = MISSING,
application_commands: Collection[Snowflake] = MISSING,
oldest_first: bool = False,
oldest_first: bool = MISSING,
most_relevant: bool = False,
) -> AsyncIterator[Message]:
"""Returns an :term:`asynchronous iterator` that enables searching the guild's messages.
@ -2990,9 +2990,7 @@ class Guild(Hashable):
limit: Optional[:class:`int`]
The number of messages to retrieve.
If ``None``, retrieves every message in the results. Note, however,
that this would make it a slow operation. Additionally, note that the
search API has a maximum pagination offset of 5000 (subject to change),
so a limit of over 5000 or ``None`` may eventually raise an exception.
that this would make it a slow operation.
offset: :class:`int`
The pagination offset to start at.
before: Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]
@ -3035,17 +3033,20 @@ class Guild(Hashable):
application_commands: List[:class:`abc.ApplicationCommand`]
The used application commands to filter by.
oldest_first: :class:`bool`
Whether to return the oldest results first.
Whether to return the oldest results first. Defaults to ``True`` if
``before`` is specified, otherwise ``False``. Ignored when ``most_relevant`` is set.
most_relevant: :class:`bool`
Whether to sort the results by relevance. Using this with ``oldest_first``
will return the least relevant results first.
Whether to sort the results by relevance. Limits pagination to 9975 entries.
Prevents using both ``before`` and ``after``.
Raises
------
Forbidden
You do not have permissions to search the channel's messages.
HTTPException
~discord.HTTPException
The request to search messages failed.
TypeError
Provided both ``before`` and ``after`` when ``most_relevant`` is set.
ValueError
Could not resolve the channel's guild ID.
Yields
-------

Loading…
Cancel
Save