diff --git a/discord/client.py b/discord/client.py index 9ec9adf04..73037aa8a 100644 --- a/discord/client.py +++ b/discord/client.py @@ -51,7 +51,7 @@ import logging, traceback import sys, time, re, json import tempfile, os, hashlib import itertools -import zlib, math +import zlib from random import randint as random_integer PY35 = sys.version_info >= (3, 5) @@ -122,7 +122,7 @@ class Client: if max_messages is None or max_messages < 100: max_messages = 5000 - self.connection = ConnectionState(self.dispatch, max_messages, loop=self.loop) + self.connection = ConnectionState(self.dispatch, self.request_offline_members, max_messages, loop=self.loop) # Blame Jake for this user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}' @@ -145,28 +145,6 @@ class Client: # internals - def _get_all_chunks(self): - # a chunk has a maximum of 1000 members. - # we need to find out how many futures we're actually waiting for - large_servers = filter(lambda s: s.large, self.servers) - futures = [] - for server in large_servers: - chunks_needed = math.ceil(server._member_count / 1000) - for chunk in range(chunks_needed): - futures.append(self.connection.receive_chunk(server.id)) - - return futures - - @asyncio.coroutine - def _fill_offline(self): - yield from self.request_offline_members(filter(lambda s: s.large, self.servers)) - chunks = self._get_all_chunks() - - if chunks: - yield from asyncio.wait(chunks) - - self.dispatch('ready') - def _get_cache_filename(self, email): filename = hashlib.md5(email.encode('utf-8')).hexdigest() return os.path.join(tempfile.gettempdir(), 'discord_py', filename) @@ -392,11 +370,10 @@ class Client: func = getattr(self.connection, parser) except AttributeError: log.info('Unhandled event {}'.format(event)) - else: - func(data) - if is_ready: - utils.create_task(self._fill_offline(), loop=self.loop) + result = func(data) + if asyncio.iscoroutine(result): + utils.create_task(result, loop=self.loop) @asyncio.coroutine def _make_websocket(self, initial=True): diff --git a/discord/state.py b/discord/state.py index 6db172182..fec2e111a 100644 --- a/discord/state.py +++ b/discord/state.py @@ -36,10 +36,9 @@ from .enums import Status from collections import deque, namedtuple -import copy +import copy, enum, math import datetime import asyncio -import enum import logging class ListenerType(enum.Enum): @@ -49,10 +48,11 @@ Listener = namedtuple('Listener', ('type', 'future', 'predicate')) log = logging.getLogger(__name__) class ConnectionState: - def __init__(self, dispatch, max_messages, *, loop): + def __init__(self, dispatch, chunker, max_messages, *, loop): self.loop = loop self.max_messages = max_messages self.dispatch = dispatch + self.chunker = chunker self._listeners = [] self.clear() @@ -128,6 +128,7 @@ class ConnectionState: self._add_server(server) return server + @asyncio.coroutine def parse_ready(self, data): self.user = User(**data['user']) guilds = data.get('guilds') @@ -139,6 +140,23 @@ class ConnectionState: self._add_private_channel(PrivateChannel(id=pm['id'], user=User(**pm['recipient']))) + # a chunk has a maximum of 1000 members. + # we need to find out how many futures we're actually waiting for + + large_servers = [s for s in self.servers if s.large] + yield from self.chunker(large_servers) + + chunks = [] + for server in large_servers: + chunks_needed = math.ceil(server._member_count / 1000) + for chunk in range(chunks_needed): + chunks.append(self.receive_chunk(server.id)) + + if chunks: + yield from asyncio.wait(chunks) + + self.dispatch('ready') + def parse_message_create(self, data): channel = self.get_channel(data.get('channel_id')) message = Message(channel=channel, **data)