diff --git a/discord/client.py b/discord/client.py index b1b87ca95..7c5975423 100644 --- a/discord/client.py +++ b/discord/client.py @@ -50,7 +50,7 @@ import aiohttp import websockets import logging, traceback -import sys, re, io +import sys, re, io, enum import tempfile, os, hashlib import itertools import datetime @@ -70,6 +70,10 @@ def app_info_icon_url(self): AppInfo.icon_url = property(app_info_icon_url) +class WaitForType(enum.Enum): + message = 0 + reaction = 1 + ChannelPermissions = namedtuple('ChannelPermissions', 'target overwrite') ChannelPermissions.__new__.__defaults__ = (PermissionOverwrite(),) @@ -194,9 +198,36 @@ class Client: log.info('a problem occurred while updating the login cache') pass + def handle_reaction_add(self, reaction, user): + removed = [] + for i, (condition, future, event_type) in enumerate(self._listeners): + if event_type is not WaitForType.reaction: + continue + + if future.cancelled(): + removed.append(i) + continue + + try: + result = condition(reaction, user) + except Exception as e: + future.set_exception(e) + removed.append(i) + else: + if result: + future.set_result((reaction, user)) + removed.append(i) + + + for idx in reversed(removed): + del self._listeners[idx] + def handle_message(self, message): removed = [] - for i, (condition, future) in enumerate(self._listeners): + for i, (condition, future, event_type) in enumerate(self._listeners): + if event_type is not WaitForType.message: + continue + if future.cancelled(): removed.append(i) continue @@ -614,45 +645,45 @@ class Client: .. code-block:: python :emphasize-lines: 5 - @client.async_event - def on_message(message): + @client.event + async def on_message(message): if message.content.startswith('$greet'): - yield from client.send_message(message.channel, 'Say hello') - msg = yield from client.wait_for_message(author=message.author, content='hello') - yield from client.send_message(message.channel, 'Hello.') + await client.send_message(message.channel, 'Say hello') + msg = await client.wait_for_message(author=message.author, content='hello') + await client.send_message(message.channel, 'Hello.') Asking for a follow-up question: .. code-block:: python :emphasize-lines: 6 - @client.async_event - def on_message(message): + @client.event + async def on_message(message): if message.content.startswith('$start'): - yield from client.send_message(message.channel, 'Type $stop 4 times.') + await client.send_message(message.channel, 'Type $stop 4 times.') for i in range(4): - msg = yield from client.wait_for_message(author=message.author, content='$stop') + msg = await client.wait_for_message(author=message.author, content='$stop') fmt = '{} left to go...' - yield from client.send_message(message.channel, fmt.format(3 - i)) + await client.send_message(message.channel, fmt.format(3 - i)) - yield from client.send_message(message.channel, 'Good job!') + await client.send_message(message.channel, 'Good job!') Advanced filters using ``check``: .. code-block:: python :emphasize-lines: 9 - @client.async_event - def on_message(message): + @client.event + async def on_message(message): if message.content.startswith('$cool'): - yield from client.send_message(message.channel, 'Who is cool? Type $name namehere') + await client.send_message(message.channel, 'Who is cool? Type $name namehere') def check(msg): return msg.content.startswith('$name') - message = yield from client.wait_for_message(author=message.author, check=check) + message = await client.wait_for_message(author=message.author, check=check) name = message.content[len('$name'):].strip() - yield from client.send_message(message.channel, '{} is cool indeed'.format(name)) + await client.send_message(message.channel, '{} is cool indeed'.format(name)) Parameters @@ -693,13 +724,107 @@ class Client: return result future = asyncio.Future(loop=self.loop) - self._listeners.append((predicate, future)) + self._listeners.append((predicate, future, WaitForType.message)) try: message = yield from asyncio.wait_for(future, timeout, loop=self.loop) except asyncio.TimeoutError: message = None return message + + @asyncio.coroutine + def wait_for_reaction(self, emoji=None, *, user=None, timeout=None, message=None, check=None): + """|coro| + + Waits for a message reaction from Discord. This is similar to :meth:`wait_for_message` + and could be seen as another :func:`on_reaction_add` event outside of the actual event. + This could be used for follow up situations. + + Similar to :meth:`wait_for_message`, the keyword arguments are combined using logical + AND operator. The ``check`` keyword argument can be used to pass in more complicated + checks and must a regular function taking in two arguments, ``(reaction, user)``. It + must not be a coroutine. + + The ``timeout`` parameter is passed into asyncio.wait_for. By default, it + does not timeout. Instead of throwing ``asyncio.TimeoutError`` the coroutine + catches the exception and returns ``None`` instead of a the ``(reaction, user)`` + tuple. + + If the ``check`` predicate throws an exception, then the exception is propagated. + + The ``emoji`` parameter can be either a :class:`Emoji`, a ``str`` representing + an emoji, or a sequence of either type. If the ``emoji`` parameter is a sequence + then the first reaction emoji that is in the list is returned. If ``None`` is + passed then the first reaction emoji used is returned. + + This function returns the **first reaction that meets the requirements**. + + Examples + --------- + + Basic Example: + + .. code-block:: python + + @client.event + async def on_message(message): + if message.content.startswith('$react'): + msg = await client.send_message(message.channel, 'React with thumbs up or thumbs down.') + (reaction, user) = await client.wait_for_reaction(['\N{THUMBS UP SIGN}', + '\N{THUMBS DOWN SIGN}'], + message=msg) + await client.send_message(message.channel, '{} reacted with {.emoji}!'.format(user, reaction)) + + Parameters + ----------- + timeout: float + The number of seconds to wait before returning ``None``. + user: :class:`Member` or :class:`User` + The user the reaction must be from. + emoji: str or :class:`Emoji` or sequence + The emoji that we are waiting to react with. + message: :class:`Message` + The message that we want the reaction to be from. + check: function + A predicate for other complicated checks. The predicate must take + ``(reaction, user)`` as its two parameters, which ``reaction`` being a + :class:`Reaction` and ``user`` being either a :class:`User` or a + :class:`Member`. + + Returns + -------- + tuple + A tuple of ``(reaction, user)`` similar to :func:`on_reaction_add`. + """ + + emoji_check = lambda r: True + if isinstance(emoji, (str, Emoji)): + emoji_check = lambda r: r.emoji == emoji + else: + emoji_check = lambda r: r.emoji in emoji + + def predicate(reaction, reaction_user): + result = emoji_check(reaction) + + if message is not None: + result = result and message.id == reaction.message.id + + if user is not None: + result = result and user.id == reaction_user.id + + if callable(check): + # the exception thrown by check is propagated through the future. + result = result and check(reaction, reaction_user) + + return result + + future = asyncio.Future(loop=self.loop) + self._listeners.append((predicate, future, WaitForType.reaction)) + try: + return (yield from asyncio.wait_for(future, timeout, loop=self.loop)) + except asyncio.TimeoutError: + return None + # event registration def event(self, coro):