Browse Source

Replace wait_for_* with a generic Client.wait_for

pull/468/head
Rapptz 8 years ago
parent
commit
e5cb7d295c
  1. 372
      discord/client.py

372
discord/client.py

@ -41,7 +41,7 @@ import aiohttp
import websockets import websockets
import logging, traceback import logging, traceback
import sys, re, io, enum import sys, re, io
import itertools import itertools
import datetime import datetime
from collections import namedtuple from collections import namedtuple
@ -51,7 +51,6 @@ PY35 = sys.version_info >= (3, 5)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
AppInfo = namedtuple('AppInfo', 'id name description icon owner') AppInfo = namedtuple('AppInfo', 'id name description icon owner')
WaitedReaction = namedtuple('WaitedReaction', 'reaction user')
def app_info_icon_url(self): def app_info_icon_url(self):
"""Retrieves the application's icon_url if it exists. Empty string otherwise.""" """Retrieves the application's icon_url if it exists. Empty string otherwise."""
@ -62,10 +61,6 @@ def app_info_icon_url(self):
AppInfo.icon_url = property(app_info_icon_url) AppInfo.icon_url = property(app_info_icon_url)
class WaitForType(enum.Enum):
message = 0
reaction = 1
class Client: class Client:
"""Represents a client connection that connects to Discord. """Represents a client connection that connects to Discord.
This class is used to interact with the Discord WebSocket and API. This class is used to interact with the Discord WebSocket and API.
@ -113,7 +108,7 @@ class Client:
self.ws = None self.ws = None
self.email = None self.email = None
self.loop = asyncio.get_event_loop() if loop is None else loop self.loop = asyncio.get_event_loop() if loop is None else loop
self._listeners = [] self._listeners = {}
self.shard_id = options.get('shard_id') self.shard_id = options.get('shard_id')
self.shard_count = options.get('shard_count') self.shard_count = options.get('shard_count')
@ -125,8 +120,6 @@ class Client:
self.connection.shard_count = self.shard_count self.connection.shard_count = self.shard_count
self._closed = asyncio.Event(loop=self.loop) self._closed = asyncio.Event(loop=self.loop)
self._is_logged_in = asyncio.Event(loop=self.loop)
self._is_ready = asyncio.Event(loop=self.loop)
# if VoiceClient.warn_nacl: # if VoiceClient.warn_nacl:
# VoiceClient.warn_nacl = False # VoiceClient.warn_nacl = False
@ -156,57 +149,6 @@ class Client:
yield from self.ws.send_as_json(payload) yield from self.ws.send_as_json(payload)
def handle_reaction_add(self, reaction, user):
removed = []
for i, (condition, future, event_type) in enumerate(self._listeners):
if event_type is not WaitForType.reaction:
continue
if future.cancelled():
removed.append(i)
continue
try:
result = condition(reaction, user)
except Exception as e:
future.set_exception(e)
removed.append(i)
else:
if result:
future.set_result(WaitedReaction(reaction, user))
removed.append(i)
for idx in reversed(removed):
del self._listeners[idx]
def handle_message(self, message):
removed = []
for i, (condition, future, event_type) in enumerate(self._listeners):
if event_type is not WaitForType.message:
continue
if future.cancelled():
removed.append(i)
continue
try:
result = condition(message)
except Exception as e:
future.set_exception(e)
removed.append(i)
else:
if result:
future.set_result(message)
removed.append(i)
for idx in reversed(removed):
del self._listeners[idx]
def handle_ready(self):
self._is_ready.set()
def _resolve_invite(self, invite): def _resolve_invite(self, invite):
if isinstance(invite, Invite) or isinstance(invite, Object): if isinstance(invite, Invite) or isinstance(invite, Object):
return invite.id return invite.id
@ -264,6 +206,35 @@ class Client:
method = 'on_' + event method = 'on_' + event
handler = 'handle_' + event handler = 'handle_' + event
listeners = self._listeners.get(event)
if listeners:
removed = []
for i, (future, condition) in enumerate(listeners):
if future.cancelled():
removed.append(i)
continue
try:
result = condition(*args)
except Exception as e:
future.set_exception(e)
removed.append(i)
else:
if result:
if len(args) == 0:
future.set_result(None)
elif len(args) == 1:
future.set_result(args[0])
else:
future.set_result(args)
removed.append(i)
if len(removed) == len(listeners):
self._listeners.pop(event)
else:
for idx in reversed(removed):
del listeners[idx]
try: try:
actual_handler = getattr(self, handler) actual_handler = getattr(self, handler)
except AttributeError: except AttributeError:
@ -353,7 +324,6 @@ class Client:
data = yield from self.http.static_login(token, bot=bot) data = yield from self.http.static_login(token, bot=bot)
self.email = data.get('email', None) self.email = data.get('email', None)
self.connection.is_bot = bot self.connection.is_bot = bot
self._is_logged_in.set()
@asyncio.coroutine @asyncio.coroutine
def logout(self): def logout(self):
@ -362,7 +332,6 @@ class Client:
Logs out of Discord and closes all connections. Logs out of Discord and closes all connections.
""" """
yield from self.close() yield from self.close()
self._is_logged_in.clear()
@asyncio.coroutine @asyncio.coroutine
def connect(self): def connect(self):
@ -420,7 +389,6 @@ class Client:
yield from self.http.close() yield from self.http.close()
self._closed.set() self._closed.set()
self._is_ready.clear()
@asyncio.coroutine @asyncio.coroutine
def start(self, *args, **kwargs): def start(self, *args, **kwargs):
@ -474,12 +442,7 @@ class Client:
finally: finally:
self.loop.close() self.loop.close()
# properties # properties
@property
def is_logged_in(self):
"""bool: Indicates if the client has logged in successfully."""
return self._is_logged_in.is_set()
@property @property
def is_closed(self): def is_closed(self):
@ -550,250 +513,83 @@ class Client:
# listeners/waiters # listeners/waiters
@asyncio.coroutine def wait_for(self, event, *, check=None, timeout=None):
def wait_until_ready(self):
"""|coro|
This coroutine waits until the client is all ready. This could be considered
another way of asking for :func:`discord.on_ready` except meant for your own
background tasks.
"""
yield from self._is_ready.wait()
@asyncio.coroutine
def wait_until_login(self):
"""|coro| """|coro|
This coroutine waits until the client is logged on successfully. This Waits for a WebSocket event to be dispatched.
is different from waiting until the client's state is all ready. For
that check :func:`discord.on_ready` and :meth:`wait_until_ready`.
"""
yield from self._is_logged_in.wait()
@asyncio.coroutine
def wait_for_message(self, timeout=None, *, author=None, channel=None, content=None, check=None):
"""|coro|
Waits for a message reply from Discord. This could be seen as another
:func:`discord.on_message` event outside of the actual event. This could
also be used for follow-ups and easier user interactions.
The keyword arguments passed into this function are combined using the logical and
operator. The ``check`` keyword argument can be used to pass in more complicated
checks and must be a regular function (not a coroutine).
The ``timeout`` parameter is passed into `asyncio.wait_for`_. By default, it This could be used to wait for a user to reply to a message,
does not timeout. Instead of throwing ``asyncio.TimeoutError`` the coroutine or to react to a message, or to edit a message in a self-contained
catches the exception and returns ``None`` instead of a :class:`Message`. way.
If the ``check`` predicate throws an exception, then the exception is propagated. The ``timeout`` parameter is passed onto `asyncio.wait_for`_. By default,
it does not timeout. Note that this does propagate the
``asyncio.TimeoutError`` for you in case of timeout and is provided for
ease of use.
This function returns the **first message that meets the requirements**. In case the event returns multiple arguments, a tuple containing those
arguments is returned instead. Please check the
:ref:`documentation <discord-api-events>` for a list of events and their
parameters.
.. _asyncio.wait_for: https://docs.python.org/3/library/asyncio-task.html#asyncio.wait_for This function returns the **first event that meets the requirements**.
Examples Examples
---------- ---------
Basic example:
.. code-block:: python Waiting for a user reply: ::
:emphasize-lines: 5
@client.event @client.event
async def on_message(message): async def on_message(message):
if message.content.startswith('$greet'): if message.content.startswith('$greet'):
await message.channel.send('Say hello') await message.channel.send('Say hello!')
msg = await client.wait_for_message(author=message.author, content='hello')
await message.channel.send('Hello.')
Asking for a follow-up question:
.. code-block:: python
:emphasize-lines: 6
@client.event
async def on_message(message):
if message.content.startswith('$start'):
await message.channel.send('Type $stop 4 times.')
for i in range(4):
msg = await client.wait_for_message(author=message.author, content='$stop')
fmt = '{} left to go...'
await message.channel.send(fmt.format(3 - i))
await message.channel.send('Good job!')
Advanced filters using ``check``:
.. code-block:: python
:emphasize-lines: 9
@client.event
async def on_message(message):
if message.content.startswith('$cool'):
await message.channel.send('Who is cool? Type $name namehere')
def check(msg):
return msg.content.startswith('$name')
message = await client.wait_for_message(author=message.author, check=check) def check(m):
name = message.content[len('$name'):].strip() return m.content == 'hello' and m.channel == message.channel
await message.channel.send('{} is cool indeed'.format(name))
msg = await client.wait_for('message', check=check)
await message.channel.send('Hello {.author}!'.format(msg))
Parameters Parameters
----------- ------------
timeout : float event: str
The number of seconds to wait before returning ``None``. The event name, similar to the :ref:`event reference <discord-api-events>`,
author : :class:`Member` or :class:`User` but without the ``on_`` prefix, to wait for.
The author the message must be from. check: Optional[predicate]
channel : :class:`Channel` or :class:`PrivateChannel` or :class:`Object` A predicate to check what to wait for. The arguments must meet the
The channel the message must be from. parameters of the event being waited for.
content : str timeout: Optional[float]
The exact content the message must have. The number of seconds to wait before timing out and raising
check : function ``asyncio.TimeoutError``\.
A predicate for other complicated checks. The predicate must take
a :class:`Message` as its only parameter.
Returns
--------
:class:`Message`
The message that you requested for.
"""
def predicate(message):
result = True
if author is not None:
result = result and message.author == author
if content is not None:
result = result and message.content == content
if channel is not None:
result = result and message.channel.id == channel.id
if callable(check):
# the exception thrown by check is propagated through the future.
result = result and check(message)
return result
future = compat.create_future(self.loop)
self._listeners.append((predicate, future, WaitForType.message))
try:
message = yield from asyncio.wait_for(future, timeout, loop=self.loop)
except asyncio.TimeoutError:
message = None
return message
@asyncio.coroutine
def wait_for_reaction(self, emoji=None, *, user=None, timeout=None, message=None, check=None):
"""|coro|
Waits for a message reaction from Discord. This is similar to :meth:`wait_for_message`
and could be seen as another :func:`on_reaction_add` event outside of the actual event.
This could be used for follow up situations.
Similar to :meth:`wait_for_message`, the keyword arguments are combined using logical
AND operator. The ``check`` keyword argument can be used to pass in more complicated
checks and must a regular function taking in two arguments, ``(reaction, user)``. It
must not be a coroutine.
The ``timeout`` parameter is passed into asyncio.wait_for. By default, it
does not timeout. Instead of throwing ``asyncio.TimeoutError`` the coroutine
catches the exception and returns ``None`` instead of a the ``(reaction, user)``
tuple.
If the ``check`` predicate throws an exception, then the exception is propagated.
The ``emoji`` parameter can be either a :class:`Emoji`, a ``str`` representing
an emoji, or a sequence of either type. If the ``emoji`` parameter is a sequence
then the first reaction emoji that is in the list is returned. If ``None`` is
passed then the first reaction emoji used is returned.
This function returns the **first reaction that meets the requirements**.
Examples
---------
Basic Example:
.. code-block:: python
@client.event
async def on_message(message):
if message.content.startswith('$react'):
msg = await message.channel.send('React with thumbs up or thumbs down.')
res = await client.wait_for_reaction(['\N{THUMBS UP SIGN}', '\N{THUMBS DOWN SIGN}'], message=msg)
await message.channel.send('{0.user} reacted with {0.reaction.emoji}!'.format(res))
Checking for reaction emoji regardless of skin tone:
.. code-block:: python
@client.event
async def on_message(message):
if message.content.startswith('$react'):
msg = await message.channel.send('React with thumbs up or thumbs down.')
def check(reaction, user):
e = str(reaction.emoji)
return e.startswith(('\N{THUMBS UP SIGN}', '\N{THUMBS DOWN SIGN}'))
res = await client.wait_for_reaction(message=msg, check=check)
await message.channel.send('{0.user} reacted with {0.reaction.emoji}!'.format(res))
Parameters Raises
----------- -------
timeout: float asyncio.TimeoutError
The number of seconds to wait before returning ``None``. If a timeout is provided and it was reached.
user: :class:`Member` or :class:`User`
The user the reaction must be from.
emoji: str or :class:`Emoji` or sequence
The emoji that we are waiting to react with.
message: :class:`Message`
The message that we want the reaction to be from.
check: function
A predicate for other complicated checks. The predicate must take
``(reaction, user)`` as its two parameters, which ``reaction`` being a
:class:`Reaction` and ``user`` being either a :class:`User` or a
:class:`Member`.
Returns Returns
-------- --------
namedtuple Any
A namedtuple with attributes ``reaction`` and ``user`` similar to :func:`on_reaction_add`. Returns no arguments, a single argument, or a tuple of multiple
arguments that mirrors the parameters passed in the
:ref:`event reference <discord-api-events>`.
""" """
if emoji is None:
emoji_check = lambda r: True
elif isinstance(emoji, (str, Emoji)):
emoji_check = lambda r: r.emoji == emoji
else:
emoji_check = lambda r: r.emoji in emoji
def predicate(reaction, reaction_user):
result = emoji_check(reaction)
if message is not None:
result = result and message.id == reaction.message.id
if user is not None:
result = result and user.id == reaction_user.id
if callable(check):
# the exception thrown by check is propagated through the future.
result = result and check(reaction, reaction_user)
return result
future = compat.create_future(self.loop) future = compat.create_future(self.loop)
self._listeners.append((predicate, future, WaitForType.reaction)) if check is None:
def _check(*args):
return True
check = _check
ev = event.lower()
try: try:
return (yield from asyncio.wait_for(future, timeout, loop=self.loop)) listeners = self._listeners[ev]
except asyncio.TimeoutError: except KeyError:
return None listeners = []
self._listeners[ev] = listeners
listeners.append((future, check))
return asyncio.wait_for(future, timeout, loop=self.loop)
# event registration # event registration

Loading…
Cancel
Save