From 041785937e091b7e282403d45dd0c68da340a8d4 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sat, 4 Apr 2020 07:40:51 -0400 Subject: [PATCH] Add support for configuring allowed mentions per message or bot wide. --- discord/__init__.py | 1 + discord/abc.py | 18 +++++++-- discord/client.py | 6 ++- discord/http.py | 9 ++++- discord/mentions.py | 98 +++++++++++++++++++++++++++++++++++++++++++++ discord/state.py | 7 ++++ discord/webhook.py | 14 ++++++- docs/api.rst | 6 +++ 8 files changed, 152 insertions(+), 7 deletions(-) create mode 100644 discord/mentions.py diff --git a/discord/__init__.py b/discord/__init__.py index 2cca80ae9..3bc6a02d3 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -46,6 +46,7 @@ from .reaction import Reaction from . import utils, opus, abc from .enums import * from .embeds import Embed +from .mentions import AllowedMentions from .shard import AutoShardedClient from .player import * from .webhook import * diff --git a/discord/abc.py b/discord/abc.py index 2a14c3490..b4a07791c 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -768,7 +768,9 @@ class Messageable(metaclass=abc.ABCMeta): async def _get_channel(self): raise NotImplementedError - async def send(self, content=None, *, tts=False, embed=None, file=None, files=None, delete_after=None, nonce=None): + async def send(self, content=None, *, tts=False, embed=None, file=None, + files=None, delete_after=None, nonce=None, + mentions=None): """|coro| Sends a message to the destination with the content given. @@ -804,6 +806,10 @@ class Messageable(metaclass=abc.ABCMeta): If provided, the number of seconds to wait in the background before deleting the message we just sent. If the deletion fails, then it is silently ignored. + mentions: :class:`AllowedMentions` + Controls the mentions being processed in this message. + + .. versionadded:: 1.4 Raises -------- @@ -827,6 +833,12 @@ class Messageable(metaclass=abc.ABCMeta): if embed is not None: embed = embed.to_dict() + if mentions is not None: + if state.mentions is not None: + mentions = state.mentions.merge(mentions).to_dict() + else: + mentions = mentions.to_dict() + if file is not None and files is not None: raise InvalidArgument('cannot pass both file and files parameter to send()') @@ -848,12 +860,12 @@ class Messageable(metaclass=abc.ABCMeta): try: data = await state.http.send_files(channel.id, files=files, content=content, tts=tts, - embed=embed, nonce=nonce) + embed=embed, nonce=nonce, mentions=mentions) finally: for f in files: f.close() else: - data = await state.http.send_message(channel.id, content, tts=tts, embed=embed, nonce=nonce) + data = await state.http.send_message(channel.id, content, tts=tts, embed=embed, nonce=nonce, mentions=mentions) ret = state.create_message(channel=channel, data=data) if delete_after is not None: diff --git a/discord/client.py b/discord/client.py index 757886eff..bbc323d6a 100644 --- a/discord/client.py +++ b/discord/client.py @@ -142,13 +142,17 @@ class Client: The total number of shards. fetch_offline_members: :class:`bool` Indicates if :func:`.on_ready` should be delayed to fetch all offline - members from the guilds the bot belongs to. If this is ``False``\, then + members from the guilds the client belongs to. If this is ``False``\, then no offline members are received and :meth:`request_offline_members` must be used to fetch the offline members of the guild. status: Optional[:class:`.Status`] A status to start your presence with upon logging on to Discord. activity: Optional[:class:`.BaseActivity`] An activity to start your presence with upon logging on to Discord. + mention: Optional[:class:`AllowedMentions`] + Control how the client handles mentions by default on every message sent. + + .. versionadded:: 1.4 heartbeat_timeout: :class:`float` The maximum numbers of seconds before timing out and restarting the WebSocket in the case of not receiving a HEARTBEAT_ACK. Useful if diff --git a/discord/http.py b/discord/http.py index 4837bd2f0..40ba7e335 100644 --- a/discord/http.py +++ b/discord/http.py @@ -310,7 +310,7 @@ class HTTPClient: return self.request(Route('POST', '/users/@me/channels'), json=payload) - def send_message(self, channel_id, content, *, tts=False, embed=None, nonce=None): + def send_message(self, channel_id, content, *, tts=False, embed=None, nonce=None, mentions=None): r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id) payload = {} @@ -326,12 +326,15 @@ class HTTPClient: if nonce: payload['nonce'] = nonce + if mentions: + payload['allowed_mentions'] = mentions + return self.request(r, json=payload) def send_typing(self, channel_id): return self.request(Route('POST', '/channels/{channel_id}/typing', channel_id=channel_id)) - def send_files(self, channel_id, *, files, content=None, tts=False, embed=None, nonce=None): + def send_files(self, channel_id, *, files, content=None, tts=False, embed=None, nonce=None, mentions=None): r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id) form = aiohttp.FormData() @@ -342,6 +345,8 @@ class HTTPClient: payload['embed'] = embed if nonce: payload['nonce'] = nonce + if mentions: + payload['allowed_mentions'] = mentions form.add_field('payload_json', utils.to_json(payload)) if len(files) == 1: diff --git a/discord/mentions.py b/discord/mentions.py new file mode 100644 index 000000000..70aa8d44b --- /dev/null +++ b/discord/mentions.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2020 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +class _FakeBool: + def __repr__(self): + return 'True' + + def __eq__(self, other): + return other is True + + def __bool__(self): + return True + +default = _FakeBool() + +class AllowedMentions: + """A class that represents what mentions are allowed in a message. + + This class can be set during :class:`Client` initialization to apply + to every message sent. It can also be applied on a per message basis + via :meth:`abc.Messageable.send` for more fine-grained control. + + Attributes + ------------ + everyone: :class:`bool` + Whether to allow everyone and here mentions. Defaults to ``True``. + users: Union[:class:`bool`, List[:class:`abc.Snowflake`]] + Controls the users being mentioned. If ``True`` (the default) then + users are mentioned based on the message content. If ``False`` then + users are not mentioned at all. If a list of :class:`abc.Snowflake` + is given then only the users provided will be mentioned, provided those + users are in the message content. + roles: Union[:class:`bool`, List[:class:`abc.Snowflake`]] + Controls the roles being mentioned. If ``True`` (the default) then + roles are mentioned based on the message content. If ``False`` then + roles are not mentioned at all. If a list of :class:`abc.Snowflake` + is given then only the roles provided will be mentioned, provided those + roles are in the message content. + """ + + __slots__ = ('everyone', 'users', 'roles') + + def __init__(self, *, everyone=default, users=default, roles=default): + self.everyone = everyone + self.users = users + self.roles = roles + + def to_dict(self): + parse = [] + data = {} + + if self.everyone: + parse.append('everyone') + + if self.users == True: + parse.append('users') + elif self.users != False: + data['users'] = [x.id for x in self.users] + + if self.roles == True: + parse.append('roles') + elif self.roles != False: + data['roles'] = [x.id for x in self.roles] + + data['parse'] = parse + return data + + def merge(self, other): + # Creates a new AllowedMentions by merging from another one. + # Merge is done by using the 'self' values unless explicitly + # overridden by the 'other' values. + everyone = self.everyone if other.everyone is default else other.everyone + users = self.users if other.users is default else other.users + roles = self.roles if other.roles is default else other.roles + return AllowedMentions(everyone=everyone, roles=roles, users=users) diff --git a/discord/state.py b/discord/state.py index e4ecdc459..e3ba299c9 100644 --- a/discord/state.py +++ b/discord/state.py @@ -39,6 +39,7 @@ from .guild import Guild from .activity import BaseActivity from .user import User, ClientUser from .emoji import Emoji +from .mentions import AllowedMentions from .partial_emoji import PartialEmoji from .message import Message from .relationship import Relationship @@ -78,6 +79,12 @@ class ConnectionState: self._fetch_offline = options.get('fetch_offline_members', True) self.heartbeat_timeout = options.get('heartbeat_timeout', 60.0) self.guild_subscriptions = options.get('guild_subscriptions', True) + mentions = options.get('mentions') + + if mentions is not None and not isinstance(mentions, AllowedMentions): + raise TypeError('mentions parameter must be AllowedMentions') + + self.mentions = mentions # Only disable cache if both fetch_offline and guild_subscriptions are off. self._cache_members = (self._fetch_offline or self.guild_subscriptions) self._listeners = [] diff --git a/discord/webhook.py b/discord/webhook.py index 2fcfd0808..058e48553 100644 --- a/discord/webhook.py +++ b/discord/webhook.py @@ -688,7 +688,7 @@ class Webhook: return self._adapter.edit_webhook(**payload) def send(self, content=None, *, wait=False, username=None, avatar_url=None, tts=False, - file=None, files=None, embed=None, embeds=None): + file=None, files=None, embed=None, embeds=None, mentions=None): """|maybecoro| Sends a message using the webhook. @@ -732,6 +732,10 @@ class Webhook: embeds: List[:class:`Embed`] A list of embeds to send with the content. Maximum of 10. This cannot be mixed with the ``embed`` parameter. + mentions: :class:`AllowedMentions` + Controls the mentions being processed in this message. + + .. versionadded:: 1.4 Raises -------- @@ -777,6 +781,14 @@ class Webhook: if username: payload['username'] = username + if mentions: + try: + mentions = self._state.mentions.merge(mentions).to_dict() + except AttributeError: + mentions = mentions.to_dict() + finally: + payload['allowed_mentions'] = mentions + return self._adapter.execute_webhook(wait=wait, file=file, files=files, payload=payload) def execute(self, *args, **kwargs): diff --git a/docs/api.rst b/docs/api.rst index 1ba9b2f8c..7cf30bec4 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -2569,6 +2569,12 @@ Embed .. autoclass:: Embed :members: +AllowedMentions +~~~~~~~~~~~~~~~~~ + +.. autoclass:: AllowedMentions + :members: + File ~~~~~