Browse Source

Implement message search

pull/10109/head
dolfies 2 years ago
parent
commit
3108aabcce
  1. 278
      discord/abc.py
  2. 152
      discord/guild.py
  3. 23
      discord/http.py
  4. 30
      discord/message.py
  5. 16
      discord/state.py
  6. 2
      discord/types/embed.py
  7. 41
      discord/types/message.py

278
discord/abc.py

@ -101,6 +101,8 @@ if TYPE_CHECKING:
GuildChannel as GuildChannelPayload,
OverwriteType,
)
from .types.embed import EmbedType
from .types.message import MessageSearchAuthorType, MessageSearchHasType
from .types.snowflake import (
SnowflakeList,
)
@ -284,6 +286,132 @@ async def _handle_commands(
return
async def _handle_message_search(
destination: Union[Messageable, Guild],
*,
limit: Optional[int] = 25,
offset: int = 0,
before: SnowflakeTime = MISSING,
after: SnowflakeTime = MISSING,
include_nsfw: bool = MISSING,
content: str = MISSING,
channels: Collection[Snowflake] = MISSING,
authors: Collection[Snowflake] = MISSING,
author_types: Collection[MessageSearchAuthorType] = MISSING,
mentions: Collection[Snowflake] = MISSING,
mention_everyone: bool = MISSING,
pinned: bool = MISSING,
has: Collection[MessageSearchHasType] = MISSING,
embed_types: Collection[EmbedType] = MISSING,
embed_providers: Collection[str] = MISSING,
link_hostnames: Collection[str] = MISSING,
attachment_filenames: Collection[str] = MISSING,
attachment_extensions: Collection[str] = MISSING,
application_commands: Collection[Snowflake] = MISSING,
oldest_first: bool = False,
most_relevant: bool = False,
) -> AsyncIterator[Message]:
if limit is not None and limit < 0:
raise ValueError('limit must be greater than or equal to 0')
if offset < 0:
raise ValueError('offset must be greater than or equal to 0')
_state = destination._state
endpoint = _state.http.search_channel if isinstance(destination, Messageable) else _state.http.search_guild
entity_id = (await destination._get_channel()).id if isinstance(destination, Messageable) else destination.id
payload = {}
if isinstance(before, datetime):
before = Object(id=utils.time_snowflake(before, high=False))
if isinstance(after, datetime):
after = Object(id=utils.time_snowflake(after, high=True))
if (
include_nsfw is MISSING
and not isinstance(destination, Messageable)
and _state.user
and _state.user.nsfw_allowed is not None
):
include_nsfw = _state.user.nsfw_allowed
if before:
payload['max_id'] = before.id
if after:
payload['min_id'] = after.id
if include_nsfw is not MISSING:
payload['include_nsfw'] = str(include_nsfw).lower()
if content:
payload['content'] = content
if channels:
payload['channel_id'] = [c.id for c in channels]
if authors:
payload['author_id'] = [a.id for a in authors]
if author_types:
payload['author_type'] = list(author_types)
if mentions:
payload['mentions'] = [m.id for m in mentions]
if mention_everyone is not MISSING:
payload['mention_everyone'] = str(mention_everyone).lower()
if pinned is not MISSING:
payload['pinned'] = str(pinned).lower()
if has:
payload['has'] = list(has)
if embed_types:
payload['embed_type'] = list(embed_types)
if embed_providers:
payload['embed_provider'] = list(embed_providers)
if link_hostnames:
payload['link_hostname'] = list(link_hostnames)
if attachment_filenames:
payload['attachment_filename'] = list(attachment_filenames)
if attachment_extensions:
payload['attachment_extension'] = list(attachment_extensions)
if application_commands:
payload['command_id'] = [c.id for c in application_commands]
if oldest_first:
payload['sort_order'] = 'asc'
if most_relevant:
payload['sort_by'] = 'relevance'
while True:
retrieve = min(25 if limit is None else limit, 25)
if retrieve < 1:
return
if retrieve != 25:
payload['limit'] = retrieve
if offset:
payload['offset'] = offset
data = await endpoint(entity_id, payload)
threads = {int(thread['id']): thread for thread in data.get('threads', [])}
for member in data.get('members', []):
thread_id = int(member['id'])
thread = threads.get(thread_id)
if thread:
thread['member'] = member
length = len(data['messages'])
offset += length
if limit is not None:
limit -= length
# Terminate loop on next iteration; there's no data left after this
if len(data['messages']) < 25:
limit = 0
for raw_messages in data['messages']:
if not raw_messages:
continue
# Context is no longer sent, so this is probably fine
raw_message = raw_messages[0]
channel_id = int(raw_message['channel_id'])
if channel_id in threads:
raw_message['thread'] = threads[channel_id]
channel, _ = _state._get_guild_channel(raw_message)
yield _state.create_message(channel=channel, data=raw_message, search_result=data) # type: ignore
@runtime_checkable
class Snowflake(Protocol):
"""An ABC that details the common operations on a Discord model.
@ -1349,17 +1477,17 @@ class GuildChannel:
If this invite is invalid, a new invite will be created according to the parameters and returned.
.. versionadded:: 2.0
target_type: Optional[:class:`.InviteTarget`]
target_type: Optional[:class:`~discord.InviteTarget`]
The type of target for the voice channel invite, if any.
.. versionadded:: 2.0
target_user: Optional[:class:`User`]
target_user: Optional[:class:`~discord.User`]
The user whose stream to display for this invite, required if ``target_type`` is :attr:`.InviteTarget.stream`. The user must be streaming in the channel.
.. versionadded:: 2.0
target_application:: Optional[:class:`.Application`]
target_application:: Optional[:class:`~discord.Application`]
The embedded application for the invite, required if ``target_type`` is :attr:`.InviteTarget.embedded_application`.
.. versionadded:: 2.0
@ -2021,6 +2149,146 @@ class Messageable:
# There's no data left after this
break
def search(
self,
content: str = MISSING,
*,
limit: Optional[int] = 25,
offset: int = 0,
before: SnowflakeTime = MISSING,
after: SnowflakeTime = MISSING,
authors: Collection[Snowflake] = MISSING,
author_types: Collection[MessageSearchAuthorType] = MISSING,
mentions: Collection[Snowflake] = MISSING,
mention_everyone: bool = MISSING,
pinned: bool = MISSING,
has: Collection[MessageSearchHasType] = MISSING,
embed_types: Collection[EmbedType] = MISSING,
embed_providers: Collection[str] = MISSING,
link_hostnames: Collection[str] = MISSING,
attachment_filenames: Collection[str] = MISSING,
attachment_extensions: Collection[str] = MISSING,
application_commands: Collection[Snowflake] = MISSING,
oldest_first: bool = False,
most_relevant: bool = False,
) -> AsyncIterator[Message]:
"""Returns an :term:`asynchronous iterator` that enables searching the channel's messages.
You must have :attr:`~discord.Permissions.read_message_history` to do this.
.. note::
Due to a limitation with the Discord API, the :class:`.Message`
objects returned by this method do not contain complete
:attr:`.Message.reactions` data.
.. versionadded:: 2.1
Examples
---------
Usage ::
counter = 0
async for message in channel.search('hi', limit=200):
if message.author == client.user:
counter += 1
Flattening into a list: ::
messages = [message async for message in channel.search('test', limit=123)]
# messages is now a list of Message...
All parameters are optional.
Parameters
-----------
content: :class:`str`
The message content to search for.
limit: Optional[:class:`int`]
The number of messages to retrieve.
If ``None``, retrieves every message in the results. Note, however,
that this would make it a slow operation. Additionally, note that the
search API has a maximum pagination offset of 5000 (subject to change),
so a limit of over 5000 or ``None`` may eventually raise an exception.
offset: :class:`int`
The pagination offset to start at.
before: Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]
Retrieve messages before this date or message.
If a datetime is provided, it is recommended to use a UTC aware datetime.
If the datetime is naive, it is assumed to be local time.
after: Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]
Retrieve messages after this date or message.
If a datetime is provided, it is recommended to use a UTC aware datetime.
If the datetime is naive, it is assumed to be local time.
authors: List[:class:`~discord.User`]
The authors to filter by.
author_types: List[:class:`str`]
The author types to filter by. Can be one of ``user``, ``bot``, or ``webhook``.
These can be negated by prefixing with ``-``, which will exclude them.
mentions: List[:class:`~discord.User`]
The mentioned users to filter by.
mention_everyone: :class:`bool`
Whether to filter by messages that do or do not mention @everyone.
pinned: :class:`bool`
Whether to filter by messages that are or are not pinned.
has: List[:class:`str`]
The message attributes to filter by. Can be one of ``image``, ``sound``,
``video``, ``file``, ``sticker``, ``embed``, or ``link``. These can be
negated by prefixing with ``-``, which will exclude them.
embed_types: List[:class:`str`]
The embed types to filter by.
embed_providers: List[:class:`str`]
The embed providers to filter by (e.g. tenor).
link_hostnames: List[:class:`str`]
The link hostnames to filter by (e.g. google.com).
attachment_filenames: List[:class:`str`]
The attachment filenames to filter by.
attachment_extensions: List[:class:`str`]
The attachment extensions to filter by (e.g. txt).
application_commands: List[:class:`~discord.abc.ApplicationCommand`]
The used application commands to filter by.
oldest_first: :class:`bool`
Whether to return the oldest results first.
most_relevant: :class:`bool`
Whether to sort the results by relevance. Using this with ``oldest_first``
will return the least relevant results first.
Raises
------
~discord.Forbidden
You do not have permissions to search the channel's messages.
~discord.HTTPException
The request to search messages failed.
Yields
-------
:class:`~discord.Message`
The message with the message data parsed.
"""
return _handle_message_search(
self,
limit=limit,
offset=offset,
before=before,
after=after,
content=content,
authors=authors,
author_types=author_types,
mentions=mentions,
mention_everyone=mention_everyone,
pinned=pinned,
has=has,
embed_types=embed_types,
embed_providers=embed_providers,
link_hostnames=link_hostnames,
attachment_filenames=attachment_filenames,
attachment_extensions=attachment_extensions,
application_commands=application_commands,
oldest_first=oldest_first,
most_relevant=most_relevant,
)
def slash_commands(
self,
query: Optional[str] = None,
@ -2084,7 +2352,7 @@ class Messageable:
Yields
-------
:class:`.SlashCommand`
:class:`~discord.SlashCommand`
A slash command.
"""
return _handle_commands(
@ -2160,7 +2428,7 @@ class Messageable:
Yields
-------
:class:`.UserCommand`
:class:`~discord.UserCommand`
A user command.
"""
return _handle_commands(

152
discord/guild.py

@ -127,10 +127,12 @@ if TYPE_CHECKING:
StageChannel as StageChannelPayload,
ForumChannel as ForumChannelPayload,
)
from .types.embed import EmbedType
from .types.integration import IntegrationType
from .types.message import MessageSearchAuthorType, MessageSearchHasType
from .types.snowflake import SnowflakeList, Snowflake as _Snowflake
from .types.widget import EditWidgetSettings
from .message import EmojiInputType
from .message import EmojiInputType, Message
VocalGuildChannel = Union[VoiceChannel, StageChannel]
GuildChannel = Union[VocalGuildChannel, ForumChannel, TextChannel, CategoryChannel]
@ -2586,6 +2588,154 @@ class Guild(Hashable):
for e in data:
yield BanEntry(user=User(state=_state, data=e['user']), reason=e['reason'])
def search(
self,
content: str = MISSING,
*,
limit: Optional[int] = 25,
offset: int = 0,
before: SnowflakeTime = MISSING,
after: SnowflakeTime = MISSING,
include_nsfw: bool = MISSING,
channels: Collection[Snowflake] = MISSING,
authors: Collection[Snowflake] = MISSING,
author_types: Collection[MessageSearchAuthorType] = MISSING,
mentions: Collection[Snowflake] = MISSING,
mention_everyone: bool = MISSING,
pinned: bool = MISSING,
has: Collection[MessageSearchHasType] = MISSING,
embed_types: Collection[EmbedType] = MISSING,
embed_providers: Collection[str] = MISSING,
link_hostnames: Collection[str] = MISSING,
attachment_filenames: Collection[str] = MISSING,
attachment_extensions: Collection[str] = MISSING,
application_commands: Collection[Snowflake] = MISSING,
oldest_first: bool = False,
most_relevant: bool = False,
) -> AsyncIterator[Message]:
"""Returns an :term:`asynchronous iterator` that enables searching the guild's messages.
You must have :attr:`~Permissions.read_message_history` to do this.
.. note::
Due to a limitation with the Discord API, the :class:`.Message`
objects returned by this method do not contain complete
:attr:`.Message.reactions` data.
.. versionadded:: 2.1
Examples
---------
Usage ::
counter = 0
async for message in guild.search('hi', limit=200):
if message.author == client.user:
counter += 1
Flattening into a list: ::
messages = [message async for message in guild.search('test', limit=123)]
# messages is now a list of Message...
All parameters are optional.
Parameters
-----------
content: :class:`str`
The message content to search for.
limit: Optional[:class:`int`]
The number of messages to retrieve.
If ``None``, retrieves every message in the results. Note, however,
that this would make it a slow operation. Additionally, note that the
search API has a maximum pagination offset of 5000 (subject to change),
so a limit of over 5000 or ``None`` may eventually raise an exception.
offset: :class:`int`
The pagination offset to start at.
before: Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]
Retrieve messages before this date or message.
If a datetime is provided, it is recommended to use a UTC aware datetime.
If the datetime is naive, it is assumed to be local time.
after: Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]
Retrieve messages after this date or message.
If a datetime is provided, it is recommended to use a UTC aware datetime.
If the datetime is naive, it is assumed to be local time.
nsfw_allowed: :class:`bool`
Whether to include messages from NSFW channels. Defaults to :attr:`~discord.ClientUser.nsfw_allowed`.
channels: List[Union[:class:`abc.GuildChannel`, :class:`abc.PrivateChannel`, :class:`Thread`]]
The channels to filter by.
authors: List[:class:`User`]
The authors to filter by.
author_types: List[:class:`str`]
The author types to filter by. Can be one of ``user``, ``bot``, or ``webhook``.
These can be negated by prefixing with ``-``, which will exclude them.
mentions: List[:class:`User`]
The mentioned users to filter by.
mention_everyone: :class:`bool`
Whether to filter by messages that do or do not mention @everyone.
pinned: :class:`bool`
Whether to filter by messages that are or are not pinned.
has: List[:class:`str`]
The message attributes to filter by. Can be one of ``image``, ``sound``,
``video``, ``file``, ``sticker``, ``embed``, or ``link``. These can be
negated by prefixing with ``-``, which will exclude them.
embed_types: List[:class:`str`]
The embed types to filter by.
embed_providers: List[:class:`str`]
The embed providers to filter by (e.g. tenor).
link_hostnames: List[:class:`str`]
The link hostnames to filter by (e.g. google.com).
attachment_filenames: List[:class:`str`]
The attachment filenames to filter by.
attachment_extensions: List[:class:`str`]
The attachment extensions to filter by (e.g. txt).
application_commands: List[:class:`abc.ApplicationCommand`]
The used application commands to filter by.
oldest_first: :class:`bool`
Whether to return the oldest results first.
most_relevant: :class:`bool`
Whether to sort the results by relevance. Using this with ``oldest_first``
will return the least relevant results first.
Raises
------
Forbidden
You do not have permissions to search the channel's messages.
HTTPException
The request to search messages failed.
Yields
-------
:class:`Message`
The message with the message data parsed.
"""
return abc._handle_message_search(
self,
limit=limit,
offset=offset,
before=before,
after=after,
content=content,
include_nsfw=include_nsfw,
channels=channels,
authors=authors,
author_types=author_types,
mentions=mentions,
mention_everyone=mention_everyone,
pinned=pinned,
has=has,
embed_types=embed_types,
embed_providers=embed_providers,
link_hostnames=link_hostnames,
attachment_filenames=attachment_filenames,
attachment_extensions=attachment_extensions,
application_commands=application_commands,
oldest_first=oldest_first,
most_relevant=most_relevant,
)
async def prune_members(
self,
*,

23
discord/http.py

@ -760,6 +760,14 @@ class HTTPClient:
discord_hash or route_key,
)
# 202s must be retried
if response.status == 202 and isinstance(data, dict) and 'retry_after' in data:
# Sometimes retry_after is 0, but that's undesirable
retry_after: float = data['retry_after'] or 0.25
_log.debug('%s %s received a 202. Retrying in %s seconds...', method, url, retry_after)
await asyncio.sleep(retry_after)
continue
# Request was successful so just return the text/json
if 300 > response.status >= 200:
_log.debug('%s %s has received %s.', method, url, data)
@ -798,7 +806,7 @@ class HTTPClient:
_log.warning(fmt, method, url, retry_after)
_log.debug(
'Rate limit is being handled by bucket hash %s with %r major parameters',
'Rate limit is being handled by bucket hash %s with %r major parameters.',
bucket_hash,
route.major_parameters,
)
@ -833,8 +841,8 @@ class HTTPClient:
elif response.status >= 500:
raise DiscordServerError(response, data)
else:
if 'captcha_key' in data:
raise CaptchaRequired(response, data) # type: ignore # Should not be text at this point
if isinstance(data, dict) and 'captcha_key' in data:
raise CaptchaRequired(response, data)
raise HTTPException(response, data)
# This is handling exceptions from the request
@ -1259,6 +1267,15 @@ class HTTPClient:
return self.request(Route('GET', '/channels/{channel_id}/messages', channel_id=channel_id), params=params)
def search_guild(self, guild_id: Snowflake, payload: Dict[str, Any]) -> Response[message.MessageSearchResult]:
return self.request(Route('GET', '/guilds/{guild_id}/messages/search', guild_id=guild_id), params=payload)
def search_channel(self, channel_id: Snowflake, payload: Dict[str, Any]) -> Response[message.MessageSearchResult]:
return self.request(Route('GET', '/channels/{channel_id}/messages/search', channel_id=channel_id), params=payload)
def search_user(self, payload: Dict[str, Any]) -> Response[message.MessageSearchResult]:
return self.request(Route('GET', '/users/@me/messages/search'), json=payload)
def publish_message(self, channel_id: Snowflake, message_id: Snowflake) -> Response[message.Message]:
r = Route(
'POST',

30
discord/message.py

@ -82,6 +82,7 @@ if TYPE_CHECKING:
MessageReference as MessageReferencePayload,
MessageActivity as MessageActivityPayload,
RoleSubscriptionData as RoleSubscriptionDataPayload,
MessageSearchResult as MessageSearchResultPayload,
)
from .types.interactions import MessageInteraction as MessageInteractionPayload
@ -1426,6 +1427,24 @@ class Message(PartialMessage, Hashable):
The interaction that this message is a response to.
.. versionadded:: 2.0
hit: :class:`bool`
Whether the message was a hit in a search result. As surrounding messages
are no longer returned in search results, this is always ``True`` for search results.
.. versionadded:: 2.1
total_results: Optional[:class:`int`]
The total number of results for the search query. This is only present in search results.
.. versionadded:: 2.1
analytics_id: Optional[:class:`str`]
The search results analytics ID. This is only present in search results.
.. versionadded:: 2.1
doing_deep_historical_index: Optional[:class:`bool`]
The status of the document's current deep historical indexing operation, if any.
This is only present in search results.
.. versionadded:: 2.1
"""
__slots__ = (
@ -1460,6 +1479,10 @@ class Message(PartialMessage, Hashable):
'role_subscription',
'application_id',
'position',
'hit',
'total_results',
'analytics_id',
'doing_deep_historical_index',
)
if TYPE_CHECKING:
@ -1477,6 +1500,7 @@ class Message(PartialMessage, Hashable):
state: ConnectionState,
channel: MessageableChannel,
data: MessagePayload,
search_result: Optional[MessageSearchResultPayload] = None,
) -> None:
self.channel: MessageableChannel = channel
self.id: int = int(data['id'])
@ -1554,6 +1578,12 @@ class Message(PartialMessage, Hashable):
else:
self.role_subscription = RoleSubscriptionInfo(role_subscription)
search_payload = search_result or {}
self.hit: bool = data.get('hit', False)
self.total_results: Optional[int] = search_payload.get('total_results')
self.analytics_id: Optional[str] = search_payload.get('analytics_id')
self.doing_deep_historical_index: Optional[bool] = search_payload.get('doing_deep_historical_index')
for handler in ('author', 'member', 'mentions', 'mention_roles', 'call', 'interaction', 'components'):
try:
getattr(self, f'_handle_{handler}')(data[handler])

16
discord/state.py

@ -117,7 +117,11 @@ if TYPE_CHECKING:
from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload
from .types.sticker import GuildSticker as GuildStickerPayload
from .types.guild import Guild as GuildPayload
from .types.message import Message as MessagePayload, PartialMessage as PartialMessagePayload
from .types.message import (
Message as MessagePayload,
MessageSearchResult as MessageSearchResultPayload,
PartialMessage as PartialMessagePayload,
)
from .types import gateway as gw
from .types.voice import GuildVoiceState
from .types.activity import ClientStatus as ClientStatusPayload
@ -2718,8 +2722,14 @@ class ConnectionState:
if channel is not None:
return channel
def create_message(self, *, channel: MessageableChannel, data: MessagePayload) -> Message:
return Message(state=self, channel=channel, data=data)
def create_message(
self,
*,
channel: MessageableChannel,
data: MessagePayload,
search_result: Optional[MessageSearchResultPayload] = None,
) -> Message:
return Message(state=self, channel=channel, data=data, search_result=search_result)
def _update_message_references(self) -> None:
# self._messages won't be None when this is called

2
discord/types/embed.py

@ -71,7 +71,7 @@ class EmbedAuthor(TypedDict, total=False):
proxy_icon_url: str
EmbedType = Literal['rich', 'image', 'video', 'gifv', 'article', 'link']
EmbedType = Literal['rich', 'image', 'video', 'gifv', 'article', 'link', 'auto_moderation_message']
class Embed(TypedDict, total=False):

41
discord/types/message.py

@ -37,6 +37,7 @@ from .components import Component
from .interactions import MessageInteraction
from .application import BaseApplication
from .sticker import StickerItem
from .threads import Thread, ThreadMember
class PartialMessage(TypedDict):
@ -134,6 +135,8 @@ class Message(PartialMessage):
position: NotRequired[int]
call: NotRequired[Call]
role_subscription_data: NotRequired[RoleSubscriptionData]
hit: NotRequired[bool]
thread: NotRequired[Thread]
AllowedMentionType = Literal['roles', 'users', 'everyone']
@ -144,3 +147,41 @@ class AllowedMentions(TypedDict):
roles: SnowflakeList
users: SnowflakeList
replied_user: bool
class MessageSearchIndexingResult(TypedDict):
# Error but not quite
message: str
code: int
documents_indexed: int
retry_after: int
class MessageSearchResult(TypedDict):
messages: List[List[Message]]
threads: NotRequired[List[Thread]]
members: NotRequired[List[ThreadMember]]
total_results: int
analytics_id: str
doing_deep_historical_index: NotRequired[bool]
MessageSearchAuthorType = Literal['user', '-user', 'bot', '-bot', 'webhook', '-webhook']
MessageSearchHasType = Literal[
'image',
'-image',
'sound',
'-sound',
'video',
'-video',
'file',
'-file',
'sticker',
'-sticker',
'embed',
'-embed',
'link',
'-link',
]
MessageSearchSortType = Literal['timestamp', 'relevance']
MessageSearchSortOrder = Literal['desc', 'asc']

Loading…
Cancel
Save