From e5cb7d295c9c8ea5ca52308b4286452a64729b83 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Wed, 25 Jan 2017 22:26:49 -0500 Subject: [PATCH] Replace wait_for_* with a generic Client.wait_for --- discord/client.py | 372 +++++++++++----------------------------------- 1 file changed, 84 insertions(+), 288 deletions(-) diff --git a/discord/client.py b/discord/client.py index ec05cb205..765090161 100644 --- a/discord/client.py +++ b/discord/client.py @@ -41,7 +41,7 @@ import aiohttp import websockets import logging, traceback -import sys, re, io, enum +import sys, re, io import itertools import datetime from collections import namedtuple @@ -51,7 +51,6 @@ PY35 = sys.version_info >= (3, 5) log = logging.getLogger(__name__) AppInfo = namedtuple('AppInfo', 'id name description icon owner') -WaitedReaction = namedtuple('WaitedReaction', 'reaction user') def app_info_icon_url(self): """Retrieves the application's icon_url if it exists. Empty string otherwise.""" @@ -62,10 +61,6 @@ def app_info_icon_url(self): AppInfo.icon_url = property(app_info_icon_url) -class WaitForType(enum.Enum): - message = 0 - reaction = 1 - class Client: """Represents a client connection that connects to Discord. This class is used to interact with the Discord WebSocket and API. @@ -113,7 +108,7 @@ class Client: self.ws = None self.email = None self.loop = asyncio.get_event_loop() if loop is None else loop - self._listeners = [] + self._listeners = {} self.shard_id = options.get('shard_id') self.shard_count = options.get('shard_count') @@ -125,8 +120,6 @@ class Client: self.connection.shard_count = self.shard_count self._closed = asyncio.Event(loop=self.loop) - self._is_logged_in = asyncio.Event(loop=self.loop) - self._is_ready = asyncio.Event(loop=self.loop) # if VoiceClient.warn_nacl: # VoiceClient.warn_nacl = False @@ -156,57 +149,6 @@ class Client: yield from self.ws.send_as_json(payload) - def handle_reaction_add(self, reaction, user): - removed = [] - for i, (condition, future, event_type) in enumerate(self._listeners): - if event_type is not WaitForType.reaction: - continue - - if future.cancelled(): - removed.append(i) - continue - - try: - result = condition(reaction, user) - except Exception as e: - future.set_exception(e) - removed.append(i) - else: - if result: - future.set_result(WaitedReaction(reaction, user)) - removed.append(i) - - - for idx in reversed(removed): - del self._listeners[idx] - - def handle_message(self, message): - removed = [] - for i, (condition, future, event_type) in enumerate(self._listeners): - if event_type is not WaitForType.message: - continue - - if future.cancelled(): - removed.append(i) - continue - - try: - result = condition(message) - except Exception as e: - future.set_exception(e) - removed.append(i) - else: - if result: - future.set_result(message) - removed.append(i) - - - for idx in reversed(removed): - del self._listeners[idx] - - def handle_ready(self): - self._is_ready.set() - def _resolve_invite(self, invite): if isinstance(invite, Invite) or isinstance(invite, Object): return invite.id @@ -264,6 +206,35 @@ class Client: method = 'on_' + event handler = 'handle_' + event + listeners = self._listeners.get(event) + if listeners: + removed = [] + for i, (future, condition) in enumerate(listeners): + if future.cancelled(): + removed.append(i) + continue + + try: + result = condition(*args) + except Exception as e: + future.set_exception(e) + removed.append(i) + else: + if result: + if len(args) == 0: + future.set_result(None) + elif len(args) == 1: + future.set_result(args[0]) + else: + future.set_result(args) + removed.append(i) + + if len(removed) == len(listeners): + self._listeners.pop(event) + else: + for idx in reversed(removed): + del listeners[idx] + try: actual_handler = getattr(self, handler) except AttributeError: @@ -353,7 +324,6 @@ class Client: data = yield from self.http.static_login(token, bot=bot) self.email = data.get('email', None) self.connection.is_bot = bot - self._is_logged_in.set() @asyncio.coroutine def logout(self): @@ -362,7 +332,6 @@ class Client: Logs out of Discord and closes all connections. """ yield from self.close() - self._is_logged_in.clear() @asyncio.coroutine def connect(self): @@ -420,7 +389,6 @@ class Client: yield from self.http.close() self._closed.set() - self._is_ready.clear() @asyncio.coroutine def start(self, *args, **kwargs): @@ -474,12 +442,7 @@ class Client: finally: self.loop.close() - # properties - - @property - def is_logged_in(self): - """bool: Indicates if the client has logged in successfully.""" - return self._is_logged_in.is_set() + # properties @property def is_closed(self): @@ -550,250 +513,83 @@ class Client: # listeners/waiters - @asyncio.coroutine - def wait_until_ready(self): - """|coro| - - This coroutine waits until the client is all ready. This could be considered - another way of asking for :func:`discord.on_ready` except meant for your own - background tasks. - """ - yield from self._is_ready.wait() - - @asyncio.coroutine - def wait_until_login(self): + def wait_for(self, event, *, check=None, timeout=None): """|coro| - This coroutine waits until the client is logged on successfully. This - is different from waiting until the client's state is all ready. For - that check :func:`discord.on_ready` and :meth:`wait_until_ready`. - """ - yield from self._is_logged_in.wait() - - @asyncio.coroutine - def wait_for_message(self, timeout=None, *, author=None, channel=None, content=None, check=None): - """|coro| - - Waits for a message reply from Discord. This could be seen as another - :func:`discord.on_message` event outside of the actual event. This could - also be used for follow-ups and easier user interactions. - - The keyword arguments passed into this function are combined using the logical and - operator. The ``check`` keyword argument can be used to pass in more complicated - checks and must be a regular function (not a coroutine). + Waits for a WebSocket event to be dispatched. - The ``timeout`` parameter is passed into `asyncio.wait_for`_. By default, it - does not timeout. Instead of throwing ``asyncio.TimeoutError`` the coroutine - catches the exception and returns ``None`` instead of a :class:`Message`. + This could be used to wait for a user to reply to a message, + or to react to a message, or to edit a message in a self-contained + way. - If the ``check`` predicate throws an exception, then the exception is propagated. + The ``timeout`` parameter is passed onto `asyncio.wait_for`_. By default, + it does not timeout. Note that this does propagate the + ``asyncio.TimeoutError`` for you in case of timeout and is provided for + ease of use. - This function returns the **first message that meets the requirements**. + In case the event returns multiple arguments, a tuple containing those + arguments is returned instead. Please check the + :ref:`documentation ` for a list of events and their + parameters. - .. _asyncio.wait_for: https://docs.python.org/3/library/asyncio-task.html#asyncio.wait_for + This function returns the **first event that meets the requirements**. Examples - ---------- - - Basic example: + --------- - .. code-block:: python - :emphasize-lines: 5 + Waiting for a user reply: :: @client.event async def on_message(message): if message.content.startswith('$greet'): - await message.channel.send('Say hello') - msg = await client.wait_for_message(author=message.author, content='hello') - await message.channel.send('Hello.') - - Asking for a follow-up question: - - .. code-block:: python - :emphasize-lines: 6 - - @client.event - async def on_message(message): - if message.content.startswith('$start'): - await message.channel.send('Type $stop 4 times.') - for i in range(4): - msg = await client.wait_for_message(author=message.author, content='$stop') - fmt = '{} left to go...' - await message.channel.send(fmt.format(3 - i)) - - await message.channel.send('Good job!') - - Advanced filters using ``check``: - - .. code-block:: python - :emphasize-lines: 9 - - @client.event - async def on_message(message): - if message.content.startswith('$cool'): - await message.channel.send('Who is cool? Type $name namehere') - - def check(msg): - return msg.content.startswith('$name') + await message.channel.send('Say hello!') - message = await client.wait_for_message(author=message.author, check=check) - name = message.content[len('$name'):].strip() - await message.channel.send('{} is cool indeed'.format(name)) + def check(m): + return m.content == 'hello' and m.channel == message.channel + msg = await client.wait_for('message', check=check) + await message.channel.send('Hello {.author}!'.format(msg)) Parameters - ----------- - timeout : float - The number of seconds to wait before returning ``None``. - author : :class:`Member` or :class:`User` - The author the message must be from. - channel : :class:`Channel` or :class:`PrivateChannel` or :class:`Object` - The channel the message must be from. - content : str - The exact content the message must have. - check : function - A predicate for other complicated checks. The predicate must take - a :class:`Message` as its only parameter. - - Returns - -------- - :class:`Message` - The message that you requested for. - """ - - def predicate(message): - result = True - if author is not None: - result = result and message.author == author - - if content is not None: - result = result and message.content == content - - if channel is not None: - result = result and message.channel.id == channel.id - - if callable(check): - # the exception thrown by check is propagated through the future. - result = result and check(message) - - return result - - future = compat.create_future(self.loop) - self._listeners.append((predicate, future, WaitForType.message)) - try: - message = yield from asyncio.wait_for(future, timeout, loop=self.loop) - except asyncio.TimeoutError: - message = None - return message - - - @asyncio.coroutine - def wait_for_reaction(self, emoji=None, *, user=None, timeout=None, message=None, check=None): - """|coro| - - Waits for a message reaction from Discord. This is similar to :meth:`wait_for_message` - and could be seen as another :func:`on_reaction_add` event outside of the actual event. - This could be used for follow up situations. - - Similar to :meth:`wait_for_message`, the keyword arguments are combined using logical - AND operator. The ``check`` keyword argument can be used to pass in more complicated - checks and must a regular function taking in two arguments, ``(reaction, user)``. It - must not be a coroutine. - - The ``timeout`` parameter is passed into asyncio.wait_for. By default, it - does not timeout. Instead of throwing ``asyncio.TimeoutError`` the coroutine - catches the exception and returns ``None`` instead of a the ``(reaction, user)`` - tuple. - - If the ``check`` predicate throws an exception, then the exception is propagated. - - The ``emoji`` parameter can be either a :class:`Emoji`, a ``str`` representing - an emoji, or a sequence of either type. If the ``emoji`` parameter is a sequence - then the first reaction emoji that is in the list is returned. If ``None`` is - passed then the first reaction emoji used is returned. - - This function returns the **first reaction that meets the requirements**. - - Examples - --------- - - Basic Example: - - .. code-block:: python - - @client.event - async def on_message(message): - if message.content.startswith('$react'): - msg = await message.channel.send('React with thumbs up or thumbs down.') - res = await client.wait_for_reaction(['\N{THUMBS UP SIGN}', '\N{THUMBS DOWN SIGN}'], message=msg) - await message.channel.send('{0.user} reacted with {0.reaction.emoji}!'.format(res)) - - Checking for reaction emoji regardless of skin tone: - - .. code-block:: python - - @client.event - async def on_message(message): - if message.content.startswith('$react'): - msg = await message.channel.send('React with thumbs up or thumbs down.') - - def check(reaction, user): - e = str(reaction.emoji) - return e.startswith(('\N{THUMBS UP SIGN}', '\N{THUMBS DOWN SIGN}')) - - res = await client.wait_for_reaction(message=msg, check=check) - await message.channel.send('{0.user} reacted with {0.reaction.emoji}!'.format(res)) + ------------ + event: str + The event name, similar to the :ref:`event reference `, + but without the ``on_`` prefix, to wait for. + check: Optional[predicate] + A predicate to check what to wait for. The arguments must meet the + parameters of the event being waited for. + timeout: Optional[float] + The number of seconds to wait before timing out and raising + ``asyncio.TimeoutError``\. - Parameters - ----------- - timeout: float - The number of seconds to wait before returning ``None``. - user: :class:`Member` or :class:`User` - The user the reaction must be from. - emoji: str or :class:`Emoji` or sequence - The emoji that we are waiting to react with. - message: :class:`Message` - The message that we want the reaction to be from. - check: function - A predicate for other complicated checks. The predicate must take - ``(reaction, user)`` as its two parameters, which ``reaction`` being a - :class:`Reaction` and ``user`` being either a :class:`User` or a - :class:`Member`. + Raises + ------- + asyncio.TimeoutError + If a timeout is provided and it was reached. Returns -------- - namedtuple - A namedtuple with attributes ``reaction`` and ``user`` similar to :func:`on_reaction_add`. + Any + Returns no arguments, a single argument, or a tuple of multiple + arguments that mirrors the parameters passed in the + :ref:`event reference `. """ - if emoji is None: - emoji_check = lambda r: True - elif isinstance(emoji, (str, Emoji)): - emoji_check = lambda r: r.emoji == emoji - else: - emoji_check = lambda r: r.emoji in emoji - - def predicate(reaction, reaction_user): - result = emoji_check(reaction) - - if message is not None: - result = result and message.id == reaction.message.id - - if user is not None: - result = result and user.id == reaction_user.id - - if callable(check): - # the exception thrown by check is propagated through the future. - result = result and check(reaction, reaction_user) - - return result - future = compat.create_future(self.loop) - self._listeners.append((predicate, future, WaitForType.reaction)) + if check is None: + def _check(*args): + return True + check = _check + + ev = event.lower() try: - return (yield from asyncio.wait_for(future, timeout, loop=self.loop)) - except asyncio.TimeoutError: - return None + listeners = self._listeners[ev] + except KeyError: + listeners = [] + self._listeners[ev] = listeners + + listeners.append((future, check)) + return asyncio.wait_for(future, timeout, loop=self.loop) # event registration