diff --git a/disco/api/client.py b/disco/api/client.py index 04b9b32..abc261f 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -155,8 +155,11 @@ class APIClient(LoggingClass): def channels_messages_delete_bulk(self, channel, messages): self.http(Routes.CHANNELS_MESSAGES_DELETE_BULK, dict(channel=channel), json={'messages': messages}) - def channels_messages_reactions_get(self, channel, message, emoji): - r = self.http(Routes.CHANNELS_MESSAGES_REACTIONS_GET, dict(channel=channel, message=message, emoji=emoji)) + def channels_messages_reactions_get(self, channel, message, emoji, after=None, limit=100): + r = self.http( + Routes.CHANNELS_MESSAGES_REACTIONS_GET, + dict(channel=channel, message=message, emoji=emoji), + params={'after': after, 'limit': limit}) return User.create_map(self.client, r.json()) def channels_messages_reactions_create(self, channel, message, emoji): diff --git a/disco/api/http.py b/disco/api/http.py index 9d15550..732f88c 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -22,8 +22,9 @@ HTTPMethod = Enum( def to_bytes(obj): - if isinstance(obj, six.text_type): - return obj.encode('utf-8') + if six.PY2: + if isinstance(obj, six.text_type): + return obj.encode('utf-8') return obj diff --git a/disco/types/message.py b/disco/types/message.py index 0d1f219..7b5bcb0 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -10,6 +10,7 @@ from disco.types.base import ( SlottedModel, Field, ListField, AutoDictField, snowflake, text, datetime, enum ) +from disco.util.paginator import Paginator from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property from disco.types.user import User @@ -300,7 +301,22 @@ class Message(SlottedModel): """ return self.client.api.channels_messages_delete(self.channel_id, self.id) - def get_reactors(self, emoji): + def get_reactors_iter(self, emoji, *args, **kwargs): + """ + Returns an iterator which paginates the reactors for the given emoji. + """ + if isinstance(emoji, Emoji): + emoji = emoji.to_string() + + return Paginator( + self.client.api.channels_messages_reactions_get, + self.channel_id, + self.id, + emoji, + *args, + **kwargs) + + def get_reactors(self, emoji, *args, **kwargs): """ Returns an list of users who reacted to this message with the given emoji. @@ -309,11 +325,15 @@ class Message(SlottedModel): list(:class:`User`) The users who reacted. """ + if isinstance(emoji, Emoji): + emoji = emoji.to_string() + return self.client.api.channels_messages_reactions_get( self.channel_id, self.id, - emoji - ) + emoji, + *args, + **kwargs) def create_reaction(self, emoji): warnings.warn( @@ -322,8 +342,17 @@ class Message(SlottedModel): return self.add_reaction(emoji) def add_reaction(self, emoji): + """ + Adds a reaction to the message. + + Parameters + ---------- + emoji : Emoji|str + An emoji or string representing an emoji + """ if isinstance(emoji, Emoji): emoji = emoji.to_string() + self.client.api.channels_messages_reactions_create( self.channel_id, self.id, diff --git a/disco/util/paginator.py b/disco/util/paginator.py new file mode 100644 index 0000000..f690cff --- /dev/null +++ b/disco/util/paginator.py @@ -0,0 +1,43 @@ +import operator + + +class Paginator(object): + """ + Implements a class which provides paginated iteration over an endpoint. + """ + def __init__(self, func, *args, **kwargs): + self.func = func + self.args = args + self.kwargs = kwargs + + self._key = kwargs.pop('key', operator.attrgetter('id')) + self._bulk = kwargs.pop('bulk', False) + self._after = kwargs.pop('after', None) + self._buffer = [] + + def fill(self): + self.kwargs['after'] = self._after + result = self.func(*self.args, **self.kwargs) + + if not len(result): + raise StopIteration + + self._buffer.extend(result) + self._after = self._key(result[-1]) + + def next(self): + return self.__next__() + + def __iter__(self): + return self + + def __next__(self): + if not len(self._buffer): + self.fill() + + if self._bulk: + res = self._buffer + self._buffer = [] + return res + else: + return self._buffer.pop() diff --git a/examples/basic_plugin.py b/examples/basic_plugin.py index a48b21b..60c5e8b 100644 --- a/examples/basic_plugin.py +++ b/examples/basic_plugin.py @@ -12,7 +12,7 @@ class BasicPlugin(Plugin): self.log.info('Message created: {}: {}'.format(msg.author, msg.content)) @Plugin.command('echo', '') - def on_test_command(self, event, content): + def on_echo_command(self, event, content): event.msg.reply(content) @Plugin.command('spam', ' ')