Browse Source

Implement AutoShardedClient for transparent sharding.

This allows people to run their >2,500 guild bot in a single process
without the headaches of IPC/RPC or much difficulty.
pull/447/head
Rapptz 8 years ago
parent
commit
20041ea756
  1. 1
      discord/__init__.py
  2. 8
      discord/client.py
  3. 9
      discord/errors.py
  4. 80
      discord/gateway.py
  5. 8
      discord/guild.py
  6. 9
      discord/http.py
  7. 174
      discord/shard.py
  8. 80
      discord/state.py
  9. 3
      docs/api.rst

1
discord/__init__.py

@ -37,6 +37,7 @@ from . import utils, opus, compat, abc
from .enums import ChannelType, GuildRegion, Status, MessageType, VerificationLevel
from collections import namedtuple
from .embeds import Embed
from .shard import AutoShardedClient
import logging

8
discord/client.py

@ -142,6 +142,7 @@ class Client:
self.connection = ConnectionState(dispatch=self.dispatch, chunker=self.request_offline_members,
syncer=self._syncer, http=self.http, loop=self.loop, **options)
self.connection.shard_count = self.shard_count
self._closed = asyncio.Event(loop=self.loop)
self._is_logged_in = asyncio.Event(loop=self.loop)
self._is_ready = asyncio.Event(loop=self.loop)
@ -405,11 +406,14 @@ class Client:
while not self.is_closed:
try:
yield from self.ws.poll_event()
yield from ws.poll_event()
except (ReconnectWebSocket, ResumeWebSocket) as e:
resume = type(e) is ResumeWebSocket
log.info('Got ' + type(e).__name__)
self.ws = yield from DiscordWebSocket.from_client(self, resume=resume)
self.ws = yield from DiscordWebSocket.from_client(self, shard_id=self.shard_id,
session=self.ws.session_id,
sequence=self.ws.sequence,
resume=resume)
except ConnectionClosed as e:
yield from self.close()
if e.code != 1000:

9
discord/errors.py

@ -118,14 +118,17 @@ class ConnectionClosed(ClientException):
Attributes
-----------
code : int
code: int
The close code of the websocket.
reason : str
reason: str
The reason provided for the closure.
shard_id: Optional[int]
The shard ID that got closed if applicable.
"""
def __init__(self, original):
def __init__(self, original, *, shard_id):
# This exception is just the same exception except
# reconfigured to subclass ClientException for users
self.code = original.code
self.reason = original.reason
self.shard_id = shard_id
super().__init__(str(original))

80
discord/gateway.py

@ -47,11 +47,13 @@ __all__ = [ 'ReconnectWebSocket', 'DiscordWebSocket',
class ReconnectWebSocket(Exception):
"""Signals to handle the RECONNECT opcode."""
pass
def __init__(self, shard_id):
self.shard_id = shard_id
class ResumeWebSocket(Exception):
"""Signals to initialise via RESUME opcode instead of IDENTIFY."""
pass
def __init__(self, shard_id):
self.shard_id = shard_id
EventListener = namedtuple('EventListener', 'predicate event result future')
@ -81,7 +83,7 @@ class KeepAliveHandler(threading.Thread):
def get_payload(self):
return {
'op': self.ws.HEARTBEAT,
'd': self.ws._connection.sequence
'd': self.ws.sequence
}
def stop(self):
@ -165,9 +167,13 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
# the keep alive
self._keep_alive = None
# ws related stuff
self.session_id = None
self.sequence = None
@classmethod
@asyncio.coroutine
def from_client(cls, client, *, resume=False):
def from_client(cls, client, *, shard_id=None, session=None, sequence=None, resume=False):
"""Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only.
@ -180,8 +186,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
ws._connection = client.connection
ws._dispatch = client.dispatch
ws.gateway = gateway
ws.shard_id = client.shard_id
ws.shard_count = client.shard_count
ws.shard_id = shard_id
ws.shard_count = client.connection.shard_count
ws.session_id = session
ws.sequence = sequence
client.connection._update_references(ws)
@ -206,6 +214,35 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
else:
return ws
@classmethod
@asyncio.coroutine
def from_sharded_client(cls, client):
if client.shard_count is None:
client.shard_count, gateway = yield from client.http.get_bot_gateway()
else:
gateway = yield from client.http.get_gateway()
ret = []
client.connection.shard_count = client.shard_count
for shard_id in range(client.shard_count):
ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls)
ws.token = client.http.token
ws._connection = client.connection
ws._dispatch = client.dispatch
ws.gateway = gateway
ws.shard_id = shard_id
ws.shard_count = client.shard_count
# OP HELLO
yield from ws.poll_event()
yield from ws.identify()
ret.append(ws)
log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id)
yield from asyncio.sleep(5.0, loop=client.loop)
return ret
def wait_for(self, event, predicate, result=None):
"""Waits for a DISPATCH'd event that meets the predicate.
@ -262,12 +299,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
@asyncio.coroutine
def resume(self):
"""Sends the RESUME packet."""
state = self._connection
payload = {
'op': self.RESUME,
'd': {
'seq': state.sequence,
'session_id': state.session_id,
'seq': self.sequence,
'session_id': self.session_id,
'token': self.token
}
}
@ -283,16 +319,15 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
msg = msg.decode('utf-8')
msg = json.loads(msg)
state = self._connection
log.debug('WebSocket Event: {}'.format(msg))
log.debug('For Shard ID {}: WebSocket Event: {}'.format(self.shard_id, msg))
self._dispatch('socket_response', msg)
op = msg.get('op')
data = msg.get('d')
seq = msg.get('s')
if seq is not None:
state.sequence = seq
self.sequence = seq
if op == self.RECONNECT:
# "reconnect" can only be handled by the Client
@ -300,7 +335,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
# internal exception signalling to reconnect.
log.info('Received RECONNECT opcode.')
yield from self.close()
raise ReconnectWebSocket()
raise ReconnectWebSocket(self.shard_id)
if op == self.HEARTBEAT_ACK:
return # disable noisy logging for now
@ -317,11 +352,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
return
if op == self.INVALIDATE_SESSION:
state.sequence = None
state.session_id = None
self.sequence = None
self.session_id = None
if data == True:
yield from self.close()
raise ResumeWebSocket()
raise ResumeWebSocket(self.shard_id)
yield from self.identify()
return
@ -334,9 +369,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
is_ready = event == 'READY'
if is_ready:
state.clear()
state.sequence = msg['s']
state.session_id = data['session_id']
self.sequence = msg['s']
self.session_id = data['session_id']
parser = 'parse_' + event.lower()
@ -389,9 +423,9 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
except websockets.exceptions.ConnectionClosed as e:
if self._can_handle_close(e.code):
log.info('Websocket closed with {0.code} ({0.reason}), attempting a reconnect.'.format(e))
raise ResumeWebSocket() from e
raise ResumeWebSocket(self.shard_id) from e
else:
raise ConnectionClosed(e) from e
raise ConnectionClosed(e, shard_id=self.shard_id) from e
@asyncio.coroutine
def send(self, data):
@ -404,7 +438,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
yield from super().send(utils.to_json(data))
except websockets.exceptions.ConnectionClosed as e:
if not self._can_handle_close(e.code):
raise ConnectionClosed(e) from e
raise ConnectionClosed(e, shard_id=self.shard_id) from e
@asyncio.coroutine
def change_presence(self, *, game=None, status=None, afk=False, since=0.0, idle=None):
@ -615,7 +649,7 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
msg = yield from asyncio.wait_for(self.recv(), timeout=30.0, loop=self.loop)
yield from self.received_message(json.loads(msg))
except websockets.exceptions.ConnectionClosed as e:
raise ConnectionClosed(e) from e
raise ConnectionClosed(e, shard_id=None) from e
@asyncio.coroutine
def close_connection(self, force=False):

8
discord/guild.py

@ -324,6 +324,14 @@ class Guild(Hashable):
"""Returns the true member count regardless of it being loaded fully or not."""
return self._member_count
@property
def shard_id(self):
"""Returns the shard ID for this guild if applicable."""
count = self._state.shard_count
if count is None:
return None
return (self.id >> 22) % count
@property
def created_at(self):
"""Returns the guild's creation time in UTC."""

9
discord/http.py

@ -588,5 +588,14 @@ class HTTPClient:
raise GatewayNotFound() from e
return data.get('url') + '?encoding=json&v=6'
@asyncio.coroutine
def get_bot_gateway(self):
try:
data = yield from self.get(self.GATEWAY + '/bot', bucket=_func_())
except HTTPException as e:
raise GatewayNotFound() from e
else:
return data['shards'], data['url'] + '?encoding=json&v=6'
def get_user_info(self, user_id):
return self.get('{0.USERS}/{1}'.format(self, user_id), bucket=_func_())

174
discord/shard.py

@ -0,0 +1,174 @@
# -*- coding: utf-8 -*-
"""
The MIT License (MIT)
Copyright (c) 2015-2016 Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from .state import AutoShardedConnectionState
from .client import Client
from .gateway import *
from .errors import ConnectionClosed
from . import compat
import asyncio
import logging
log = logging.getLogger(__name__)
class Shard:
def __init__(self, ws, client):
self.ws = ws
self._client = client
self.loop = self._client.loop
self._current = asyncio.Future(loop=self.loop)
self._current.set_result(None) # we just need an already done future
@property
def id(self):
return self.ws.shard_id
@asyncio.coroutine
def poll(self):
try:
yield from self.ws.poll_event()
except (ReconnectWebSocket, ResumeWebSocket) as e:
resume = type(e) is ResumeWebSocket
log.info('Got ' + type(e).__name__)
self.ws = yield from DiscordWebSocket.from_client(self._client, resume=resume,
shard_id=self.id,
session=self.ws.session_id,
sequence=self.ws.sequence)
except ConnectionClosed as e:
yield from self._client.close()
if e.code != 1000:
raise
def get_future(self):
if self._current.done():
self._current = compat.create_task(self.poll(), loop=self.loop)
return self._current
class AutoShardedClient(Client):
"""A client similar to :class:`Client` except it handles the complications
of sharding for the user into a more manageable and transparent single
process bot.
When using this client, you will be able to use it as-if it was a regular
:class:`Client` with a single shard when implementation wise internally it
is split up into multiple shards. This allows you to not have to deal with
IPC or other complicated infrastructure.
It is recommended to use this client only if you have surpassed at least
1000 guilds.
If no :attr:`shard_count` is provided, then the library will use the
Bot Gateway endpoint call to figure out how many shards to use.
"""
def __init__(self, *args, loop=None, **kwargs):
kwargs.pop('shard_id', None)
super().__init__(*args, loop=loop, **kwargs)
self.connection = AutoShardedConnectionState(dispatch=self.dispatch, chunker=self.request_offline_members,
syncer=self._syncer, http=self.http, loop=self.loop, **kwargs)
# instead of a single websocket, we have multiple
# the index is the shard_id
self.shards = []
@asyncio.coroutine
def request_offline_members(self, guild, *, shard_id=None):
"""|coro|
Requests previously offline members from the guild to be filled up
into the :attr:`Guild.members` cache. This function is usually not
called.
When the client logs on and connects to the websocket, Discord does
not provide the library with offline members if the number of members
in the guild is larger than 250. You can check if a guild is large
if :attr:`Guild.large` is ``True``.
Parameters
-----------
guild: :class:`Guild` or list
The guild to request offline members for. If this parameter is a
list then it is interpreted as a list of guilds to request offline
members for.
"""
try:
guild_id = guild.id
shard_id = shard_id or guild.shard_id
except AttributeError:
guild_id = [s.id for s in guild]
payload = {
'op': 8,
'd': {
'guild_id': guild_id,
'query': '',
'limit': 0
}
}
ws = self.shards[shard_id].ws
yield from ws.send_as_json(payload)
@asyncio.coroutine
def connect(self):
"""|coro|
Creates a websocket connection and lets the websocket listen
to messages from discord.
Raises
-------
GatewayNotFound
If the gateway to connect to discord is not found. Usually if this
is thrown then there is a discord API outage.
ConnectionClosed
The websocket connection has been terminated.
"""
ret = yield from DiscordWebSocket.from_sharded_client(self)
self.shards = [Shard(ws, self) for ws in ret]
while not self.is_closed:
pollers = [shard.get_future() for shard in self.shards]
yield from asyncio.wait(pollers, loop=self.loop, return_when=asyncio.FIRST_COMPLETED)
@asyncio.coroutine
def close(self):
"""|coro|
Closes the connection to discord.
"""
if self.is_closed:
return
for shard in self.shards:
yield from shard.ws.close()
yield from self.http.close()
self._closed.set()
self._is_ready.clear()

80
discord/state.py

@ -43,6 +43,7 @@ import datetime
import asyncio
import logging
import weakref
import itertools
class ListenerType(enum.Enum):
chunk = 0
@ -60,13 +61,12 @@ class ConnectionState:
self.chunker = chunker
self.syncer = syncer
self.is_bot = None
self.shard_count = None
self._listeners = []
self.clear()
def clear(self):
self.user = None
self.sequence = None
self.session_id = None
self._users = weakref.WeakValueDictionary()
self._calls = {}
self._emojis = {}
@ -355,7 +355,8 @@ class ConnectionState:
# the reason we're doing this is so it's also removed from the
# private channel by user cache as well
channel = self._get_private_channel(channel_id)
self._remove_private_channel(channel)
if channel is not None:
self._remove_private_channel(channel)
def parse_channel_update(self, data):
channel_type = try_enum(ChannelType, data.get('type'))
@ -701,3 +702,76 @@ class ConnectionState:
listener = Listener(ListenerType.chunk, future, lambda s: s.id == guild_id)
self._listeners.append(listener)
return future
class AutoShardedConnectionState(ConnectionState):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[])
self._ready_task = None
@asyncio.coroutine
def _delay_ready(self):
launch = self._ready_state.launch
while not launch.is_set():
# this snippet of code is basically waiting 2 seconds
# until the last GUILD_CREATE was sent
launch.set()
yield from asyncio.sleep(2.0 * self.shard_count, loop=self.loop)
guilds = sorted(self._ready_state.guilds, key=lambda g: g.shard_id)
# we only want to request ~75 guilds per chunk request.
# we also want to split the chunks per shard_id
for shard_id, sub_guilds in itertools.groupby(guilds, key=lambda g: g.shard_id):
sub_guilds = list(sub_guilds)
# split chunks by shard ID
chunks = []
for guild in sub_guilds:
chunks.extend(self.chunks_needed(guild))
splits = [sub_guilds[i:i + 75] for i in range(0, len(sub_guilds), 75)]
for split in splits:
yield from self.chunker(split, shard_id=shard_id)
# wait for the chunks
if chunks:
try:
yield from asyncio.wait(chunks, timeout=len(chunks) * 30.0, loop=self.loop)
except asyncio.TimeoutError:
log.info('Somehow timed out waiting for chunks for %s shard_id' % shard_id)
self.dispatch('shard_ready', shard_id)
# sleep a second for every shard ID.
# yield from asyncio.sleep(1.0, loop=self.loop)
# remove the state
try:
del self._ready_state
except AttributeError:
pass # already been deleted somehow
# regular users cannot shard so we won't worry about it here.
# dispatch the event
self.dispatch('ready')
def parse_ready(self, data):
if not hasattr(self, '_ready_state'):
self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[])
self.user = self.store_user(data['user'])
guilds = self._ready_state.guilds
for guild_data in data['guilds']:
guild = self._add_guild_from_data(guild_data)
if not self.is_bot or guild.large:
guilds.append(guild)
for pm in data.get('private_channels', []):
factory, _ = _channel_factory(pm['type'])
self._add_private_channel(factory(me=self.user, data=pm, state=self))
if self._ready_task is None:
self._ready_task = compat.create_task(self._delay_ready(), loop=self.loop)

3
docs/api.rst

@ -37,6 +37,9 @@ Client
.. autoclass:: Client
:members:
.. autoclass:: AutoShardedClient
:members:
Voice
-----

Loading…
Cancel
Save