From 20041ea756305f20c86a621232639932c50f107c Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sat, 7 Jan 2017 21:55:47 -0500 Subject: [PATCH] Implement AutoShardedClient for transparent sharding. This allows people to run their >2,500 guild bot in a single process without the headaches of IPC/RPC or much difficulty. --- discord/__init__.py | 1 + discord/client.py | 8 +- discord/errors.py | 9 ++- discord/gateway.py | 80 ++++++++++++++------ discord/guild.py | 8 ++ discord/http.py | 9 +++ discord/shard.py | 174 ++++++++++++++++++++++++++++++++++++++++++++ discord/state.py | 80 +++++++++++++++++++- docs/api.rst | 3 + 9 files changed, 341 insertions(+), 31 deletions(-) create mode 100644 discord/shard.py diff --git a/discord/__init__.py b/discord/__init__.py index e8a29e452..4d04aeb8b 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -37,6 +37,7 @@ from . import utils, opus, compat, abc from .enums import ChannelType, GuildRegion, Status, MessageType, VerificationLevel from collections import namedtuple from .embeds import Embed +from .shard import AutoShardedClient import logging diff --git a/discord/client.py b/discord/client.py index 2e0696c90..f8f458708 100644 --- a/discord/client.py +++ b/discord/client.py @@ -142,6 +142,7 @@ class Client: self.connection = ConnectionState(dispatch=self.dispatch, chunker=self.request_offline_members, syncer=self._syncer, http=self.http, loop=self.loop, **options) + self.connection.shard_count = self.shard_count self._closed = asyncio.Event(loop=self.loop) self._is_logged_in = asyncio.Event(loop=self.loop) self._is_ready = asyncio.Event(loop=self.loop) @@ -405,11 +406,14 @@ class Client: while not self.is_closed: try: - yield from self.ws.poll_event() + yield from ws.poll_event() except (ReconnectWebSocket, ResumeWebSocket) as e: resume = type(e) is ResumeWebSocket log.info('Got ' + type(e).__name__) - self.ws = yield from DiscordWebSocket.from_client(self, resume=resume) + self.ws = yield from DiscordWebSocket.from_client(self, shard_id=self.shard_id, + session=self.ws.session_id, + sequence=self.ws.sequence, + resume=resume) except ConnectionClosed as e: yield from self.close() if e.code != 1000: diff --git a/discord/errors.py b/discord/errors.py index 5449b77ef..46751b627 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -118,14 +118,17 @@ class ConnectionClosed(ClientException): Attributes ----------- - code : int + code: int The close code of the websocket. - reason : str + reason: str The reason provided for the closure. + shard_id: Optional[int] + The shard ID that got closed if applicable. """ - def __init__(self, original): + def __init__(self, original, *, shard_id): # This exception is just the same exception except # reconfigured to subclass ClientException for users self.code = original.code self.reason = original.reason + self.shard_id = shard_id super().__init__(str(original)) diff --git a/discord/gateway.py b/discord/gateway.py index 2154cc984..fcba2dfcc 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -47,11 +47,13 @@ __all__ = [ 'ReconnectWebSocket', 'DiscordWebSocket', class ReconnectWebSocket(Exception): """Signals to handle the RECONNECT opcode.""" - pass + def __init__(self, shard_id): + self.shard_id = shard_id class ResumeWebSocket(Exception): """Signals to initialise via RESUME opcode instead of IDENTIFY.""" - pass + def __init__(self, shard_id): + self.shard_id = shard_id EventListener = namedtuple('EventListener', 'predicate event result future') @@ -81,7 +83,7 @@ class KeepAliveHandler(threading.Thread): def get_payload(self): return { 'op': self.ws.HEARTBEAT, - 'd': self.ws._connection.sequence + 'd': self.ws.sequence } def stop(self): @@ -165,9 +167,13 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): # the keep alive self._keep_alive = None + # ws related stuff + self.session_id = None + self.sequence = None + @classmethod @asyncio.coroutine - def from_client(cls, client, *, resume=False): + def from_client(cls, client, *, shard_id=None, session=None, sequence=None, resume=False): """Creates a main websocket for Discord from a :class:`Client`. This is for internal use only. @@ -180,8 +186,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): ws._connection = client.connection ws._dispatch = client.dispatch ws.gateway = gateway - ws.shard_id = client.shard_id - ws.shard_count = client.shard_count + ws.shard_id = shard_id + ws.shard_count = client.connection.shard_count + ws.session_id = session + ws.sequence = sequence client.connection._update_references(ws) @@ -206,6 +214,35 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): else: return ws + @classmethod + @asyncio.coroutine + def from_sharded_client(cls, client): + if client.shard_count is None: + client.shard_count, gateway = yield from client.http.get_bot_gateway() + else: + gateway = yield from client.http.get_gateway() + + ret = [] + client.connection.shard_count = client.shard_count + + for shard_id in range(client.shard_count): + ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls) + ws.token = client.http.token + ws._connection = client.connection + ws._dispatch = client.dispatch + ws.gateway = gateway + ws.shard_id = shard_id + ws.shard_count = client.shard_count + + # OP HELLO + yield from ws.poll_event() + yield from ws.identify() + ret.append(ws) + log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id) + yield from asyncio.sleep(5.0, loop=client.loop) + + return ret + def wait_for(self, event, predicate, result=None): """Waits for a DISPATCH'd event that meets the predicate. @@ -262,12 +299,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): @asyncio.coroutine def resume(self): """Sends the RESUME packet.""" - state = self._connection payload = { 'op': self.RESUME, 'd': { - 'seq': state.sequence, - 'session_id': state.session_id, + 'seq': self.sequence, + 'session_id': self.session_id, 'token': self.token } } @@ -283,16 +319,15 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): msg = msg.decode('utf-8') msg = json.loads(msg) - state = self._connection - log.debug('WebSocket Event: {}'.format(msg)) + log.debug('For Shard ID {}: WebSocket Event: {}'.format(self.shard_id, msg)) self._dispatch('socket_response', msg) op = msg.get('op') data = msg.get('d') seq = msg.get('s') if seq is not None: - state.sequence = seq + self.sequence = seq if op == self.RECONNECT: # "reconnect" can only be handled by the Client @@ -300,7 +335,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): # internal exception signalling to reconnect. log.info('Received RECONNECT opcode.') yield from self.close() - raise ReconnectWebSocket() + raise ReconnectWebSocket(self.shard_id) if op == self.HEARTBEAT_ACK: return # disable noisy logging for now @@ -317,11 +352,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): return if op == self.INVALIDATE_SESSION: - state.sequence = None - state.session_id = None + self.sequence = None + self.session_id = None if data == True: yield from self.close() - raise ResumeWebSocket() + raise ResumeWebSocket(self.shard_id) yield from self.identify() return @@ -334,9 +369,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): is_ready = event == 'READY' if is_ready: - state.clear() - state.sequence = msg['s'] - state.session_id = data['session_id'] + self.sequence = msg['s'] + self.session_id = data['session_id'] parser = 'parse_' + event.lower() @@ -389,9 +423,9 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): except websockets.exceptions.ConnectionClosed as e: if self._can_handle_close(e.code): log.info('Websocket closed with {0.code} ({0.reason}), attempting a reconnect.'.format(e)) - raise ResumeWebSocket() from e + raise ResumeWebSocket(self.shard_id) from e else: - raise ConnectionClosed(e) from e + raise ConnectionClosed(e, shard_id=self.shard_id) from e @asyncio.coroutine def send(self, data): @@ -404,7 +438,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): yield from super().send(utils.to_json(data)) except websockets.exceptions.ConnectionClosed as e: if not self._can_handle_close(e.code): - raise ConnectionClosed(e) from e + raise ConnectionClosed(e, shard_id=self.shard_id) from e @asyncio.coroutine def change_presence(self, *, game=None, status=None, afk=False, since=0.0, idle=None): @@ -615,7 +649,7 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): msg = yield from asyncio.wait_for(self.recv(), timeout=30.0, loop=self.loop) yield from self.received_message(json.loads(msg)) except websockets.exceptions.ConnectionClosed as e: - raise ConnectionClosed(e) from e + raise ConnectionClosed(e, shard_id=None) from e @asyncio.coroutine def close_connection(self, force=False): diff --git a/discord/guild.py b/discord/guild.py index 0f37a214a..2255c297c 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -324,6 +324,14 @@ class Guild(Hashable): """Returns the true member count regardless of it being loaded fully or not.""" return self._member_count + @property + def shard_id(self): + """Returns the shard ID for this guild if applicable.""" + count = self._state.shard_count + if count is None: + return None + return (self.id >> 22) % count + @property def created_at(self): """Returns the guild's creation time in UTC.""" diff --git a/discord/http.py b/discord/http.py index 2b885dec4..4e5410d08 100644 --- a/discord/http.py +++ b/discord/http.py @@ -588,5 +588,14 @@ class HTTPClient: raise GatewayNotFound() from e return data.get('url') + '?encoding=json&v=6' + @asyncio.coroutine + def get_bot_gateway(self): + try: + data = yield from self.get(self.GATEWAY + '/bot', bucket=_func_()) + except HTTPException as e: + raise GatewayNotFound() from e + else: + return data['shards'], data['url'] + '?encoding=json&v=6' + def get_user_info(self, user_id): return self.get('{0.USERS}/{1}'.format(self, user_id), bucket=_func_()) diff --git a/discord/shard.py b/discord/shard.py new file mode 100644 index 000000000..2be0ea128 --- /dev/null +++ b/discord/shard.py @@ -0,0 +1,174 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2016 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from .state import AutoShardedConnectionState +from .client import Client +from .gateway import * +from .errors import ConnectionClosed +from . import compat + +import asyncio +import logging + +log = logging.getLogger(__name__) + +class Shard: + def __init__(self, ws, client): + self.ws = ws + self._client = client + self.loop = self._client.loop + self._current = asyncio.Future(loop=self.loop) + self._current.set_result(None) # we just need an already done future + + @property + def id(self): + return self.ws.shard_id + + @asyncio.coroutine + def poll(self): + try: + yield from self.ws.poll_event() + except (ReconnectWebSocket, ResumeWebSocket) as e: + resume = type(e) is ResumeWebSocket + log.info('Got ' + type(e).__name__) + self.ws = yield from DiscordWebSocket.from_client(self._client, resume=resume, + shard_id=self.id, + session=self.ws.session_id, + sequence=self.ws.sequence) + except ConnectionClosed as e: + yield from self._client.close() + if e.code != 1000: + raise + + def get_future(self): + if self._current.done(): + self._current = compat.create_task(self.poll(), loop=self.loop) + + return self._current + +class AutoShardedClient(Client): + """A client similar to :class:`Client` except it handles the complications + of sharding for the user into a more manageable and transparent single + process bot. + + When using this client, you will be able to use it as-if it was a regular + :class:`Client` with a single shard when implementation wise internally it + is split up into multiple shards. This allows you to not have to deal with + IPC or other complicated infrastructure. + + It is recommended to use this client only if you have surpassed at least + 1000 guilds. + + If no :attr:`shard_count` is provided, then the library will use the + Bot Gateway endpoint call to figure out how many shards to use. + """ + def __init__(self, *args, loop=None, **kwargs): + kwargs.pop('shard_id', None) + super().__init__(*args, loop=loop, **kwargs) + + self.connection = AutoShardedConnectionState(dispatch=self.dispatch, chunker=self.request_offline_members, + syncer=self._syncer, http=self.http, loop=self.loop, **kwargs) + + # instead of a single websocket, we have multiple + # the index is the shard_id + self.shards = [] + + @asyncio.coroutine + def request_offline_members(self, guild, *, shard_id=None): + """|coro| + + Requests previously offline members from the guild to be filled up + into the :attr:`Guild.members` cache. This function is usually not + called. + + When the client logs on and connects to the websocket, Discord does + not provide the library with offline members if the number of members + in the guild is larger than 250. You can check if a guild is large + if :attr:`Guild.large` is ``True``. + + Parameters + ----------- + guild: :class:`Guild` or list + The guild to request offline members for. If this parameter is a + list then it is interpreted as a list of guilds to request offline + members for. + """ + + try: + guild_id = guild.id + shard_id = shard_id or guild.shard_id + except AttributeError: + guild_id = [s.id for s in guild] + + payload = { + 'op': 8, + 'd': { + 'guild_id': guild_id, + 'query': '', + 'limit': 0 + } + } + + ws = self.shards[shard_id].ws + yield from ws.send_as_json(payload) + + @asyncio.coroutine + def connect(self): + """|coro| + + Creates a websocket connection and lets the websocket listen + to messages from discord. + + Raises + ------- + GatewayNotFound + If the gateway to connect to discord is not found. Usually if this + is thrown then there is a discord API outage. + ConnectionClosed + The websocket connection has been terminated. + """ + ret = yield from DiscordWebSocket.from_sharded_client(self) + self.shards = [Shard(ws, self) for ws in ret] + + while not self.is_closed: + pollers = [shard.get_future() for shard in self.shards] + yield from asyncio.wait(pollers, loop=self.loop, return_when=asyncio.FIRST_COMPLETED) + + @asyncio.coroutine + def close(self): + """|coro| + + Closes the connection to discord. + """ + if self.is_closed: + return + + for shard in self.shards: + yield from shard.ws.close() + + yield from self.http.close() + self._closed.set() + self._is_ready.clear() diff --git a/discord/state.py b/discord/state.py index 383b559fa..bd7fbdbe8 100644 --- a/discord/state.py +++ b/discord/state.py @@ -43,6 +43,7 @@ import datetime import asyncio import logging import weakref +import itertools class ListenerType(enum.Enum): chunk = 0 @@ -60,13 +61,12 @@ class ConnectionState: self.chunker = chunker self.syncer = syncer self.is_bot = None + self.shard_count = None self._listeners = [] self.clear() def clear(self): self.user = None - self.sequence = None - self.session_id = None self._users = weakref.WeakValueDictionary() self._calls = {} self._emojis = {} @@ -355,7 +355,8 @@ class ConnectionState: # the reason we're doing this is so it's also removed from the # private channel by user cache as well channel = self._get_private_channel(channel_id) - self._remove_private_channel(channel) + if channel is not None: + self._remove_private_channel(channel) def parse_channel_update(self, data): channel_type = try_enum(ChannelType, data.get('type')) @@ -701,3 +702,76 @@ class ConnectionState: listener = Listener(ListenerType.chunk, future, lambda s: s.id == guild_id) self._listeners.append(listener) return future + +class AutoShardedConnectionState(ConnectionState): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[]) + self._ready_task = None + + @asyncio.coroutine + def _delay_ready(self): + launch = self._ready_state.launch + while not launch.is_set(): + # this snippet of code is basically waiting 2 seconds + # until the last GUILD_CREATE was sent + launch.set() + yield from asyncio.sleep(2.0 * self.shard_count, loop=self.loop) + + guilds = sorted(self._ready_state.guilds, key=lambda g: g.shard_id) + + # we only want to request ~75 guilds per chunk request. + # we also want to split the chunks per shard_id + for shard_id, sub_guilds in itertools.groupby(guilds, key=lambda g: g.shard_id): + sub_guilds = list(sub_guilds) + + # split chunks by shard ID + chunks = [] + for guild in sub_guilds: + chunks.extend(self.chunks_needed(guild)) + + splits = [sub_guilds[i:i + 75] for i in range(0, len(sub_guilds), 75)] + for split in splits: + yield from self.chunker(split, shard_id=shard_id) + + # wait for the chunks + if chunks: + try: + yield from asyncio.wait(chunks, timeout=len(chunks) * 30.0, loop=self.loop) + except asyncio.TimeoutError: + log.info('Somehow timed out waiting for chunks for %s shard_id' % shard_id) + + self.dispatch('shard_ready', shard_id) + + # sleep a second for every shard ID. + # yield from asyncio.sleep(1.0, loop=self.loop) + + # remove the state + try: + del self._ready_state + except AttributeError: + pass # already been deleted somehow + + # regular users cannot shard so we won't worry about it here. + + # dispatch the event + self.dispatch('ready') + + def parse_ready(self, data): + if not hasattr(self, '_ready_state'): + self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[]) + + self.user = self.store_user(data['user']) + + guilds = self._ready_state.guilds + for guild_data in data['guilds']: + guild = self._add_guild_from_data(guild_data) + if not self.is_bot or guild.large: + guilds.append(guild) + + for pm in data.get('private_channels', []): + factory, _ = _channel_factory(pm['type']) + self._add_private_channel(factory(me=self.user, data=pm, state=self)) + + if self._ready_task is None: + self._ready_task = compat.create_task(self._delay_ready(), loop=self.loop) diff --git a/docs/api.rst b/docs/api.rst index bcbf1f471..e001d8bbd 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -37,6 +37,9 @@ Client .. autoclass:: Client :members: +.. autoclass:: AutoShardedClient + :members: + Voice -----