diff --git a/discord/client.py b/discord/client.py index ff5ec16fa..daec48e40 100644 --- a/discord/client.py +++ b/discord/client.py @@ -93,6 +93,7 @@ class Client: self.token = None self.gateway = None self.loop = asyncio.get_event_loop() if loop is None else loop + self._listeners = [] max_messages = options.get('max_messages') if max_messages is None or max_messages < 100: @@ -117,6 +118,27 @@ class Client: # internals + def handle_message(self, message): + removed = [] + for i, (condition, future) in enumerate(self._listeners): + if future.cancelled(): + removed.append(i) + continue + + try: + result = condition(message) + except Exception as e: + future.set_exception(e) + removed.append(i) + else: + if result: + future.set_result(message) + removed.append(i) + + + for idx in reversed(removed): + del self._listeners[idx] + def _resolve_mentions(self, content, mentions): if isinstance(mentions, list): return [user.id for user in mentions] @@ -336,6 +358,120 @@ class Client: for member in server.members: yield member + + @asyncio.coroutine + def wait_for_message(self, timeout=None, *, author=None, channel=None, content=None, check=None): + """|coro| + + Waits for a message reply from Discord. This could be seen as another + :func:`discord.on_message` event outside of the actual event. This could + also be used for follow-ups and easier user interactions. + + The keyword arguments passed into this function are combined using the logical and + operator. The ``check`` keyword argument can be used to pass in more complicated + checks and must be a regular function (not a coroutine). + + The ``timeout`` parameter is passed into `asyncio.wait_for`_. By default, it + does not timeout. Instead of throwing ``asyncio.TimeoutError`` the coroutine + catches the exception and returns ``None`` instead of a :class:`Message`. + + If the ``check`` predicate throws an exception, then the exception is propagated. + + This function returns the **first message that meets the requirements**. + + .. _asyncio.wait_for: https://docs.python.org/3/library/asyncio-task.html#asyncio.wait_for + + Examples + ---------- + + Basic example: + + .. code-block:: python + :emphasize-lines: 5 + + @client.async_event + def on_message(message): + if message.content.startswith('$greet') + yield from client.send_message(message.channel, 'Say hello') + msg = yield from client.wait_for_message(author=message.author, content='hello') + yield from client.send_message(message.channel, 'Hello.') + + Asking for a follow-up question: + + .. code-block:: python + :emphasize-lines: 6 + + @client.async_event + def on_message(message): + if message.content.startswith('$start') + yield from client.send_message(message.channel, 'Type $stop 4 times.') + for i in range(4): + msg = yield from client.wait_for_message(author=message.author, content='$stop') + fmt = '{} left to go...' + yield from client.send_message(message.channel, fmt.format(3 - i)) + + yield from client.send_message(message.channel, 'Good job!') + + Advanced filters using ``check``: + + .. code-block:: python + :emphasize-lines: 9 + + @client.async_event + def on_message(message): + if message.content.startswith('$cool'): + yield from client.send_message(message.channel, 'Who is cool? Type $name namehere') + + def check(msg): + return msg.content.startswith('$name') + + message = yield from client.wait_for_message(author=message.author, check=check) + name = message.content[len('$name'):].strip() + yield from client.send_message(message.channel, '{} is cool indeed'.format(name)) + + + Parameters + ----------- + timeout : float + The number of seconds to wait before returning ``None``. + author : :class:`Member` or :class:`User` + The author the message must be from. + channel : :class:`Channel` or :class:`PrivateChannel` or :class:`Object` + The channel the message must be from. + content : str + The exact content the message must have. + check : function + A predicate for other complicated checks. The predicate must take + a :class:`Message` as its only parameter. + + Returns + -------- + :class:`Message` + The message that you requested for. + """ + + def predicate(message): + result = message.author == author + if content is not None: + result = result and message.content == content + + if channel is not None: + result = result and message.channel.id == channel.id + + if callable(check): + # the exception thrown by check is propagated through the future. + result = result and check(message) + + return result + + future = asyncio.Future(loop=self.loop) + self._listeners.append((predicate, future)) + try: + message = yield from asyncio.wait_for(future, timeout, loop=self.loop) + except asyncio.TimeoutError: + message = None + return message + # login state management @asyncio.coroutine