From 7431a127cf33802c1bcc65beda8eb08cfa2b6655 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Tue, 3 Jan 2017 09:05:08 -0500 Subject: [PATCH] Change Messageable channel getter to be a coroutine. --- discord/abc.py | 16 +++++++++------- discord/channel.py | 3 +++ discord/ext/commands/context.py | 1 + discord/iterators.py | 16 ++++++++++------ 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/discord/abc.py b/discord/abc.py index 8ef15db30..89229279e 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -467,6 +467,7 @@ class GuildChannel: class Messageable(metaclass=abc.ABCMeta): __slots__ = () + @asyncio.coroutine @abc.abstractmethod def _get_channel(self): raise NotImplementedError @@ -534,7 +535,7 @@ class Messageable(metaclass=abc.ABCMeta): The message that was sent. """ - channel = self._get_channel() + channel = yield from self._get_channel() guild_id = self._get_guild_id() state = self._state content = str(content) if content else None @@ -576,7 +577,7 @@ class Messageable(metaclass=abc.ABCMeta): *Typing* indicator will go away after 10 seconds, or after a message is sent. """ - channel = self._get_channel() + channel = yield from self._get_channel() yield from self._state.http.send_typing(channel.id) def typing(self): @@ -596,7 +597,8 @@ class Messageable(metaclass=abc.ABCMeta): await channel.send_message('done!') """ - return Typing(self._get_channel()) + channel = yield from self._get_channel() + return Typing(channel) @asyncio.coroutine def get_message(self, id): @@ -626,7 +628,7 @@ class Messageable(metaclass=abc.ABCMeta): Retrieving the message failed. """ - channel = self._get_channel() + channel = yield from self._get_channel() data = yield from self._state.http.get_message(channel.id, id) return state.create_message(channel=channel, data=data) @@ -660,7 +662,7 @@ class Messageable(metaclass=abc.ABCMeta): raise ClientException('Can only delete messages in the range of [2, 100]') message_ids = [m.id for m in messages] - channel = self._get_channel() + channel = yield from self._get_channel() guild_id = self._get_guild_id() yield from self._state.http.delete_messages(channel.id, message_ids, guild_id) @@ -677,7 +679,7 @@ class Messageable(metaclass=abc.ABCMeta): Retrieving the pinned messages failed. """ - channel = self._get_channel() + channel = yield from self._get_channel() state = self._state data = yield from state.http.pins_from(channel.id) return [state.create_message(channel=channel, data=m) for m in data] @@ -745,7 +747,7 @@ class Messageable(metaclass=abc.ABCMeta): if message.author == client.user: counter += 1 """ - return LogsFromIterator(self._get_channel(), limit=limit, before=before, after=after, around=around, reverse=reverse) + return LogsFromIterator(self, limit=limit, before=before, after=after, around=around, reverse=reverse) @asyncio.coroutine def purge(self, *, limit=100, check=None, before=None, after=None, around=None): diff --git a/discord/channel.py b/discord/channel.py index 89a6051e9..8efae5462 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -88,6 +88,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): self.position = data['position'] self._fill_overwrites(data) + @asyncio.coroutine def _get_channel(self): return self @@ -262,6 +263,7 @@ class DMChannel(discord.abc.Messageable, Hashable): self.me = me self.id = int(data['id']) + @asyncio.coroutine def _get_channel(self): return self @@ -360,6 +362,7 @@ class GroupChannel(discord.abc.Messageable, Hashable): else: self.owner = utils.find(lambda u: u.id == owner_id, self.recipients) + @asyncio.coroutine def _get_channel(self): return self diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index ff78c562a..59f091179 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -117,6 +117,7 @@ class Context(discord.abc.Messageable): ret = yield from command.callback(*arguments, **kwargs) return ret + @asyncio.coroutine def _get_channel(self): return self.channel diff --git a/discord/iterators.py b/discord/iterators.py index 86b5e9de6..3ac75593e 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -70,7 +70,7 @@ class LogsFromIterator: will be out of order. """ - def __init__(self, channel, limit, + def __init__(self, messageable, limit, before=None, after=None, around=None, reverse=None): if isinstance(before, datetime.datetime): @@ -80,9 +80,7 @@ class LogsFromIterator: if isinstance(around, datetime.datetime): around = Object(id=time_snowflake(around)) - self.channel = channel - self.ctx = channel._state - self.logs_from = channel._state.http.logs_from + self.messageable = messageable self.limit = limit self.before = before self.after = after @@ -135,6 +133,13 @@ class LogsFromIterator: @asyncio.coroutine def fill_messages(self): + if not hasattr(self, 'channel'): + # do the required set up + channel = yield from self.messageable._get_channel() + self.channel = channel + self.state = channel._state + self.logs_from = channel._state.http.logs_from + if self.limit > 0: retrieve = self.limit if self.limit <= 100 else 100 data = yield from self._retrieve_messages(retrieve) @@ -144,9 +149,8 @@ class LogsFromIterator: data = filter(self._filter, data) channel = self.channel - state = self.ctx for element in data: - yield from self.messages.put(state.create_message(channel=channel, data=element)) + yield from self.messages.put(self.state.create_message(channel=channel, data=element)) @asyncio.coroutine def _retrieve_messages(self, retrieve):