From d1cb30cccf39648e21c0f7c73cb087fc730b8e25 Mon Sep 17 00:00:00 2001 From: PikalaxALT Date: Thu, 26 Nov 2020 23:19:00 -0500 Subject: [PATCH] Implement discord.Message.reply --- discord/abc.py | 44 ++++++++++-- discord/ext/commands/context.py | 5 ++ discord/http.py | 9 ++- discord/mentions.py | 22 ++++-- discord/message.py | 123 ++++++++++++++++++++------------ discord/message_reference.py | 86 ++++++++++++++++++++++ 6 files changed, 229 insertions(+), 60 deletions(-) create mode 100644 discord/message_reference.py diff --git a/discord/abc.py b/discord/abc.py index 91da74dfa..4968ce1b1 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -39,6 +39,7 @@ from .invite import Invite from .file import File from .voice_client import VoiceClient, VoiceProtocol from . import utils +from .message_reference import _MessageType, MessageReference class _Undefined: def __repr__(self): @@ -802,7 +803,8 @@ class Messageable(metaclass=abc.ABCMeta): async def send(self, content=None, *, tts=False, embed=None, file=None, files=None, delete_after=None, nonce=None, - allowed_mentions=None): + allowed_mentions=None, reference=None, + mention_author=None): """|coro| Sends a message to the destination with the content given. @@ -848,6 +850,19 @@ class Messageable(metaclass=abc.ABCMeta): .. versionadded:: 1.4 + reference: Union[:class:`~discord.Message`, :class:`~discord.MessageReference`] + A reference to the :class:`~discord.Message` to which you are replying, this can be created using + :meth:`~discord.Message.to_reference` or passed directly as a :class:`~discord.Message`. You can control + whether this mentions the author of the referenced message using the :attr:`~discord.AllowedMentions.replied_user` + attribute of ``allowed_mentions`` or by setting ``mention_author``. + + .. versionadded:: 1.6 + + mention_author: Optional[:class:`bool`] + If set, overrides the :attr:`~discord.AllowedMentions.replied_user` attribute of ``allowed_mentions``. + + .. versionadded:: 1.6 + Raises -------- ~discord.HTTPException @@ -855,8 +870,10 @@ class Messageable(metaclass=abc.ABCMeta): ~discord.Forbidden You do not have the proper permissions to send the message. ~discord.InvalidArgument - The ``files`` list is not of the appropriate size or - you specified both ``file`` and ``files``. + The ``files`` list is not of the appropriate size, + you specified both ``file`` and ``files``, + or the ``reference`` object is not a :class:`~discord.Message` + or :class:`~discord.MessageReference`. Returns --------- @@ -878,6 +895,18 @@ class Messageable(metaclass=abc.ABCMeta): else: allowed_mentions = state.allowed_mentions and state.allowed_mentions.to_dict() + if mention_author is not None: + allowed_mentions = allowed_mentions or {} + allowed_mentions['replied_user'] = mention_author + + if reference is not None: + if isinstance(reference, _MessageType): + if not isinstance(reference, MessageReference): + reference = reference.to_reference() + reference = reference.to_dict() + else: + raise InvalidArgument('reference parameter must be Message or MessageReference') + if file is not None and files is not None: raise InvalidArgument('cannot pass both file and files parameter to send()') @@ -887,7 +916,8 @@ class Messageable(metaclass=abc.ABCMeta): try: data = await state.http.send_files(channel.id, files=[file], allowed_mentions=allowed_mentions, - content=content, tts=tts, embed=embed, nonce=nonce) + content=content, tts=tts, embed=embed, nonce=nonce, + message_reference=reference) finally: file.close() @@ -899,13 +929,15 @@ class Messageable(metaclass=abc.ABCMeta): try: data = await state.http.send_files(channel.id, files=files, content=content, tts=tts, - embed=embed, nonce=nonce, allowed_mentions=allowed_mentions) + embed=embed, nonce=nonce, allowed_mentions=allowed_mentions, + message_reference=reference) finally: for f in files: f.close() else: data = await state.http.send_message(channel.id, content, tts=tts, embed=embed, - nonce=nonce, allowed_mentions=allowed_mentions) + nonce=nonce, allowed_mentions=allowed_mentions, + message_reference=reference) ret = state.create_message(channel=channel, data=data) if delete_after is not None: diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 3cf851c68..d129e819d 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -322,3 +322,8 @@ class Context(discord.abc.Messageable): return None except CommandError as e: await cmd.on_help_command_error(self, e) + + async def reply(self, content=None, **kwargs): + return await self.message.reply(content, **kwargs) + + reply.__doc__ = discord.Message.reply.__doc__ diff --git a/discord/http.py b/discord/http.py index 9459a9c55..887632f9e 100644 --- a/discord/http.py +++ b/discord/http.py @@ -342,7 +342,7 @@ class HTTPClient: return self.request(Route('POST', '/users/@me/channels'), json=payload) - def send_message(self, channel_id, content, *, tts=False, embed=None, nonce=None, allowed_mentions=None): + def send_message(self, channel_id, content, *, tts=False, embed=None, nonce=None, allowed_mentions=None, message_reference=None): r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id) payload = {} @@ -361,12 +361,15 @@ class HTTPClient: if allowed_mentions: payload['allowed_mentions'] = allowed_mentions + if message_reference: + payload['message_reference'] = message_reference + return self.request(r, json=payload) def send_typing(self, channel_id): return self.request(Route('POST', '/channels/{channel_id}/typing', channel_id=channel_id)) - def send_files(self, channel_id, *, files, content=None, tts=False, embed=None, nonce=None, allowed_mentions=None): + def send_files(self, channel_id, *, files, content=None, tts=False, embed=None, nonce=None, allowed_mentions=None, message_reference=None): r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id) form = aiohttp.FormData() @@ -379,6 +382,8 @@ class HTTPClient: payload['nonce'] = nonce if allowed_mentions: payload['allowed_mentions'] = allowed_mentions + if message_reference: + payload['message_reference'] = message_reference form.add_field('payload_json', utils.to_json(payload)) if len(files) == 1: diff --git a/discord/mentions.py b/discord/mentions.py index 73c9b500b..14fcecc57 100644 --- a/discord/mentions.py +++ b/discord/mentions.py @@ -59,14 +59,20 @@ class AllowedMentions: roles are not mentioned at all. If a list of :class:`abc.Snowflake` is given then only the roles provided will be mentioned, provided those roles are in the message content. + replied_user: :class:`bool` + Whether to mention the author of the message being replied to. Defaults + to ``True``. + + .. versionadded:: 1.6 """ - __slots__ = ('everyone', 'users', 'roles') + __slots__ = ('everyone', 'users', 'roles', 'replied_user') - def __init__(self, *, everyone=default, users=default, roles=default): + def __init__(self, *, everyone=default, users=default, roles=default, replied_user=default): self.everyone = everyone self.users = users self.roles = roles + self.replied_user = replied_user @classmethod def all(cls): @@ -74,7 +80,7 @@ class AllowedMentions: .. versionadded:: 1.5 """ - return cls(everyone=True, users=True, roles=True) + return cls(everyone=True, users=True, roles=True, replied_user=True) @classmethod def none(cls): @@ -82,7 +88,7 @@ class AllowedMentions: .. versionadded:: 1.5 """ - return cls(everyone=False, users=False, roles=False) + return cls(everyone=False, users=False, roles=False, replied_user=False) def to_dict(self): parse = [] @@ -101,6 +107,9 @@ class AllowedMentions: elif self.roles != False: data['roles'] = [x.id for x in self.roles] + if self.replied_user: + data['replied_user'] = True + data['parse'] = parse return data @@ -111,7 +120,8 @@ class AllowedMentions: everyone = self.everyone if other.everyone is default else other.everyone users = self.users if other.users is default else other.users roles = self.roles if other.roles is default else other.roles - return AllowedMentions(everyone=everyone, roles=roles, users=users) + replied_user = self.replied_user if other.replied_user is default else other.replied_user + return AllowedMentions(everyone=everyone, roles=roles, users=users, replied_user=replied_user) def __repr__(self): - return '{0.__class__.__qualname__}(everyone={0.everyone}, users={0.users}, roles={0.roles})'.format(self) + return '{0.__class__.__qualname__}(everyone={0.everyone}, users={0.users}, roles={0.roles}, replied_user={0.replied_user})'.format(self) diff --git a/discord/message.py b/discord/message.py index 7c255b91a..d1928464a 100644 --- a/discord/message.py +++ b/discord/message.py @@ -43,6 +43,8 @@ from .file import File from .utils import escape_mentions from .guild import Guild from .mixins import Hashable +from .mentions import AllowedMentions +from .message_reference import _MessageType, MessageReference from .sticker import Sticker @@ -210,36 +212,6 @@ class Attachment: data = await self.read(use_cached=use_cached) return File(io.BytesIO(data), filename=self.filename, spoiler=spoiler) -class MessageReference: - """Represents a reference to a :class:`Message`. - - .. versionadded:: 1.5 - - Attributes - ----------- - message_id: Optional[:class:`int`] - The id of the message referenced. - channel_id: :class:`int` - The channel id of the message referenced. - guild_id: Optional[:class:`int`] - The guild id of the message referenced. - """ - - __slots__ = ('message_id', 'channel_id', 'guild_id', '_state') - - def __init__(self, state, **kwargs): - self.message_id = utils._get_as_snowflake(kwargs, 'message_id') - self.channel_id = int(kwargs.pop('channel_id')) - self.guild_id = utils._get_as_snowflake(kwargs, 'guild_id') - self._state = state - - @property - def cached_message(self): - """Optional[:class:`Message`]: The cached message, if found in the internal message cache.""" - return self._state._get_message(self.message_id) - - def __repr__(self): - return ''.format(self) def flatten_handlers(cls): prefix = len('_handle_') @@ -258,7 +230,7 @@ def flatten_handlers(cls): return cls @flatten_handlers -class Message(Hashable): +class Message(Hashable, _MessageType): r"""Represents a message from Discord. There should be no need to create one of these manually. @@ -288,10 +260,10 @@ class Message(Hashable): call: Optional[:class:`CallMessage`] The call that the message refers to. This is only applicable to messages of type :attr:`MessageType.call`. - reference: Optional[:class:`MessageReference`] + reference: Optional[:class:`~discord.MessageReference`] The message that this message references. This is only applicable to messages of - type :attr:`MessageType.pins_add` or crossposted messages created by a - followed channel integration. + type :attr:`MessageType.pins_add`, crossposted messages created by a + followed channel integration, or message replies. .. versionadded:: 1.5 @@ -841,9 +813,22 @@ class Message(Hashable): before deleting the message we just edited. If the deletion fails, then it is silently ignored. allowed_mentions: Optional[:class:`~discord.AllowedMentions`] - Controls the mentions being processed in this message. + Controls the mentions being processed in this message. If this is + passed, then the object is merged with :attr:`~discord.Client.allowed_mentions`. + The merging behaviour only overrides attributes that have been explicitly passed + to the object, otherwise it uses the attributes set in :attr:`~discord.Client.allowed_mentions`. + If no object is passed at all then the defaults given by :attr:`~discord.Client.allowed_mentions` + are used instead. .. versionadded:: 1.4 + .. versionchanged:: 1.6 + :attr:`~discord.Client.allowed_mentions` serves as defaults unconditionally. + + mention_author: Optional[:class:`bool`] + Overrides the :attr:`~discord.AllowedMentions.replied_user` attribute + of ``allowed_mentions``. + + .. versionadded:: 1.6 Raises ------- @@ -881,17 +866,24 @@ class Message(Hashable): delete_after = fields.pop('delete_after', None) - try: - allowed_mentions = fields.pop('allowed_mentions') - except KeyError: - pass - else: - if allowed_mentions is not None: - if self._state.allowed_mentions is not None: - allowed_mentions = self._state.allowed_mentions.merge(allowed_mentions).to_dict() - else: - allowed_mentions = allowed_mentions.to_dict() - fields['allowed_mentions'] = allowed_mentions + mention_author = fields.pop('mention_author', None) + allowed_mentions = fields.pop('allowed_mentions', None) + if allowed_mentions is not None: + if self._state.allowed_mentions is not None: + allowed_mentions = self._state.allowed_mentions.merge(allowed_mentions) + allowed_mentions = allowed_mentions.to_dict() + if mention_author is not None: + allowed_mentions['replied_user'] = mention_author + fields['allowed_mentions'] = allowed_mentions + elif mention_author is not None: + if self._state.allowed_mentions is not None: + allowed_mentions = self._state.allowed_mentions.to_dict() + allowed_mentions['replied_user'] = mention_author + else: + allowed_mentions = {'replied_user': mention_author} + fields['allowed_mentions'] = allowed_mentions + elif self._state.allowed_mentions is not None: + fields['allowed_mentions'] = self._state.allowed_mentions.to_dict() if fields: data = await self._state.http.edit_message(self.channel.id, self.id, **fields) @@ -1127,3 +1119,42 @@ class Message(Hashable): if state.is_bot: raise ClientException('Must not be a bot account to ack messages.') return await state.http.ack_message(self.channel.id, self.id) + + async def reply(self, content=None, **kwargs): + """|coro| + + A shortcut method to :meth:`abc.Messageable.send` to reply to the + :class:`Message`. + + .. versionadded:: 1.6 + + Raises + -------- + ~discord.HTTPException + Sending the message failed. + ~discord.Forbidden + You do not have the proper permissions to send the message. + ~discord.InvalidArgument + The ``files`` list is not of the appropriate size or + you specified both ``file`` and ``files``. + + Returns + --------- + :class:`Message` + The message that was sent. + """ + + return await self.channel.send(content, reference=self, **kwargs) + + def to_reference(self): + """Creates a :class:`~discord.MessageReference` from the current message. + + .. versionadded:: 1.6 + + Returns + --------- + :class:`~discord.MessageReference` + The reference to this message. + """ + + return MessageReference.from_message(self) diff --git a/discord/message_reference.py b/discord/message_reference.py new file mode 100644 index 000000000..3dcb78710 --- /dev/null +++ b/discord/message_reference.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2020 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from . import utils + +class _MessageType: + __slots__ = () + +class MessageReference(_MessageType): + """Represents a reference to a :class:`~discord.Message`. + + .. versionadded:: 1.5 + + Attributes + ----------- + message_id: Optional[:class:`int`] + The id of the message referenced. + channel_id: :class:`int` + The channel id of the message referenced. + guild_id: Optional[:class:`int`] + The guild id of the message referenced. + """ + + __slots__ = ('message_id', 'channel_id', 'guild_id', '_state') + + def __init__(self, state, **kwargs): + self.message_id = utils._get_as_snowflake(kwargs, 'message_id') + self.channel_id = int(kwargs.pop('channel_id')) + self.guild_id = utils._get_as_snowflake(kwargs, 'guild_id') + self._state = state + + @classmethod + def from_message(cls, message): + """Creates a :class:`MessageReference` from an existing :class:`~discord.Message`. + + .. versionadded:: 1.6 + + Parameters + ---------- + message: :class:`~discord.Message` + The message to be converted into a reference. + + Returns + ------- + :class:`MessageReference` + A reference to the message. + """ + return cls(message._state, message_id=message.id, channel_id=message.channel.id, guild_id=getattr(message.guild, 'id', None)) + + @property + def cached_message(self): + """Optional[:class:`~discord.Message`]: The cached message, if found in the internal message cache.""" + return self._state._get_message(self.message_id) + + def __repr__(self): + return ''.format(self) + + def to_dict(self): + result = {'message_id': self.message_id} if self.message_id is not None else {} + result['channel_id'] = self.channel_id + if self.guild_id is not None: + result['guild_id'] = self.guild_id + return result