Browse Source

Add Client.wait_for_reaction to wait for a reaction from a user.

pull/372/merge
Rapptz 9 years ago
parent
commit
0e8a92cbac
  1. 163
      discord/client.py

163
discord/client.py

@ -50,7 +50,7 @@ import aiohttp
import websockets import websockets
import logging, traceback import logging, traceback
import sys, re, io import sys, re, io, enum
import tempfile, os, hashlib import tempfile, os, hashlib
import itertools import itertools
import datetime import datetime
@ -70,6 +70,10 @@ def app_info_icon_url(self):
AppInfo.icon_url = property(app_info_icon_url) AppInfo.icon_url = property(app_info_icon_url)
class WaitForType(enum.Enum):
message = 0
reaction = 1
ChannelPermissions = namedtuple('ChannelPermissions', 'target overwrite') ChannelPermissions = namedtuple('ChannelPermissions', 'target overwrite')
ChannelPermissions.__new__.__defaults__ = (PermissionOverwrite(),) ChannelPermissions.__new__.__defaults__ = (PermissionOverwrite(),)
@ -194,9 +198,36 @@ class Client:
log.info('a problem occurred while updating the login cache') log.info('a problem occurred while updating the login cache')
pass pass
def handle_reaction_add(self, reaction, user):
removed = []
for i, (condition, future, event_type) in enumerate(self._listeners):
if event_type is not WaitForType.reaction:
continue
if future.cancelled():
removed.append(i)
continue
try:
result = condition(reaction, user)
except Exception as e:
future.set_exception(e)
removed.append(i)
else:
if result:
future.set_result((reaction, user))
removed.append(i)
for idx in reversed(removed):
del self._listeners[idx]
def handle_message(self, message): def handle_message(self, message):
removed = [] removed = []
for i, (condition, future) in enumerate(self._listeners): for i, (condition, future, event_type) in enumerate(self._listeners):
if event_type is not WaitForType.message:
continue
if future.cancelled(): if future.cancelled():
removed.append(i) removed.append(i)
continue continue
@ -614,45 +645,45 @@ class Client:
.. code-block:: python .. code-block:: python
:emphasize-lines: 5 :emphasize-lines: 5
@client.async_event @client.event
def on_message(message): async def on_message(message):
if message.content.startswith('$greet'): if message.content.startswith('$greet'):
yield from client.send_message(message.channel, 'Say hello') await client.send_message(message.channel, 'Say hello')
msg = yield from client.wait_for_message(author=message.author, content='hello') msg = await client.wait_for_message(author=message.author, content='hello')
yield from client.send_message(message.channel, 'Hello.') await client.send_message(message.channel, 'Hello.')
Asking for a follow-up question: Asking for a follow-up question:
.. code-block:: python .. code-block:: python
:emphasize-lines: 6 :emphasize-lines: 6
@client.async_event @client.event
def on_message(message): async def on_message(message):
if message.content.startswith('$start'): if message.content.startswith('$start'):
yield from client.send_message(message.channel, 'Type $stop 4 times.') await client.send_message(message.channel, 'Type $stop 4 times.')
for i in range(4): for i in range(4):
msg = yield from client.wait_for_message(author=message.author, content='$stop') msg = await client.wait_for_message(author=message.author, content='$stop')
fmt = '{} left to go...' fmt = '{} left to go...'
yield from client.send_message(message.channel, fmt.format(3 - i)) await client.send_message(message.channel, fmt.format(3 - i))
yield from client.send_message(message.channel, 'Good job!') await client.send_message(message.channel, 'Good job!')
Advanced filters using ``check``: Advanced filters using ``check``:
.. code-block:: python .. code-block:: python
:emphasize-lines: 9 :emphasize-lines: 9
@client.async_event @client.event
def on_message(message): async def on_message(message):
if message.content.startswith('$cool'): if message.content.startswith('$cool'):
yield from client.send_message(message.channel, 'Who is cool? Type $name namehere') await client.send_message(message.channel, 'Who is cool? Type $name namehere')
def check(msg): def check(msg):
return msg.content.startswith('$name') return msg.content.startswith('$name')
message = yield from client.wait_for_message(author=message.author, check=check) message = await client.wait_for_message(author=message.author, check=check)
name = message.content[len('$name'):].strip() name = message.content[len('$name'):].strip()
yield from client.send_message(message.channel, '{} is cool indeed'.format(name)) await client.send_message(message.channel, '{} is cool indeed'.format(name))
Parameters Parameters
@ -693,13 +724,107 @@ class Client:
return result return result
future = asyncio.Future(loop=self.loop) future = asyncio.Future(loop=self.loop)
self._listeners.append((predicate, future)) self._listeners.append((predicate, future, WaitForType.message))
try: try:
message = yield from asyncio.wait_for(future, timeout, loop=self.loop) message = yield from asyncio.wait_for(future, timeout, loop=self.loop)
except asyncio.TimeoutError: except asyncio.TimeoutError:
message = None message = None
return message return message
@asyncio.coroutine
def wait_for_reaction(self, emoji=None, *, user=None, timeout=None, message=None, check=None):
"""|coro|
Waits for a message reaction from Discord. This is similar to :meth:`wait_for_message`
and could be seen as another :func:`on_reaction_add` event outside of the actual event.
This could be used for follow up situations.
Similar to :meth:`wait_for_message`, the keyword arguments are combined using logical
AND operator. The ``check`` keyword argument can be used to pass in more complicated
checks and must a regular function taking in two arguments, ``(reaction, user)``. It
must not be a coroutine.
The ``timeout`` parameter is passed into asyncio.wait_for. By default, it
does not timeout. Instead of throwing ``asyncio.TimeoutError`` the coroutine
catches the exception and returns ``None`` instead of a the ``(reaction, user)``
tuple.
If the ``check`` predicate throws an exception, then the exception is propagated.
The ``emoji`` parameter can be either a :class:`Emoji`, a ``str`` representing
an emoji, or a sequence of either type. If the ``emoji`` parameter is a sequence
then the first reaction emoji that is in the list is returned. If ``None`` is
passed then the first reaction emoji used is returned.
This function returns the **first reaction that meets the requirements**.
Examples
---------
Basic Example:
.. code-block:: python
@client.event
async def on_message(message):
if message.content.startswith('$react'):
msg = await client.send_message(message.channel, 'React with thumbs up or thumbs down.')
(reaction, user) = await client.wait_for_reaction(['\N{THUMBS UP SIGN}',
'\N{THUMBS DOWN SIGN}'],
message=msg)
await client.send_message(message.channel, '{} reacted with {.emoji}!'.format(user, reaction))
Parameters
-----------
timeout: float
The number of seconds to wait before returning ``None``.
user: :class:`Member` or :class:`User`
The user the reaction must be from.
emoji: str or :class:`Emoji` or sequence
The emoji that we are waiting to react with.
message: :class:`Message`
The message that we want the reaction to be from.
check: function
A predicate for other complicated checks. The predicate must take
``(reaction, user)`` as its two parameters, which ``reaction`` being a
:class:`Reaction` and ``user`` being either a :class:`User` or a
:class:`Member`.
Returns
--------
tuple
A tuple of ``(reaction, user)`` similar to :func:`on_reaction_add`.
"""
emoji_check = lambda r: True
if isinstance(emoji, (str, Emoji)):
emoji_check = lambda r: r.emoji == emoji
else:
emoji_check = lambda r: r.emoji in emoji
def predicate(reaction, reaction_user):
result = emoji_check(reaction)
if message is not None:
result = result and message.id == reaction.message.id
if user is not None:
result = result and user.id == reaction_user.id
if callable(check):
# the exception thrown by check is propagated through the future.
result = result and check(reaction, reaction_user)
return result
future = asyncio.Future(loop=self.loop)
self._listeners.append((predicate, future, WaitForType.reaction))
try:
return (yield from asyncio.wait_for(future, timeout, loop=self.loop))
except asyncio.TimeoutError:
return None
# event registration # event registration
def event(self, coro): def event(self, coro):

Loading…
Cancel
Save