diff --git a/discord/async_client.py b/discord/async_client.py new file mode 100644 index 000000000..22f763913 --- /dev/null +++ b/discord/async_client.py @@ -0,0 +1,474 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from . import endpoints +from .user import User +from .channel import Channel, PrivateChannel +from .server import Server +from .message import Message +from .invite import Invite +from .object import Object +from .errors import * +from .state import ConnectionState +from . import utils + +import asyncio +import aiohttp +import websockets + +import logging, traceback +import sys, time, re, json + +log = logging.getLogger(__name__) +request_logging_format = '{method} {response.url} has returned {response.status}' +request_success_log = '{response.url} with {json} received {data}' + +def to_json(obj): + return json.dumps(obj, separators=(',', ':'), ensure_ascii=True) + +class Client: + """Represents a client connection that connects to Discord. + This class is used to interact with the Discord WebSocket and API. + + A number of options can be passed to the :class:`Client`. + + Parameters + ---------- + max_messages : Optional[int] + The maximum number of messages to store in :attr:`messages`. + This defaults to 5000. Passing in `None` or a value of ``<= 0`` + will use the default instead of the passed in value. + loop : Optional[event loop]. + The `event loop`_ to use for asynchronous operations. Defaults to ``None``, + in which case the default event loop is used via ``asyncio.get_event_loop()``. + + Attributes + ----------- + user : Optional[:class:`User`] + Represents the connected client. None if not logged in. + servers : list of :class:`Server` + The servers that the connected client is a member of. + private_channels : list of :class:`PrivateChannel` + The private channels that the connected client is participating on. + messages : deque_ of :class:`Message` + A deque_ of :class:`Message` that the client has received from all + servers and private messages. The number of messages stored in this + deque is controlled by the ``max_messages`` parameter. + email : Optional[str] + The email used to login. This is only set if login is successful, + otherwise it's None. + gateway : Optional[str] + The websocket gateway the client is currently connected to. Could be None. + loop + The `event loop`_ that the client uses for HTTP requests and websocket operations. + + .. _deque: https://docs.python.org/3.4/library/collections.html#collections.deque + .. _event loop: https://docs.python.org/3/library/asyncio-eventloops.html + """ + def __init__(self, *, loop=None, **options): + self.ws = None + self.token = None + self.gateway = None + self.loop = asyncio.get_event_loop() if loop is None else loop + + max_messages = options.get('max_messages') + if max_messages is None or max_messages <= 0: + max_messages = 5000 + + self.connection = ConnectionState(self.dispatch, max_messages) + self.session = aiohttp.ClientSession(loop=self.loop) + self.headers = { + 'content-type': 'application/json', + } + self._closed = False + + def _resolve_mentions(self, content, mentions): + if isinstance(mentions, list): + return [user.id for user in mentions] + elif mentions == True: + return re.findall(r'<@(\d+)>', content) + else: + return [] + + def _resolve_invite(self, invite): + if isinstance(invite, Invite) or isinstance(invite, Object): + return invite.id + else: + rx = r'(?:https?\:\/\/)?discord\.gg\/(.+)' + m = re.match(rx, invite) + if m: + return m.group(1) + return invite + + def _resolve_destination(self, destination): + if isinstance(destination, (Channel, PrivateChannel, Server)): + return destination.id + elif isinstance(destination, User): + found = utils.find(lambda pm: pm.user == destination, self.private_channels) + if found is None: + # Couldn't find the user, so start a PM with them first. + self.start_private_message(destination) + channel_id = self.private_channels[-1].id + return channel_id + else: + return found.id + elif isinstance(destination, Object): + return destination.id + else: + raise InvalidArgument('Destination must be Channel, PrivateChannel, User, or Object') + + # Compatibility shim + def __getattr__(self, name): + if name in ('user', 'email', 'servers', 'private_channels', 'messages'): + return getattr(self.connection, name) + else: + msg = "'{}' object has no attribute '{}'" + raise AttributeError(msg.format(self.__class__, name)) + + # Compatibility shim + def __setattr__(self, name, value): + if name in ('user', 'email', 'servers', 'private_channels', + 'messages'): + return setattr(self.connection, name, value) + else: + object.__setattr__(self, name, value) + + @property + def is_logged_in(self): + """bool: Indicates if the client has logged in successfully.""" + return self._is_logged_in + + @asyncio.coroutine + def _get_gateway(self): + resp = yield from self.session.get(endpoints.GATEWAY, headers=self.headers) + if resp.status != 200: + raise GatewayNotFound() + data = yield from resp.json() + return data.get('url') + + @asyncio.coroutine + def _run_event(self, event, *args, **kwargs): + try: + yield from getattr(self, event)(*args, **kwargs) + except Exception as e: + yield from self.on_error(event, *args, **kwargs) + + def dispatch(self, event, *args, **kwargs): + log.debug('Dispatching event {}'.format(event)) + method = 'on_' + event + handler = 'handle_' + event + + if hasattr(self, handler): + getattr(self, handler)(*args, **kwargs) + + if hasattr(self, method): + utils.create_task(self._run_event(method, *args, **kwargs), loop=self.loop) + + def get_channel(self, id): + """Returns a :class:`Channel` or :class:`PrivateChannel` with the following ID. If not found, returns None.""" + return self.connection.get_channel(id) + + @asyncio.coroutine + def login(self, email, password): + """|coro| + + Logs in the client with the specified credentials. + + Parameters + ---------- + email : str + The email used to login. + password : str + The password used to login. + + Raises + ------ + LoginFailure + The wrong credentials are passed. + HTTPException + An unknown HTTP related error occurred, + usually when it isn't 200 or the known incorrect credentials + passing status code. + """ + payload = { + 'email': email, + 'password': password + } + + data = to_json(payload) + resp = yield from self.session.post(endpoints.LOGIN, data=data, headers=self.headers) + log.debug(request_logging_format.format(method='POST', response=resp)) + if resp.status == 400: + raise LoginFailure('Improper credentials have been passed.') + elif resp.status != 200: + data = yield from resp.json() + raise HTTPException(resp, data.get('message')) + + log.info('logging in returned status code {}'.format(resp.status)) + self.email = email + + body = yield from resp.json() + self.token = body['token'] + self.headers['authorization'] = self.token + self._is_logged_in = True + + @asyncio.coroutine + def keep_alive_handler(self, interval): + while not self._closed: + payload = { + 'op': 1, + 'd': int(time.time()) + } + + msg = 'Keeping websocket alive with timestamp {}' + log.debug(msg.format(payload['d'])) + yield from self.ws.send(to_json(payload)) + yield from asyncio.sleep(interval) + + @asyncio.coroutine + def on_error(self, event_method, *args, **kwargs): + """|coro| + + The default error handler provided by the client. + + By default this prints to ``sys.stderr`` however it could be + overridden to have a different implementation. + Check :func:`discord.on_error` for more details. + """ + print('Ignoring exception in {}'.format(event_method), file=sys.stderr) + traceback.print_exc() + + def received_message(self, msg): + log.debug('WebSocket Event: {}'.format(msg)) + self.dispatch('socket_response', msg) + + op = msg.get('op') + data = msg.get('d') + + if op != 0: + log.info('Unhandled op {}'.format(op)) + return + + event = msg.get('t') + + if event == 'READY': + interval = data['heartbeat_interval'] / 1000.0 + self.keep_alive = utils.create_task(self.keep_alive_handler(interval), loop=self.loop) + + if event in ('READY', 'MESSAGE_CREATE', 'MESSAGE_DELETE', + 'MESSAGE_UPDATE', 'PRESENCE_UPDATE', 'USER_UPDATE', + 'CHANNEL_DELETE', 'CHANNEL_UPDATE', 'CHANNEL_CREATE', + 'GUILD_MEMBER_ADD', 'GUILD_MEMBER_REMOVE', + 'GUILD_MEMBER_UPDATE', 'GUILD_CREATE', 'GUILD_DELETE', + 'GUILD_ROLE_CREATE', 'GUILD_ROLE_DELETE', 'TYPING_START', + 'GUILD_ROLE_UPDATE', 'VOICE_STATE_UPDATE'): + parser = 'parse_' + event.lower() + if hasattr(self.connection, parser): + getattr(self.connection, parser)(data) + else: + log.info("Unhandled event {}".format(event)) + + @asyncio.coroutine + def _make_websocket(self): + if not self.is_logged_in: + raise ClientException('You must be logged in to connect') + + self.gateway = yield from self._get_gateway() + self.ws = yield from websockets.connect(self.gateway) + self.ws.max_size = None + log.info('Created websocket connected to {0.gateway}'.format(self)) + payload = { + 'op': 2, + 'd': { + 'token': self.token, + 'properties': { + '$os': sys.platform, + '$browser': 'discord.py', + '$device': 'discord.py', + '$referrer': '', + '$referring_domain': '' + }, + 'v': 3 + } + } + + yield from self.ws.send(to_json(payload)) + log.info('sent the initial payload to create the websocket') + + @asyncio.coroutine + def connect(self): + """|coro| + + Creates a websocket connection and connects to the websocket listen + to messages from discord. + + This function is implemented using a while loop in the background. + If you need to run this event listening in another thread then + you should run it in an executor or schedule the coroutine to + be executed later using ``loop.create_task``. + + This function throws :exc:`ClientException` if called before + logging in via :meth:`login`. + """ + yield from self._make_websocket() + + while not self._closed: + msg = yield from self.ws.recv() + if msg is None: + yield from self.ws.close() + self._closed = True + self.keep_alive.cancel() + break + + self.received_message(json.loads(msg)) + + def event(self, coro): + """A decorator that registers an event to listen to. + + You can find more info about the events on the :ref:`documentation below `. + + The events must be a |corourl|_, if not, :exc:`ClientException` is raised. + + Example: :: + + @client.event + @asyncio.coroutine + def on_ready(): + print('Ready!') + """ + + if not asyncio.iscoroutinefunction(coro): + raise ClientException('event registered must be a coroutine function') + + setattr(self, coro.__name__, coro) + log.info('{0.__name__} has successfully been registered as an event'.format(coro)) + return coro + + @asyncio.coroutine + def start_private_message(self, user): + """|coro| + + Starts a private message with the user. This allows you to + :meth:`send_message` to the user. + + Note + ----- + This method should rarely be called as :meth:`send_message` + does it automatically for you. + + Parameters + ----------- + user : :class:`User` + The user to start the private message with. + + Raises + ------ + HTTPException + The request failed. + """ + + if not isinstance(user, User): + raise TypeError('user argument must be a User') + + payload = { + 'recipient_id': user.id + } + + r = requests.post('{}/{}/channels'.format(endpoints.USERS, self.user.id), json=payload, headers=self.headers) + log.debug(request_logging_format.format(response=r)) + utils._verify_successful_response(r) + data = r.json() + log.debug(request_success_log.format(response=r, json=payload, data=data)) + self.private_channels.append(PrivateChannel(id=data['id'], user=user)) + + @asyncio.coroutine + def send_message(self, destination, content, *, mentions=True, tts=False): + """|coro| + + Sends a message to the destination given with the content given. + + The destination could be a :class:`Channel`, :class:`PrivateChannel` or :class:`Server`. + For convenience it could also be a :class:`User`. If it's a :class:`User` or :class:`PrivateChannel` + then it sends the message via private message, otherwise it sends the message to the channel. + If the destination is a :class:`Server` then it's equivalent to calling + :meth:`Server.get_default_channel` and sending it there. If it is a :class:`Object` + instance then it is assumed to be the destination ID. + + .. versionchanged:: 0.9.0 + ``str`` being allowed was removed and replaced with :class:`Object`. + + The content must be a type that can convert to a string through ``str(content)``. + + The mentions must be either an array of :class:`User` to mention or a boolean. If + ``mentions`` is ``True`` then all the users mentioned in the content are mentioned, otherwise + no one is mentioned. Note that to mention someone in the content, you should use :meth:`User.mention`. + + Parameters + ------------ + destination + The location to send the message. + content + The content of the message to send. + mentions + A list of :class:`User` to mention in the message or a boolean. Ignored for private messages. + tts : bool + Indicates if the message should be sent using text-to-speech. + + Raises + -------- + HTTPException + Sending the message failed. + InvalidArgument + The destination parameter is invalid. + + Returns + --------- + :class:`Message` + The message that was sent. + """ + + channel_id = self._resolve_destination(destination) + + content = str(content) + mentions = self._resolve_mentions(content, mentions) + + url = '{base}/{id}/messages'.format(base=endpoints.CHANNELS, id=channel_id) + payload = { + 'content': content, + 'mentions': mentions + } + + if tts: + payload['tts'] = True + + resp = yield from self.session.post(url, data=to_json(payload), headers=self.headers) + log.debug(request_logging_format.format(method='POST', response=resp)) + yield from utils._verify_successful_response(resp) + data = yield from resp.json() + log.debug(request_success_log.format(response=resp, json=payload, data=data)) + channel = self.get_channel(data.get('channel_id')) + message = Message(channel=channel, **data) + return message diff --git a/discord/errors.py b/discord/errors.py index abe0bff92..0d75fb536 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -24,11 +24,6 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -try: - import http.client as httplib -except ImportError: - import httplib - class DiscordException(Exception): """Base exception class for discord.py @@ -56,28 +51,19 @@ class HTTPException(DiscordException): .. attribute:: response The response of the failed HTTP request. This is an - instance of `requests.Response`__. + instance of `aiohttp.ClientResponse`__. - __ http://docs.python-requests.org/en/latest/api/#requests.Response + __ http://aiohttp.readthedocs.org/en/stable/client_reference.html#aiohttp.ClientResponse """ def __init__(self, response, message=None): self.response = response - if message is None: - message = httplib.responses.get(response.status_code, 'HTTP error') - - message = '{0} (status code: {1.response.status_code})'.format(message, self) - - try: - data = response.json() - response_error = data['message'] - if response_error: - message = '{}: {}'.format(message, response_error) - except: - pass + fmt = '{0.reason} (status code: {0.status})' + if message: + fmt = fmt + ': {1}' - super(HTTPException, self).__init__(message) + super().__init__(fmt.format(self.response, message)) class InvalidArgument(ClientException): """Exception that's thrown when an argument to a function diff --git a/discord/state.py b/discord/state.py new file mode 100644 index 000000000..b62817a74 --- /dev/null +++ b/discord/state.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from .server import Server +from .user import User +from .message import Message +from .channel import Channel, PrivateChannel +from .member import Member +from .role import Role +from . import utils + +from collections import deque +import copy +import datetime + +class ConnectionState: + def __init__(self, dispatch, max_messages): + self.user = None + self.email = None + self.servers = [] + self.private_channels = [] + self.messages = deque(maxlen=max_messages) + self.dispatch = dispatch + + def _get_message(self, msg_id): + return utils.find(lambda m: m.id == msg_id, self.messages) + + def _get_server(self, guild_id): + return utils.find(lambda g: g.id == guild_id, self.servers) + + def _add_server(self, guild): + server = Server(**guild) + self.servers.append(server) + + def parse_ready(self, data): + self.user = User(**data['user']) + guilds = data.get('guilds') + + for guild in guilds: + if guild.get('unavailable', False): + continue + self._add_server(guild) + + for pm in data.get('private_channels'): + self.private_channels.append(PrivateChannel(id=pm['id'], + user=User(**pm['recipient']))) + + # we're all ready + self.dispatch('ready') + + def parse_message_create(self, data): + channel = self.get_channel(data.get('channel_id')) + message = Message(channel=channel, **data) + self.dispatch('message', message) + self.messages.append(message) + + def parse_message_delete(self, data): + channel = self.get_channel(data.get('channel_id')) + message_id = data.get('id') + found = self._get_message(message_id) + if found is not None: + self.dispatch('message_delete', found) + self.messages.remove(found) + + def parse_message_update(self, data): + older_message = self._get_message(data.get('id')) + if older_message is not None: + # create a copy of the new message + message = copy.copy(older_message) + # update the new update + for attr in data: + if attr == 'channel_id' or attr == 'author': + continue + value = data[attr] + if 'time' in attr: + setattr(message, attr, utils.parse_time(value)) + else: + setattr(message, attr, value) + self.dispatch('message_edit', older_message, message) + # update the older message + older_message = message + + def parse_presence_update(self, data): + server = self._get_server(data.get('guild_id')) + if server is not None: + status = data.get('status') + user = data['user'] + member_id = user['id'] + member = utils.find(lambda m: m.id == member_id, server.members) + if member is not None: + old_member = copy.copy(member) + member.status = data.get('status') + member.game_id = data.get('game_id') + member.name = user.get('username', member.name) + member.avatar = user.get('avatar', member.avatar) + + # call the event now + self.dispatch('status', member, old_member.game_id, old_member.status) + self.dispatch('member_update', old_member, member) + + def parse_user_update(self, data): + self.user = User(**data) + + def parse_channel_delete(self, data): + server = self._get_server(data.get('guild_id')) + if server is not None: + channel_id = data.get('id') + channel = utils.find(lambda c: c.id == channel_id, server.channels) + try: + server.channels.remove(channel) + self.dispatch('channel_delete', channel) + except ValueError: + return + + def parse_channel_update(self, data): + server = self._get_server(data.get('guild_id')) + if server is not None: + channel_id = data.get('id') + channel = utils.find(lambda c: c.id == channel_id, server.channels) + channel.update(server=server, **data) + self.dispatch('channel_update', channel) + + def parse_channel_create(self, data): + is_private = data.get('is_private', False) + channel = None + if is_private: + recipient = User(**data.get('recipient')) + pm_id = data.get('id') + channel = PrivateChannel(id=pm_id, user=recipient) + self.private_channels.append(channel) + else: + server = self._get_server(data.get('guild_id')) + if server is not None: + channel = Channel(server=server, **data) + server.channels.append(channel) + + self.dispatch('channel_create', channel) + + def parse_guild_member_add(self, data): + server = self._get_server(data.get('guild_id')) + member = Member(server=server, deaf=False, mute=False, **data) + server.members.append(member) + self.dispatch('member_join', member) + + def parse_guild_member_remove(self, data): + server = self._get_server(data.get('guild_id')) + if server is not None: + user_id = data['user']['id'] + member = utils.find(lambda m: m.id == user_id, server.members) + try: + server.members.remove(member) + self.dispatch('member_remove', member) + except ValueError: + return + + def parse_guild_member_update(self, data): + server = self._get_server(data.get('guild_id')) + user_id = data['user']['id'] + member = utils.find(lambda m: m.id == user_id, server.members) + if member is not None: + user = data['user'] + old_member = copy.copy(member) + member.name = user['username'] + member.discriminator = user['discriminator'] + member.avatar = user['avatar'] + member.roles = [] + # update the roles + for role in server.roles: + if role.id in data['roles']: + member.roles.append(role) + + self.dispatch('member_update', old_member, member) + + def parse_guild_create(self, data): + unavailable = data.get('unavailable') + if unavailable == False: + # GUILD_CREATE with unavailable in the response + # usually means that the server has become available + # and is therefore in the cache + server = self._get_server(data.get('id')) + if server is not None: + server.unavailable = False + self.dispatch('server_available', server) + return + + if unavailable == True: + # joined a server with unavailable == True so.. + return + + # if we're at this point then it was probably + # unavailable during the READY event and is now + # available, so it isn't in the cache... + + self._add_server(data) + self.dispatch('server_join', self.servers[-1]) + + def parse_guild_delete(self, data): + server = self._get_server(data.get('id')) + if data.get('unavailable', False) and server is not None: + # GUILD_DELETE with unavailable being True means that the + # server that was available is now currently unavailable + server.unavailable = True + self.dispatch('server_unavailable', server) + return + + try: + self.servers.remove(server) + self.dispatch('server_remove', server) + except ValueError: + return + + def parse_guild_role_create(self, data): + server = self._get_server(data.get('guild_id')) + role_data = data.get('role', {}) + everyone = server.id == role_data.get('id') + role = Role(everyone=everyone, **role_data) + server.roles.append(role) + self.dispatch('server_role_create', server, role) + + def parse_guild_role_delete(self, data): + server = self._get_server(data.get('guild_id')) + if server is not None: + role_id = data.get('role_id') + role = utils.find(lambda r: r.id == role_id, server.roles) + server.roles.remove(role) + self.dispatch('server_role_delete', server, role) + + def parse_guild_role_update(self, data): + server = self._get_server(data.get('guild_id')) + if server is not None: + role_id = data['role']['id'] + role = utils.find(lambda r: r.id == role_id, server.roles) + role.update(**data['role']) + self.dispatch('server_role_update', role) + + def parse_voice_state_update(self, data): + server = self._get_server(data.get('guild_id')) + if server is not None: + updated_member = server._update_voice_state(data) + self.dispatch('voice_state_update', updated_member) + + def parse_typing_start(self, data): + channel = self.get_channel(data.get('channel_id')) + if channel is not None: + member = None + user_id = data.get('user_id') + is_private = getattr(channel, 'is_private', None) + if is_private == None: + return + + if is_private: + member = channel.user + else: + members = channel.server.members + member = utils.find(lambda m: m.id == user_id, members) + + if member is not None: + timestamp = datetime.datetime.utcfromtimestamp(data.get('timestamp')) + self.dispatch('typing', channel, member, timestamp) + + def get_channel(self, id): + if id is None: + return None + + for server in self.servers: + for channel in server.channels: + if channel.id == id: + return channel + + for pm in self.private_channels: + if pm.id == id: + return pm diff --git a/discord/utils.py b/discord/utils.py index 423ddc8ae..3e22c710c 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -27,7 +27,7 @@ DEALINGS IN THE SOFTWARE. from re import split as re_split from .errors import HTTPException, InvalidArgument import datetime - +import asyncio def parse_time(timestamp): if timestamp: @@ -60,11 +60,13 @@ def find(predicate, seq): def _null_event(*args, **kwargs): pass +@asyncio.coroutine def _verify_successful_response(response): - code = response.status_code + code = response.status success = code >= 200 and code < 300 if not success: - raise HTTPException(response) + data = yield from response.json() + raise HTTPException(response, data.get('message')) def _get_mime_type_for_image(data): if data.startswith(b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A'): @@ -73,3 +75,8 @@ def _get_mime_type_for_image(data): return 'image/jpeg' else: raise InvalidArgument('Unsupported image type given') + +try: + create_task = asyncio.ensure_future +except AttributeError: + create_task = asyncio.async diff --git a/docs/api.rst b/docs/api.rst index 935f05dc7..1e0bdd624 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -34,17 +34,34 @@ overriding the specific events. For example: :: import discord class MyClient(discord.Client): + + @asyncio.coroutine def on_message(self, message): - self.send_message(message.channel, 'Hello World!') + yield from self.send_message(message.channel, 'Hello World!') If an event handler raises an exception, :func:`on_error` will be called to handle it, which defaults to print a traceback and ignore the exception. +.. warning:: + + All the events must be a |corourl|_. If they aren't, then you might get unexpected + errors. In order to turn a function into a coroutine they must either be decorated + with ``@asyncio.coroutine`` or in Python 3.5+ be defined using the ``async def`` + declaration. + + The following two functions are examples of coroutine functions: :: + + async def on_ready(): + pass + + @asyncio.coroutine + def on_ready(): + pass + .. versionadded:: 0.7.0 Subclassing to listen to events. - .. function:: on_ready() Called when the client is done preparing the data received from Discord. Usually after login is successful diff --git a/docs/conf.py b/docs/conf.py index 3183d7761..bb5d585c4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -32,12 +32,19 @@ sys.path.insert(0, os.path.abspath('..')) extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.extlinks', + 'sphinx.ext.napoleon', ] extlinks = { 'issue': ('https://github.com/Rapptz/discord.py/issues/%s', 'issue '), } +rst_prolog = """ +.. |coro| replace:: This function is a |corourl|_. +.. |corourl| replace:: *coroutine* +.. _corourl: https://docs.python.org/3/library/asyncio-task.html#coroutine +""" + # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates']