Browse Source

Move chunking logic back into ConnectionState.

This allows for a nicer design when dealing with parsers that could
end up being coroutines.
pull/124/head
Rapptz 9 years ago
parent
commit
425bd2c091
  1. 33
      discord/client.py
  2. 24
      discord/state.py

33
discord/client.py

@ -51,7 +51,7 @@ import logging, traceback
import sys, time, re, json import sys, time, re, json
import tempfile, os, hashlib import tempfile, os, hashlib
import itertools import itertools
import zlib, math import zlib
from random import randint as random_integer from random import randint as random_integer
PY35 = sys.version_info >= (3, 5) PY35 = sys.version_info >= (3, 5)
@ -122,7 +122,7 @@ class Client:
if max_messages is None or max_messages < 100: if max_messages is None or max_messages < 100:
max_messages = 5000 max_messages = 5000
self.connection = ConnectionState(self.dispatch, max_messages, loop=self.loop) self.connection = ConnectionState(self.dispatch, self.request_offline_members, max_messages, loop=self.loop)
# Blame Jake for this # Blame Jake for this
user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}' user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}'
@ -145,28 +145,6 @@ class Client:
# internals # internals
def _get_all_chunks(self):
# a chunk has a maximum of 1000 members.
# we need to find out how many futures we're actually waiting for
large_servers = filter(lambda s: s.large, self.servers)
futures = []
for server in large_servers:
chunks_needed = math.ceil(server._member_count / 1000)
for chunk in range(chunks_needed):
futures.append(self.connection.receive_chunk(server.id))
return futures
@asyncio.coroutine
def _fill_offline(self):
yield from self.request_offline_members(filter(lambda s: s.large, self.servers))
chunks = self._get_all_chunks()
if chunks:
yield from asyncio.wait(chunks)
self.dispatch('ready')
def _get_cache_filename(self, email): def _get_cache_filename(self, email):
filename = hashlib.md5(email.encode('utf-8')).hexdigest() filename = hashlib.md5(email.encode('utf-8')).hexdigest()
return os.path.join(tempfile.gettempdir(), 'discord_py', filename) return os.path.join(tempfile.gettempdir(), 'discord_py', filename)
@ -392,11 +370,10 @@ class Client:
func = getattr(self.connection, parser) func = getattr(self.connection, parser)
except AttributeError: except AttributeError:
log.info('Unhandled event {}'.format(event)) log.info('Unhandled event {}'.format(event))
else:
func(data)
if is_ready: result = func(data)
utils.create_task(self._fill_offline(), loop=self.loop) if asyncio.iscoroutine(result):
utils.create_task(result, loop=self.loop)
@asyncio.coroutine @asyncio.coroutine
def _make_websocket(self, initial=True): def _make_websocket(self, initial=True):

24
discord/state.py

@ -36,10 +36,9 @@ from .enums import Status
from collections import deque, namedtuple from collections import deque, namedtuple
import copy import copy, enum, math
import datetime import datetime
import asyncio import asyncio
import enum
import logging import logging
class ListenerType(enum.Enum): class ListenerType(enum.Enum):
@ -49,10 +48,11 @@ Listener = namedtuple('Listener', ('type', 'future', 'predicate'))
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class ConnectionState: class ConnectionState:
def __init__(self, dispatch, max_messages, *, loop): def __init__(self, dispatch, chunker, max_messages, *, loop):
self.loop = loop self.loop = loop
self.max_messages = max_messages self.max_messages = max_messages
self.dispatch = dispatch self.dispatch = dispatch
self.chunker = chunker
self._listeners = [] self._listeners = []
self.clear() self.clear()
@ -128,6 +128,7 @@ class ConnectionState:
self._add_server(server) self._add_server(server)
return server return server
@asyncio.coroutine
def parse_ready(self, data): def parse_ready(self, data):
self.user = User(**data['user']) self.user = User(**data['user'])
guilds = data.get('guilds') guilds = data.get('guilds')
@ -139,6 +140,23 @@ class ConnectionState:
self._add_private_channel(PrivateChannel(id=pm['id'], self._add_private_channel(PrivateChannel(id=pm['id'],
user=User(**pm['recipient']))) user=User(**pm['recipient'])))
# a chunk has a maximum of 1000 members.
# we need to find out how many futures we're actually waiting for
large_servers = [s for s in self.servers if s.large]
yield from self.chunker(large_servers)
chunks = []
for server in large_servers:
chunks_needed = math.ceil(server._member_count / 1000)
for chunk in range(chunks_needed):
chunks.append(self.receive_chunk(server.id))
if chunks:
yield from asyncio.wait(chunks)
self.dispatch('ready')
def parse_message_create(self, data): def parse_message_create(self, data):
channel = self.get_channel(data.get('channel_id')) channel = self.get_channel(data.get('channel_id'))
message = Message(channel=channel, **data) message = Message(channel=channel, **data)

Loading…
Cancel
Save