From 4768d950c590ba170ead20aad7ccc797a7d8e737 Mon Sep 17 00:00:00 2001 From: Rapptz <rapptz@gmail.com> Date: Sun, 14 Feb 2016 19:24:26 -0500 Subject: [PATCH] Offline members are now added by default automatically. This commit adds support for GUILD_MEMBERS_CHUNK which had to be done due to forced large_threshold requirements in the library. --- discord/client.py | 87 +++++++++++++++++++++++++++++++++++++++++++---- discord/server.py | 8 +++-- discord/state.py | 60 ++++++++++++++++++++++++++++---- 3 files changed, 139 insertions(+), 16 deletions(-) diff --git a/discord/client.py b/discord/client.py index 4e7083f71..fa096c984 100644 --- a/discord/client.py +++ b/discord/client.py @@ -51,7 +51,7 @@ import logging, traceback import sys, time, re, json import tempfile, os, hashlib import itertools -import zlib +import zlib, math from random import randint as random_integer PY35 = sys.version_info >= (3, 5) @@ -81,6 +81,10 @@ class Client: Indicates if :meth:`login` should cache the authentication tokens. Defaults to ``True``. The method in which the cache is written is done by writing to disk to a temporary directory. + request_offline : Optional[bool] + Indicates if the client should request the offline members of every server. + If this is False, then member lists will not store offline members if the + number of members in the server is greater than 250. Defaults to ``True``. Attributes ----------- @@ -117,12 +121,13 @@ class Client: self.loop = asyncio.get_event_loop() if loop is None else loop self._listeners = [] self.cache_auth = options.get('cache_auth', True) + self.request_offline = options.get('request_offline', True) max_messages = options.get('max_messages') if max_messages is None or max_messages < 100: max_messages = 5000 - self.connection = ConnectionState(self.dispatch, max_messages) + self.connection = ConnectionState(self.dispatch, max_messages, loop=self.loop) # Blame React for this user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}' @@ -143,6 +148,25 @@ class Client: # internals + def _get_all_chunks(self): + # a chunk has a maximum of 1000 members. + # we need to find out how many futures we're actually waiting for + large_servers = filter(lambda s: s.large, self.servers) + futures = [] + for server in large_servers: + chunks_needed = math.ceil(server._member_count / 1000) + for chunk in range(chunks_needed): + futures.append(self.connection.receive_chunk(server.id)) + + return futures + + @asyncio.coroutine + def _fill_offline(self): + yield from self.request_offline_members(filter(lambda s: s.large, self.servers)) + chunks = self._get_all_chunks() + yield from asyncio.wait(chunks) + self.dispatch('ready') + def _get_cache_filename(self, email): filename = hashlib.md5(email.encode('utf-8')).hexdigest() return os.path.join(tempfile.gettempdir(), 'discord_py', filename) @@ -335,12 +359,13 @@ class Client: return event = msg.get('t') + is_ready = event == 'READY' - if event == 'READY': + if is_ready: self.connection.clear() self.session_id = data['session_id'] - if event == 'READY' or event == 'RESUMED': + if is_ready or event == 'RESUMED': interval = data['heartbeat_interval'] / 1000.0 self.keep_alive = utils.create_task(self.keep_alive_handler(interval), loop=self.loop) @@ -362,10 +387,19 @@ class Client: return parser = 'parse_' + event.lower() - if hasattr(self.connection, parser): - getattr(self.connection, parser)(data) + + try: + func = getattr(self.connection, parser) + except AttributeError: + log.info('Unhandled event {}'.format(event)) else: - log.info("Unhandled event {}".format(event)) + func(data) + + if is_ready: + if self.request_offline: + utils.create_task(self._fill_offline(), loop=self.loop) + else: + self.dispatch('ready') @asyncio.coroutine def _make_websocket(self, initial=True): @@ -389,6 +423,7 @@ class Client: '$referring_domain': '' }, 'compress': True, + 'large_threshold': 250, 'v': 3 } } @@ -1218,6 +1253,44 @@ class Client: # Member management + @asyncio.coroutine + def request_offline_members(self, server): + """|coro| + + Requests previously offline members from the server to be filled up + into the :attr:`Server.members` cache. If the client was initialised + with ``request_offline`` as ``True`` then calling this function would + not do anything. + + When the client logs on and connects to the websocket, Discord does + not provide the library with offline members if the number of members + in the server is larger than 250. You can check if a server is large + if :attr:`Server.large` is ``True``. + + Parameters + ----------- + server : :class:`Server` or iterable + The server to request offline members for. If this parameter is a + iterable then it is interpreted as an iterator of servers to + request offline members for. + """ + + if hasattr(server, 'id'): + guild_id = server.id + else: + guild_id = [s.id for s in server] + + payload = { + 'op': 8, + 'd': { + 'guild_id': guild_id, + 'query': '', + 'limit': 0 + } + } + + yield from self._send_ws(utils.to_json(payload)) + @asyncio.coroutine def kick(self, member): """|coro| diff --git a/discord/server.py b/discord/server.py index c787a6408..f95da70b7 100644 --- a/discord/server.py +++ b/discord/server.py @@ -84,9 +84,10 @@ class Server(Hashable): Check the :func:`on_server_unavailable` and :func:`on_server_available` events. """ - __slots__ = [ 'afk_timeout', 'afk_channel', '_members', '_channels', 'icon', - 'name', 'id', 'owner', 'unavailable', 'name', 'me', 'region', - '_default_role', '_default_channel', 'roles', '_member_count'] + __slots__ = ['afk_timeout', 'afk_channel', '_members', '_channels', 'icon', + 'name', 'id', 'owner', 'unavailable', 'name', 'me', 'region', + '_default_role', '_default_channel', 'roles', '_member_count', + 'large' ] def __init__(self, **kwargs): self._channels = {} @@ -139,6 +140,7 @@ class Server(Hashable): # according to Stan, this is always available even if the guild is unavailable self._member_count = guild['member_count'] self.name = guild.get('name') + self.large = guild.get('large', self._member_count > 250) self.region = guild.get('region') try: self.region = ServerRegion(self.region) diff --git a/discord/state.py b/discord/state.py index c929d8ff9..6c41f9a71 100644 --- a/discord/state.py +++ b/discord/state.py @@ -34,14 +34,26 @@ from .role import Role from . import utils from .enums import Status -from collections import deque + +from collections import deque, namedtuple import copy import datetime +import asyncio +import enum +import logging + +class ListenerType(enum.Enum): + chunk = 0 + +Listener = namedtuple('Listener', ('type', 'future', 'predicate')) +log = logging.getLogger(__name__) class ConnectionState: - def __init__(self, dispatch, max_messages): + def __init__(self, dispatch, max_messages, *, loop): + self.loop = loop self.max_messages = max_messages self.dispatch = dispatch + self._listeners = [] self.clear() def clear(self): @@ -52,6 +64,30 @@ class ConnectionState: self._private_channels_by_user = {} self.messages = deque(maxlen=self.max_messages) + def process_listeners(self, listener_type, argument, result): + removed = [] + for i, listener in enumerate(self._listeners): + if listener.type != listener_type: + continue + + future = listener.future + if future.cancelled(): + removed.append(i) + continue + + try: + passed = listener.predicate(argument) + except Exception as e: + future.set_exception(e) + removed.append(i) + else: + if passed: + future.set_result(result) + removed.append(i) + + for index in reversed(removed): + del self._listeners[index] + @property def servers(self): return self._servers.values() @@ -103,9 +139,6 @@ class ConnectionState: self._add_private_channel(PrivateChannel(id=pm['id'], user=User(**pm['recipient']))) - # we're all ready - self.dispatch('ready') - def parse_message_create(self, data): channel = self.get_channel(data.get('channel_id')) message = Message(channel=channel, **data) @@ -213,7 +246,7 @@ class ConnectionState: def parse_guild_member_add(self, data): server = self._get_server(data.get('guild_id')) - self._add_member(server, data) + member = self._add_member(server, data) server._member_count += 1 self.dispatch('member_join', member) @@ -345,6 +378,15 @@ class ConnectionState: role._update(**data['role']) self.dispatch('server_role_update', old_role, role) + def parse_guild_members_chunk(self, data): + server = self._get_server(data.get('guild_id')) + members = data.get('members', []) + for member in members: + self._add_member(server, member) + + log.info('processed a chunk for {} members.'.format(len(members))) + self.process_listeners(ListenerType.chunk, server, len(members)) + def parse_voice_state_update(self, data): server = self._get_server(data.get('guild_id')) if server is not None: @@ -381,3 +423,9 @@ class ConnectionState: pm = self._get_private_channel(id) if pm is not None: return pm + + def receive_chunk(self, guild_id): + future = asyncio.Future(loop=self.loop) + listener = Listener(ListenerType.chunk, future, lambda s: s.id == guild_id) + self._listeners.append(listener) + return future