|
|
@ -27,6 +27,7 @@ from __future__ import annotations |
|
|
|
import copy |
|
|
|
import asyncio |
|
|
|
from datetime import datetime |
|
|
|
import logging |
|
|
|
from operator import attrgetter |
|
|
|
from typing import ( |
|
|
|
Any, |
|
|
@ -75,6 +76,9 @@ __all__ = ( |
|
|
|
) |
|
|
|
|
|
|
|
T = TypeVar('T', bound=VoiceProtocol) |
|
|
|
MISSING = utils.MISSING |
|
|
|
|
|
|
|
_log = logging.getLogger(__name__) |
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
from typing_extensions import Self |
|
|
@ -114,8 +118,6 @@ if TYPE_CHECKING: |
|
|
|
VocalChannel = Union[VoiceChannel, StageChannel, DMChannel, GroupChannel] |
|
|
|
SnowflakeTime = Union["Snowflake", datetime] |
|
|
|
|
|
|
|
MISSING = utils.MISSING |
|
|
|
|
|
|
|
|
|
|
|
class _Undefined: |
|
|
|
def __repr__(self) -> str: |
|
|
@ -279,9 +281,13 @@ async def _handle_message_search( |
|
|
|
attachment_filenames: Collection[str] = MISSING, |
|
|
|
attachment_extensions: Collection[str] = MISSING, |
|
|
|
application_commands: Collection[Snowflake] = MISSING, |
|
|
|
oldest_first: bool = False, |
|
|
|
oldest_first: bool = MISSING, |
|
|
|
most_relevant: bool = False, |
|
|
|
) -> AsyncIterator[Message]: |
|
|
|
# Important note for message search: |
|
|
|
# The endpoint might sometimes time out while waiting for messages |
|
|
|
# This will manifest as less results than the limit, even if there are more messages to be found |
|
|
|
|
|
|
|
from .channel import PartialMessageable # circular import |
|
|
|
|
|
|
|
if limit is not None and limit < 0: |
|
|
@ -321,6 +327,11 @@ async def _handle_message_search( |
|
|
|
before = Object(id=utils.time_snowflake(before, high=False)) |
|
|
|
if isinstance(after, datetime): |
|
|
|
after = Object(id=utils.time_snowflake(after, high=True)) |
|
|
|
|
|
|
|
after = after or OLDEST_OBJECT |
|
|
|
if oldest_first is MISSING: |
|
|
|
oldest_first = after != OLDEST_OBJECT |
|
|
|
|
|
|
|
if ( |
|
|
|
include_nsfw is MISSING |
|
|
|
and not isinstance(destination, Messageable) |
|
|
@ -329,10 +340,8 @@ async def _handle_message_search( |
|
|
|
): |
|
|
|
include_nsfw = _state.user.nsfw_allowed |
|
|
|
|
|
|
|
if before: |
|
|
|
payload['max_id'] = before.id |
|
|
|
if after: |
|
|
|
payload['min_id'] = after.id |
|
|
|
if offset: |
|
|
|
payload['offset'] = offset |
|
|
|
if include_nsfw is not MISSING: |
|
|
|
payload['include_nsfw'] = str(include_nsfw).lower() |
|
|
|
if content: |
|
|
@ -366,18 +375,75 @@ async def _handle_message_search( |
|
|
|
if oldest_first: |
|
|
|
payload['sort_order'] = 'asc' |
|
|
|
if most_relevant: |
|
|
|
# This is the default and it isn't respected anyway, but this ep is cursed enough as it is |
|
|
|
# So we will go with what the client does |
|
|
|
payload['sort_order'] = 'desc' |
|
|
|
payload['sort_by'] = 'relevance' |
|
|
|
|
|
|
|
async def _state_strategy(retrieve: int, state: Optional[Snowflake], limit: Optional[int]): |
|
|
|
payload['limit'] = retrieve |
|
|
|
if oldest_first and state: |
|
|
|
payload['min_id'] = state.id |
|
|
|
elif state: |
|
|
|
payload['max_id'] = state.id |
|
|
|
data = await endpoint(entity_id, payload) |
|
|
|
|
|
|
|
if data['messages']: |
|
|
|
if limit is not None: |
|
|
|
limit -= len(data['messages']) |
|
|
|
|
|
|
|
state = Object(id=int(data['messages'][-1][0]['id'])) |
|
|
|
|
|
|
|
return data, state, limit |
|
|
|
|
|
|
|
async def _relevance_strategy(retrieve: int, _, limit: Optional[int]): |
|
|
|
payload['limit'] = retrieve |
|
|
|
data = await endpoint(entity_id, payload) |
|
|
|
|
|
|
|
if data['messages']: |
|
|
|
length = len(data['messages']) |
|
|
|
if limit is not None: |
|
|
|
limit -= length |
|
|
|
payload['offset'] = (offset or 0) + length |
|
|
|
|
|
|
|
return data, None, limit |
|
|
|
|
|
|
|
predicate = None |
|
|
|
|
|
|
|
if most_relevant: |
|
|
|
strategy, state = _relevance_strategy, None |
|
|
|
if before and after != OLDEST_OBJECT: |
|
|
|
raise TypeError('Cannot use both before and after with most_relevant') |
|
|
|
if before: |
|
|
|
payload['max_id'] = before.id |
|
|
|
if after != OLDEST_OBJECT: |
|
|
|
payload['min_id'] = after.id |
|
|
|
elif oldest_first: |
|
|
|
strategy, state = _state_strategy, after |
|
|
|
if before: |
|
|
|
predicate = lambda m: int(m[0]['id']) < before.id |
|
|
|
else: |
|
|
|
strategy, state = _state_strategy, before |
|
|
|
if after and after != OLDEST_OBJECT: |
|
|
|
predicate = lambda m: int(m[0]['id']) > after.id |
|
|
|
|
|
|
|
total_results = MISSING |
|
|
|
total = 0 |
|
|
|
|
|
|
|
while True: |
|
|
|
retrieve = min(25 if limit is None else limit, 25) |
|
|
|
retrieve = 25 if limit is None else min(limit, 25) |
|
|
|
if retrieve < 1: |
|
|
|
return |
|
|
|
if retrieve != 25: |
|
|
|
payload['limit'] = retrieve |
|
|
|
if offset: |
|
|
|
payload['offset'] = offset |
|
|
|
|
|
|
|
data = await endpoint(entity_id, payload) |
|
|
|
data, state, limit = await strategy(retrieve, state, limit) |
|
|
|
|
|
|
|
if total_results is MISSING: |
|
|
|
total_results = data['total_results'] |
|
|
|
|
|
|
|
messages = data['messages'] |
|
|
|
if predicate: |
|
|
|
messages = filter(predicate, messages) |
|
|
|
|
|
|
|
threads = {int(thread['id']): thread for thread in data.get('threads', [])} |
|
|
|
for member in data.get('members', []): |
|
|
|
thread_id = int(member['id']) |
|
|
@ -385,17 +451,11 @@ async def _handle_message_search( |
|
|
|
if thread: |
|
|
|
thread['member'] = member |
|
|
|
|
|
|
|
length = len(data['messages']) |
|
|
|
offset += length |
|
|
|
if limit is not None: |
|
|
|
limit -= length |
|
|
|
count = 0 |
|
|
|
|
|
|
|
# Terminate loop on next iteration; there's no data left after this |
|
|
|
if len(data['messages']) < 25: |
|
|
|
limit = 0 |
|
|
|
|
|
|
|
for raw_messages in data['messages']: |
|
|
|
for count, raw_messages in enumerate(messages, 1): |
|
|
|
if not raw_messages: |
|
|
|
_log.debug('Search for %s with payload %s yielded an empty subarray.', destination, payload) |
|
|
|
continue |
|
|
|
|
|
|
|
# Context is no longer sent, so this is probably fine |
|
|
@ -407,6 +467,13 @@ async def _handle_message_search( |
|
|
|
channel = _resolve_channel(raw_message) |
|
|
|
yield _state.create_message(channel=channel, data=raw_message, search_result=data) # type: ignore |
|
|
|
|
|
|
|
if count == 0: |
|
|
|
return |
|
|
|
|
|
|
|
total += count |
|
|
|
if total >= total_results: |
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
@runtime_checkable |
|
|
|
class Snowflake(Protocol): |
|
|
@ -2317,7 +2384,7 @@ class Messageable: |
|
|
|
attachment_filenames: Collection[str] = MISSING, |
|
|
|
attachment_extensions: Collection[str] = MISSING, |
|
|
|
application_commands: Collection[Snowflake] = MISSING, |
|
|
|
oldest_first: bool = False, |
|
|
|
oldest_first: bool = MISSING, |
|
|
|
most_relevant: bool = False, |
|
|
|
) -> AsyncIterator[Message]: |
|
|
|
"""Returns an :term:`asynchronous iterator` that enables searching the channel's messages. |
|
|
@ -2356,9 +2423,7 @@ class Messageable: |
|
|
|
limit: Optional[:class:`int`] |
|
|
|
The number of messages to retrieve. |
|
|
|
If ``None``, retrieves every message in the results. Note, however, |
|
|
|
that this would make it a slow operation. Additionally, note that the |
|
|
|
search API has a maximum pagination offset of 5000 (subject to change), |
|
|
|
so a limit of over 5000 or ``None`` may eventually raise an exception. |
|
|
|
that this would make it a slow operation. |
|
|
|
offset: :class:`int` |
|
|
|
The pagination offset to start at. |
|
|
|
before: Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`] |
|
|
@ -2397,17 +2462,18 @@ class Messageable: |
|
|
|
application_commands: List[:class:`~discord.abc.ApplicationCommand`] |
|
|
|
The used application commands to filter by. |
|
|
|
oldest_first: :class:`bool` |
|
|
|
Whether to return the oldest results first. |
|
|
|
Whether to return the oldest results first. Defaults to ``True`` if |
|
|
|
``after`` is specified, otherwise ``False``. Ignored when ``most_relevant`` is set. |
|
|
|
most_relevant: :class:`bool` |
|
|
|
Whether to sort the results by relevance. Using this with ``oldest_first`` |
|
|
|
will return the least relevant results first. |
|
|
|
Whether to sort the results by relevance. Limits pagination to 9975 entries. |
|
|
|
Prevents using both ``before`` and ``after``. |
|
|
|
|
|
|
|
Raises |
|
|
|
------ |
|
|
|
~discord.Forbidden |
|
|
|
You do not have permissions to search the channel's messages. |
|
|
|
~discord.HTTPException |
|
|
|
The request to search messages failed. |
|
|
|
TypeError |
|
|
|
Provided both ``before`` and ``after`` when ``most_relevant`` is set. |
|
|
|
ValueError |
|
|
|
Could not resolve the channel's guild ID. |
|
|
|
|
|
|
|