From c187d87dae6b094259440f8aa2a278fef38ae6d2 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Fri, 11 Nov 2016 03:12:43 -0500 Subject: [PATCH] Re-add support for reactions. We now store emojis in a global cache and make things like adding and removing reactions part of the stateful Message class. --- discord/__init__.py | 2 +- discord/emoji.py | 3 + discord/guild.py | 4 +- discord/message.py | 130 ++++++++++++++++++++++++++++++++++++++++++-- discord/reaction.py | 72 +++++++++++++++++++----- discord/state.py | 115 +++++++++++++++++---------------------- 6 files changed, 238 insertions(+), 88 deletions(-) diff --git a/discord/__init__.py b/discord/__init__.py index 4806ffbd1..345c0d6ca 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -20,7 +20,7 @@ __version__ = '0.16.0' from .client import Client, AppInfo, ChannelPermissions from .user import User from .game import Game -from .emoji import Emoji +from .emoji import Emoji, PartialEmoji from .channel import * from .guild import Guild from .member import Member, VoiceState diff --git a/discord/emoji.py b/discord/emoji.py index 1f06c7e45..55d0ab736 100644 --- a/discord/emoji.py +++ b/discord/emoji.py @@ -25,10 +25,13 @@ DEALINGS IN THE SOFTWARE. """ import asyncio +from collections import namedtuple from . import utils from .mixins import Hashable +PartialEmoji = namedtuple('PartialEmoji', 'id name') + class Emoji(Hashable): """Represents a custom emoji. diff --git a/discord/guild.py b/discord/guild.py index 3ed07b200..5bb7ca66f 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -237,7 +237,7 @@ class Guild(Hashable): self.id = int(guild['id']) self.roles = [Role(guild=self, data=r, state=self._state) for r in guild.get('roles', [])] self.mfa_level = guild.get('mfa_level') - self.emojis = [Emoji(server=self, data=r, state=self._state) for r in guild.get('emojis', [])] + self.emojis = tuple(map(lambda d: self._state.store_emoji(self, d), guild.get('emojis', []))) self.features = guild.get('features', []) self.splash = guild.get('splash') @@ -653,7 +653,7 @@ class Guild(Hashable): img = utils._bytes_to_base64_data(image) data = yield from self._state.http.create_custom_emoji(self.id, name, img) - return Emoji(guild=self, data=data, state=self._state) + return self._state.store_emoji(self, data) @asyncio.coroutine def create_role(self, **fields): diff --git a/discord/message.py b/discord/message.py index 403178a76..9a149e1d7 100644 --- a/discord/message.py +++ b/discord/message.py @@ -29,10 +29,12 @@ import re from .user import User from .reaction import Reaction +from .emoji import Emoji from . import utils, abc from .object import Object from .calls import CallMessage from .enums import MessageType, try_enum +from .errors import InvalidArgument class Message: """Represents a message from Discord. @@ -66,8 +68,6 @@ class Message: In :issue:`very rare cases <21>` this could be a :class:`Object` instead. For the sake of convenience, this :class:`Object` instance has an attribute ``is_private`` set to ``True``. - guild: Optional[:class:`Guild`] - The guild that the message belongs to. If not applicable (i.e. a PM) then it's None instead. call: Optional[:class:`CallMessage`] The call that the message refers to. This is only applicable to messages of type :attr:`MessageType.call`. @@ -112,16 +112,15 @@ class Message: __slots__ = ( 'edited_timestamp', 'tts', 'content', 'channel', 'webhook_id', 'mention_everyone', 'embeds', 'id', 'mentions', 'author', - '_cs_channel_mentions', 'guild', '_cs_raw_mentions', 'attachments', + '_cs_channel_mentions', '_cs_raw_mentions', 'attachments', '_cs_clean_content', '_cs_raw_channel_mentions', 'nonce', 'pinned', 'role_mentions', '_cs_raw_role_mentions', 'type', 'call', '_cs_system_content', '_state', 'reactions' ) def __init__(self, *, state, channel, data): self._state = state - self.reactions = kwargs.pop('reactions') - for reaction in self.reactions: - reaction.message = self + self.id = int(data['id']) + self.reactions = [Reaction(message=self, data=d) for d in data.get('reactions', [])] self._update(channel, data) def _try_patch(self, data, key, transform): @@ -132,6 +131,41 @@ class Message: else: setattr(self, key, transform(value)) + def _add_reaction(self, data): + emoji = self._state.reaction_emoji(data['emoji']) + reaction = utils.find(lambda r: r.emoji == emoji, self.reactions) + is_me = data['me'] = int(data['user_id']) == self._state.self_id + + if reaction is None: + reaction = Reaction(message=self, data=data, emoji=emoji) + self.reactions.append(reaction) + else: + reaction.count += 1 + if is_me: + reaction.me = is_me + + return reaction + + def _remove_reaction(self, data): + emoji = self._state.reaction_emoji(data['emoji']) + reaction = utils.find(lambda r: r.emoji == emoji, self.reactions) + + if reaction is None: + # already removed? + raise ValueError('Emoji already removed?') + + # if reaction isn't in the list, we crash. This means discord + # sent bad data, or we stored improperly + reaction.count -= 1 + + if int(data['user_id']) == self._state.self_id: + reaction.me = False + if reaction.count == 0: + # this raises ValueError if something went wrong as well. + self.reactions.remove(reaction) + + return reaction + def _update(self, channel, data): self.channel = channel for handler in ('mentions', 'mention_roles', 'call'): @@ -198,6 +232,11 @@ class Message: call['participants'] = participants self.call = CallMessage(message=self, **call) + @property + def guild(self): + """Optional[:class:`Guild`]: The guild that the message belongs to, if applicable.""" + return getattr(self.channel, 'guild', None) + @utils.cached_slot_property('_cs_raw_mentions') def raw_mentions(self): """A property that returns an array of user IDs matched with @@ -428,3 +467,82 @@ class Message: yield from self._state.http.unpin_message(self.channel.id, self.id) self.pinned = False + + @asyncio.coroutine + def add_reaction(self, emoji): + """|coro| + + Add a reaction to the message. + + The emoji may be a unicode emoji or a custom server :class:`Emoji`. + + You must have the :attr:`Permissions.add_reactions` permission to + add new reactions to a message. + + Parameters + ------------ + emoji: :class:`Emoji` or str + The emoji to react with. + + Raises + -------- + HTTPException + Adding the reaction failed. + Forbidden + You do not have the proper permissions to react to the message. + NotFound + The emoji you specified was not found. + InvalidArgument + The emoji parameter is invalid. + """ + + if isinstance(emoji, Emoji): + emoji = '%s:%s' % (emoji.name, emoji.id) + elif isinstance(emoji, str): + pass # this is okay + else: + raise InvalidArgument('emoji argument must be a string or discord.Emoji') + + yield from self._state.http.add_reaction(self.id, self.channel.id, emoji) + + @asyncio.coroutine + def remove_reaction(self, emoji, member): + """|coro| + + Remove a reaction by the member from the message. + + The emoji may be a unicode emoji or a custom server :class:`Emoji`. + + If the reaction is not your own (i.e. ``member`` parameter is not you) then + the :attr:`Permissions.manage_messages` permission is needed. + + The ``member`` parameter must represent a member and meet + the :class:`abc.Snowflake` abc. + + Parameters + ------------ + emoji: :class:`Emoji` or str + The emoji to remove. + member: :class:`abc.Snowflake` + The member for which to remove the reaction. + + Raises + -------- + HTTPException + Removing the reaction failed. + Forbidden + You do not have the proper permissions to remove the reaction. + NotFound + The member or emoji you specified was not found. + InvalidArgument + The emoji parameter is invalid. + """ + + if isinstance(emoji, Emoji): + emoji = '%s:%s' % (emoji.name, emoji.id) + elif isinstance(emoji, str): + pass # this is okay + else: + raise InvalidArgument('emoji argument must be a string or discord.Emoji') + + yield from self._state.http.remove_reaction(self.id, self.channel.id, emoji, member.id) diff --git a/discord/reaction.py b/discord/reaction.py index 2e4f3ce4f..2e47e397d 100644 --- a/discord/reaction.py +++ b/discord/reaction.py @@ -24,7 +24,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from .emoji import Emoji +import asyncio + +from .user import User class Reaction: """Represents a reaction to a message. @@ -48,25 +50,27 @@ class Reaction: Attributes ----------- - emoji : :class:`Emoji` or str + emoji: :class:`Emoji` or str The reaction emoji. May be a custom emoji, or a unicode emoji. - custom_emoji : bool - If this is a custom emoji. - count : int + count: int Number of times this reaction was made - me : bool + me: bool If the user sent this reaction. message: :class:`Message` Message this reaction is for. """ - __slots__ = ['message', 'count', 'emoji', 'me', 'custom_emoji'] + __slots__ = ('message', 'count', 'emoji', 'me') + + def __init__(self, *, message, data, emoji=None): + self.message = message + self.emoji = message._state.reaction_emoji(data['emoji']) if emoji is None else emoji + self.count = data.get('count', 1) + self.me = data.get('me') - def __init__(self, **kwargs): - self.message = kwargs.get('message') - self.emoji = kwargs['emoji'] - self.count = kwargs.get('count', 1) - self.me = kwargs.get('me') - self.custom_emoji = isinstance(self.emoji, Emoji) + @property + def custom_emoji(self): + """bool: If this is a custom emoji.""" + return not isinstance(self.emoji, str) def __eq__(self, other): return isinstance(other, self.__class__) and other.emoji == self.emoji @@ -78,3 +82,45 @@ class Reaction: def __hash__(self): return hash(self.emoji) + + @asyncio.coroutine + def users(self, limit=100, after=None): + """|coro| + + Get the users that added this reaction. + + The ``after`` parameter must represent a member + and meet the :class:`abc.Snowflake` abc. + + Parameters + ------------ + limit: int + The maximum number of results to return. + after: :class:`abc.Snowflake` + For pagination, reactions are sorted by member. + + Raises + -------- + HTTPException + Getting the users for the reaction failed. + + Returns + -------- + List[:class:`User`] + A list of users who reacted to the message. + """ + + # TODO: Return an iterator a la `MessageChannel.history`? + + if self.custom_emoji: + emoji = '{0.name}:{0.id}'.format(self.emoji) + else: + emoji = self.emoji + + if after: + after = after.id + + msg = self.message + state = msg._state + data = yield from state.http.get_reaction_users(msg.id, msg.channel.id, emoji, limit, after=after) + return [User(state=state, data=user) for user in data] diff --git a/discord/state.py b/discord/state.py index c9a330b35..9df9f3658 100644 --- a/discord/state.py +++ b/discord/state.py @@ -27,7 +27,7 @@ DEALINGS IN THE SOFTWARE. from .guild import Guild from .user import User from .game import Game -from .emoji import Emoji +from .emoji import Emoji, PartialEmoji from .reaction import Reaction from .message import Message from .channel import * @@ -47,10 +47,16 @@ class ListenerType(enum.Enum): chunk = 0 Listener = namedtuple('Listener', ('type', 'future', 'predicate')) -StateContext = namedtuple('StateContext', 'store_user http self_id') log = logging.getLogger(__name__) ReadyState = namedtuple('ReadyState', ('launch', 'guilds')) +class StateContext: + __slots__ = ('store_user', 'http', 'self_id', 'store_emoji', 'reaction_emoji') + + def __init__(self, **kwargs): + for attr, value in kwargs.items(): + setattr(self, attr, value) + class ConnectionState: def __init__(self, *, dispatch, chunker, syncer, http, loop, **options): self.loop = loop @@ -60,7 +66,10 @@ class ConnectionState: self.syncer = syncer self.is_bot = None self._listeners = [] - self.ctx = StateContext(store_user=self.store_user, http=http, self_id=None) + self.ctx = StateContext(store_user=self.store_user, + store_emoji=self.store_emoji, + reaction_emoji=self._get_reaction_emoji, + http=http, self_id=None) self.clear() def clear(self): @@ -69,6 +78,7 @@ class ConnectionState: self.session_id = None self._calls = {} self._users = {} + self._emojis = {} self._guilds = {} self._voice_clients = {} self._private_channels = {} @@ -128,6 +138,14 @@ class ConnectionState: self._users[user_id] = user = User(state=self.ctx, data=data) return user + def store_emoji(self, guild, data): + emoji_id = int(data['id']) + try: + return self._emojis[emoji_id] + except KeyError: + self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self.ctx, data=data) + return emoji + @property def guilds(self): return self._guilds.values() @@ -274,26 +292,11 @@ class ConnectionState: self.dispatch('message_edit', older_message, message) def parse_message_reaction_add(self, data): - message = self._get_message(data['message_id']) + message = self._get_message(int(data['message_id'])) if message is not None: - emoji = self._get_reaction_emoji(**data.pop('emoji')) - reaction = utils.get(message.reactions, emoji=emoji) - - is_me = data['user_id'] == self.user.id - - if not reaction: - reaction = Reaction( - message=message, emoji=emoji, me=is_me, **data) - message.reactions.append(reaction) - else: - reaction.count += 1 - if is_me: - reaction.me = True - - channel = self.get_channel(data['channel_id']) - member = self._get_member(channel, data['user_id']) - - self.dispatch('reaction_add', reaction, member) + reaction = message._add_reaction(data) + user = self._get_reaction_user(message.channel, int(data['user_id'])) + self.dispatch('reaction_add', reaction, user) def parse_message_reaction_remove_all(self, data): message = self._get_message(data['message_id']) @@ -303,26 +306,15 @@ class ConnectionState: self.dispatch('reaction_clear', message, old_reactions) def parse_message_reaction_remove(self, data): - message = self._get_message(data['message_id']) + message = self._get_message(int(data['message_id'])) if message is not None: - emoji = self._get_reaction_emoji(**data['emoji']) - reaction = utils.get(message.reactions, emoji=emoji) - - # Eventual consistency means we can get out of order or duplicate removes. - if not reaction: - log.warning("Unexpected reaction remove {}".format(data)) - return - - reaction.count -= 1 - if data['user_id'] == self.user.id: - reaction.me = False - if reaction.count == 0: - message.reactions.remove(reaction) - - channel = self.get_channel(data['channel_id']) - member = self._get_member(channel, data['user_id']) - - self.dispatch('reaction_remove', reaction, member) + try: + reaction = message._remove_reaction(data) + except (AttributeError, ValueError) as e: # eventual consistency lol + pass + else: + user = self._get_reaction_user(message.channel, int(data['user_id'])) + self.dispatch('reaction_remove', reaction, user) def parse_presence_update(self, data): guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id')) @@ -462,7 +454,7 @@ class ConnectionState: def parse_guild_emojis_update(self, data): guild = self._get_guild(int(data['guild_id'])) before_emojis = guild.emojis - guild.emojis = [Emoji(guild=guild, data=e, state=self.ctx) for e in data.get('emojis', [])] + guild.emojis = tuple(map(lambda d: self.store_emoji(guild, d), data['emojis'])) self.dispatch('guild_emojis_update', before_emojis, guild.emojis) def _get_create_guild(self, data): @@ -675,35 +667,26 @@ class ConnectionState: if call is not None: self.dispatch('call_remove', call) - def _get_member(self, channel, id): - if channel.is_private: - return utils.get(channel.recipients, id=id) + def _get_reaction_user(self, channel, user_id): + if isinstance(channel, DMChannel) and user_id == channel.recipient.id: + return channel.recipient + elif isinstance(channel, TextChannel): + return channel.guild.get_member(user_id) + elif isinstance(channel, GroupChannel): + return utils.find(lambda m: m.id == user_id, channel.recipients) else: - return channel.server.get_member(id) - - def _create_message(self, **message): - """Helper mostly for injecting reactions.""" - reactions = [ - self._create_reaction(**r) for r in message.pop('reactions', []) - ] - return Message(channel=message.pop('channel'), - reactions=reactions, **message) - - def _create_reaction(self, **reaction): - emoji = self._get_reaction_emoji(**reaction.pop('emoji')) - return Reaction(emoji=emoji, **reaction) + return None - def _get_reaction_emoji(self, **data): - id = data['id'] + def _get_reaction_emoji(self, data): + emoji_id = utils._get_as_snowflake(data, 'id') - if not id: + if not emoji_id: return data['name'] - for server in self.servers: - for emoji in server.emojis: - if emoji.id == id: - return emoji - return Emoji(server=None, **data) + try: + return self._emojis[emoji_id] + except KeyError: + return PartialEmoji(id=emoji_id, name=data['name']) def get_channel(self, id): if id is None: