From 4768d950c590ba170ead20aad7ccc797a7d8e737 Mon Sep 17 00:00:00 2001
From: Rapptz <rapptz@gmail.com>
Date: Sun, 14 Feb 2016 19:24:26 -0500
Subject: [PATCH] Offline members are now added by default automatically.

This commit adds support for GUILD_MEMBERS_CHUNK which had to be done
due to forced large_threshold requirements in the library.
---
 discord/client.py | 87 +++++++++++++++++++++++++++++++++++++++++++----
 discord/server.py |  8 +++--
 discord/state.py  | 60 ++++++++++++++++++++++++++++----
 3 files changed, 139 insertions(+), 16 deletions(-)

diff --git a/discord/client.py b/discord/client.py
index 4e7083f71..fa096c984 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -51,7 +51,7 @@ import logging, traceback
 import sys, time, re, json
 import tempfile, os, hashlib
 import itertools
-import zlib
+import zlib, math
 from random import randint as random_integer
 
 PY35 = sys.version_info >= (3, 5)
@@ -81,6 +81,10 @@ class Client:
         Indicates if :meth:`login` should cache the authentication tokens. Defaults
         to ``True``. The method in which the cache is written is done by writing to
         disk to a temporary directory.
+    request_offline : Optional[bool]
+        Indicates if the client should request the offline members of every server.
+        If this is False, then member lists will not store offline members if the
+        number of members in the server is greater than 250. Defaults to ``True``.
 
     Attributes
     -----------
@@ -117,12 +121,13 @@ class Client:
         self.loop = asyncio.get_event_loop() if loop is None else loop
         self._listeners = []
         self.cache_auth = options.get('cache_auth', True)
+        self.request_offline = options.get('request_offline', True)
 
         max_messages = options.get('max_messages')
         if max_messages is None or max_messages < 100:
             max_messages = 5000
 
-        self.connection = ConnectionState(self.dispatch, max_messages)
+        self.connection = ConnectionState(self.dispatch, max_messages, loop=self.loop)
 
         # Blame React for this
         user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}'
@@ -143,6 +148,25 @@ class Client:
 
     # internals
 
+    def _get_all_chunks(self):
+        # a chunk has a maximum of 1000 members.
+        # we need to find out how many futures we're actually waiting for
+        large_servers = filter(lambda s: s.large, self.servers)
+        futures = []
+        for server in large_servers:
+            chunks_needed = math.ceil(server._member_count / 1000)
+            for chunk in range(chunks_needed):
+                futures.append(self.connection.receive_chunk(server.id))
+
+        return futures
+
+    @asyncio.coroutine
+    def _fill_offline(self):
+        yield from self.request_offline_members(filter(lambda s: s.large, self.servers))
+        chunks = self._get_all_chunks()
+        yield from asyncio.wait(chunks)
+        self.dispatch('ready')
+
     def _get_cache_filename(self, email):
         filename = hashlib.md5(email.encode('utf-8')).hexdigest()
         return os.path.join(tempfile.gettempdir(), 'discord_py', filename)
@@ -335,12 +359,13 @@ class Client:
             return
 
         event = msg.get('t')
+        is_ready = event == 'READY'
 
-        if event == 'READY':
+        if is_ready:
             self.connection.clear()
             self.session_id = data['session_id']
 
-        if event == 'READY' or event == 'RESUMED':
+        if is_ready or event == 'RESUMED':
             interval = data['heartbeat_interval'] / 1000.0
             self.keep_alive = utils.create_task(self.keep_alive_handler(interval), loop=self.loop)
 
@@ -362,10 +387,19 @@ class Client:
             return
 
         parser = 'parse_' + event.lower()
-        if hasattr(self.connection, parser):
-            getattr(self.connection, parser)(data)
+
+        try:
+            func = getattr(self.connection, parser)
+        except AttributeError:
+            log.info('Unhandled event {}'.format(event))
         else:
-            log.info("Unhandled event {}".format(event))
+            func(data)
+
+        if is_ready:
+            if self.request_offline:
+                utils.create_task(self._fill_offline(), loop=self.loop)
+            else:
+                self.dispatch('ready')
 
     @asyncio.coroutine
     def _make_websocket(self, initial=True):
@@ -389,6 +423,7 @@ class Client:
                         '$referring_domain': ''
                     },
                     'compress': True,
+                    'large_threshold': 250,
                     'v': 3
                 }
             }
@@ -1218,6 +1253,44 @@ class Client:
 
     # Member management
 
+    @asyncio.coroutine
+    def request_offline_members(self, server):
+        """|coro|
+
+        Requests previously offline members from the server to be filled up
+        into the :attr:`Server.members` cache. If the client was initialised
+        with ``request_offline`` as ``True`` then calling this function would
+        not do anything.
+
+        When the client logs on and connects to the websocket, Discord does
+        not provide the library with offline members if the number of members
+        in the server is larger than 250. You can check if a server is large
+        if :attr:`Server.large` is ``True``.
+
+        Parameters
+        -----------
+        server : :class:`Server` or iterable
+            The server to request offline members for. If this parameter is a
+            iterable then it is interpreted as an iterator of servers to
+            request offline members for.
+        """
+
+        if hasattr(server, 'id'):
+            guild_id = server.id
+        else:
+            guild_id = [s.id for s in server]
+
+        payload = {
+            'op': 8,
+            'd': {
+                'guild_id': guild_id,
+                'query': '',
+                'limit': 0
+            }
+        }
+
+        yield from self._send_ws(utils.to_json(payload))
+
     @asyncio.coroutine
     def kick(self, member):
         """|coro|
diff --git a/discord/server.py b/discord/server.py
index c787a6408..f95da70b7 100644
--- a/discord/server.py
+++ b/discord/server.py
@@ -84,9 +84,10 @@ class Server(Hashable):
         Check the :func:`on_server_unavailable` and :func:`on_server_available` events.
     """
 
-    __slots__ = [ 'afk_timeout', 'afk_channel', '_members', '_channels', 'icon',
-                  'name', 'id', 'owner', 'unavailable', 'name', 'me', 'region',
-                  '_default_role', '_default_channel', 'roles', '_member_count']
+    __slots__ = ['afk_timeout', 'afk_channel', '_members', '_channels', 'icon',
+                 'name', 'id', 'owner', 'unavailable', 'name', 'me', 'region',
+                 '_default_role', '_default_channel', 'roles', '_member_count',
+                 'large' ]
 
     def __init__(self, **kwargs):
         self._channels = {}
@@ -139,6 +140,7 @@ class Server(Hashable):
         # according to Stan, this is always available even if the guild is unavailable
         self._member_count = guild['member_count']
         self.name = guild.get('name')
+        self.large = guild.get('large', self._member_count > 250)
         self.region = guild.get('region')
         try:
             self.region = ServerRegion(self.region)
diff --git a/discord/state.py b/discord/state.py
index c929d8ff9..6c41f9a71 100644
--- a/discord/state.py
+++ b/discord/state.py
@@ -34,14 +34,26 @@ from .role import Role
 from . import utils
 from .enums import Status
 
-from collections import deque
+
+from collections import deque, namedtuple
 import copy
 import datetime
+import asyncio
+import enum
+import logging
+
+class ListenerType(enum.Enum):
+    chunk = 0
+
+Listener = namedtuple('Listener', ('type', 'future', 'predicate'))
+log = logging.getLogger(__name__)
 
 class ConnectionState:
-    def __init__(self, dispatch, max_messages):
+    def __init__(self, dispatch, max_messages, *, loop):
+        self.loop = loop
         self.max_messages = max_messages
         self.dispatch = dispatch
+        self._listeners = []
         self.clear()
 
     def clear(self):
@@ -52,6 +64,30 @@ class ConnectionState:
         self._private_channels_by_user = {}
         self.messages = deque(maxlen=self.max_messages)
 
+    def process_listeners(self, listener_type, argument, result):
+        removed = []
+        for i, listener in enumerate(self._listeners):
+            if listener.type != listener_type:
+                continue
+
+            future = listener.future
+            if future.cancelled():
+                removed.append(i)
+                continue
+
+            try:
+                passed = listener.predicate(argument)
+            except Exception as e:
+                future.set_exception(e)
+                removed.append(i)
+            else:
+                if passed:
+                    future.set_result(result)
+                    removed.append(i)
+
+        for index in reversed(removed):
+            del self._listeners[index]
+
     @property
     def servers(self):
         return self._servers.values()
@@ -103,9 +139,6 @@ class ConnectionState:
             self._add_private_channel(PrivateChannel(id=pm['id'],
                                      user=User(**pm['recipient'])))
 
-        # we're all ready
-        self.dispatch('ready')
-
     def parse_message_create(self, data):
         channel = self.get_channel(data.get('channel_id'))
         message = Message(channel=channel, **data)
@@ -213,7 +246,7 @@ class ConnectionState:
 
     def parse_guild_member_add(self, data):
         server = self._get_server(data.get('guild_id'))
-        self._add_member(server, data)
+        member = self._add_member(server, data)
         server._member_count += 1
         self.dispatch('member_join', member)
 
@@ -345,6 +378,15 @@ class ConnectionState:
                 role._update(**data['role'])
                 self.dispatch('server_role_update', old_role, role)
 
+    def parse_guild_members_chunk(self, data):
+        server = self._get_server(data.get('guild_id'))
+        members = data.get('members', [])
+        for member in members:
+            self._add_member(server, member)
+
+        log.info('processed a chunk for {} members.'.format(len(members)))
+        self.process_listeners(ListenerType.chunk, server, len(members))
+
     def parse_voice_state_update(self, data):
         server = self._get_server(data.get('guild_id'))
         if server is not None:
@@ -381,3 +423,9 @@ class ConnectionState:
         pm = self._get_private_channel(id)
         if pm is not None:
             return pm
+
+    def receive_chunk(self, guild_id):
+        future = asyncio.Future(loop=self.loop)
+        listener = Listener(ListenerType.chunk, future, lambda s: s.id == guild_id)
+        self._listeners.append(listener)
+        return future