diff --git a/discord/client.py b/discord/client.py index ff2479590..b1b87ca95 100644 --- a/discord/client.py +++ b/discord/client.py @@ -953,7 +953,7 @@ class Client: data = yield from self.http.send_message(channel_id, content, guild_id=guild_id, tts=tts) channel = self.get_channel(data.get('channel_id')) - message = Message(channel=channel, **data) + message = self.connection._create_message(channel=channel, **data) return message @asyncio.coroutine @@ -1035,7 +1035,7 @@ class Client: data = yield from self.http.send_file(channel_id, buffer, guild_id=guild_id, filename=filename, content=content, tts=tts) channel = self.get_channel(data.get('channel_id')) - message = Message(channel=channel, **data) + message = self.connection._create_message(channel=channel, **data) return message @asyncio.coroutine @@ -1234,7 +1234,7 @@ class Client: content = str(new_content) guild_id = channel.server.id if not getattr(channel, 'is_private', True) else None data = yield from self.http.edit_message(message.id, channel.id, content, guild_id=guild_id) - return Message(channel=channel, **data) + return self.connection._create_message(channel=channel, **data) @asyncio.coroutine def get_message(self, channel, id): @@ -1267,7 +1267,7 @@ class Client: """ data = yield from self.http.get_message(channel.id, id) - return Message(channel=channel, **data) + return self.connection._create_message(channel=channel, **data) @asyncio.coroutine def pin_message(self, message): @@ -1337,7 +1337,7 @@ class Client: """ data = yield from self.http.pins_from(channel.id) - return [Message(channel=channel, **m) for m in data] + return [self.connection._create_message(channel=channel, **m) for m in data] def _logs_from(self, channel, limit=100, before=None, after=None, around=None): """|coro| @@ -1418,7 +1418,7 @@ class Client: def generator(data): for message in data: - yield Message(channel=channel, **message) + yield self.connection._create_message(channel=channel, **message) result = [] while limit > 0: diff --git a/discord/iterators.py b/discord/iterators.py index fbf1a72c6..2ea514367 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -72,6 +72,7 @@ class LogsFromIterator: def __init__(self, client, channel, limit, before=None, after=None, around=None, reverse=False): self.client = client + self.connection = client.connection self.channel = channel self.limit = limit self.before = before @@ -125,7 +126,9 @@ class LogsFromIterator: if self._filter: data = filter(self._filter, data) for element in data: - yield from self.messages.put(Message(channel=self.channel, **element)) + yield from self.messages.put( + self.connection._create_message( + channel=self.channel, **element)) @asyncio.coroutine def _retrieve_messages(self, retrieve): diff --git a/discord/message.py b/discord/message.py index d2bdf87e5..e6e2fdd17 100644 --- a/discord/message.py +++ b/discord/message.py @@ -115,6 +115,9 @@ class Message: '_system_content', 'reactions' ] def __init__(self, **kwargs): + self.reactions = kwargs.pop('reactions') + for reaction in self.reactions: + reaction.message = self self._update(**kwargs) def _update(self, **data): @@ -138,7 +141,6 @@ class Message: self._handle_upgrades(data.get('channel_id')) self._handle_mentions(data.get('mentions', []), data.get('mention_roles', [])) self._handle_call(data.get('call')) - self.reactions = [Reaction(message=self, **reaction) for reaction in data.get('reactions', [])] # clear the cached properties cached = filter(lambda attr: attr[0] == '_', self.__slots__) diff --git a/discord/reaction.py b/discord/reaction.py index ec30fa226..7232a7b44 100644 --- a/discord/reaction.py +++ b/discord/reaction.py @@ -62,19 +62,11 @@ class Reaction: __slots__ = ['message', 'count', 'emoji', 'me', 'custom_emoji'] def __init__(self, **kwargs): - self.message = kwargs.pop('message') - self._from_data(kwargs) - - def _from_data(self, reaction): - self.count = reaction.get('count', 1) - self.me = reaction.get('me') - emoji = reaction['emoji'] - if emoji['id']: - self.custom_emoji = True - self.emoji = Emoji(server=None, id=emoji['id'], name=emoji['name']) - else: - self.custom_emoji = False - self.emoji = emoji['name'] + self.message = kwargs.get('message') + self.emoji = kwargs['emoji'] + self.count = kwargs.get('count', 1) + self.me = kwargs.get('me') + self.custom_emoji = isinstance(self.emoji, Emoji) def __eq__(self, other): return isinstance(other, self.__class__) and other.emoji == self.emoji diff --git a/discord/state.py b/discord/state.py index 4d3855fc2..00b0e06fa 100644 --- a/discord/state.py +++ b/discord/state.py @@ -219,7 +219,7 @@ class ConnectionState: def parse_message_create(self, data): channel = self.get_channel(data.get('channel_id')) - message = Message(channel=channel, **data) + message = self._create_message(channel=channel, **data) self.dispatch('message', message) self.messages.append(message) @@ -255,17 +255,14 @@ class ConnectionState: def parse_message_reaction_add(self, data): message = self._get_message(data['message_id']) if message is not None: - if data['emoji']['id']: - reaction_emoji = Emoji(server=None, **data['emoji']) - else: - reaction_emoji = data['emoji']['name'] - reaction = utils.get( - message.reactions, emoji=reaction_emoji) + emoji = self._get_reaction_emoji(**data.pop('emoji')) + reaction = utils.get(message.reactions, emoji=emoji) is_me = data['user_id'] == self.user.id if not reaction: - reaction = Reaction(message=message, me=is_me, **data) + reaction = Reaction( + message=message, emoji=emoji, me=is_me, **data) message.reactions.append(reaction) else: reaction.count += 1 @@ -280,12 +277,8 @@ class ConnectionState: def parse_message_reaction_remove(self, data): message = self._get_message(data['message_id']) if message is not None: - if data['emoji']['id']: - reaction_emoji = Emoji(server=None, **data['emoji']) - else: - reaction_emoji = data['emoji']['name'] - reaction = utils.get( - message.reactions, emoji=reaction_emoji) + emoji = self._get_reaction_emoji(**data['emoji']) + reaction = utils.get(message.reactions, emoji=emoji) # if reaction isn't in the list, we crash. This means discord # sent bad data, or we stored improperly @@ -680,6 +673,30 @@ class ConnectionState: else: return channel.server.get_member(id) + def _create_message(self, **message): + """Helper mostly for injecting reactions.""" + reactions = [ + self._create_reaction(**r) for r in message.pop('reactions', []) + ] + return Message(channel=message.pop('channel'), + reactions=reactions, **message) + + def _create_reaction(self, **reaction): + emoji = self._get_reaction_emoji(**reaction.pop('emoji')) + return Reaction(emoji=emoji, **reaction) + + def _get_reaction_emoji(self, **data): + id = data['id'] + + if id is None: + return data['name'] + + for server in self.servers: + for emoji in server.emojis: + if emoji.id == id: + return emoji + return Emoji(server=None, **data) + def get_channel(self, id): if id is None: return None