diff --git a/discord/client.py b/discord/client.py index 7e49aa890..98f62a442 100644 --- a/discord/client.py +++ b/discord/client.py @@ -41,6 +41,7 @@ from .permissions import Permissions from . import utils from .enums import ChannelType, ServerRegion from .voice_client import VoiceClient +from .iterators import LogsFromIterator import asyncio import aiohttp @@ -53,6 +54,7 @@ import itertools import zlib from random import randint as random_integer +PY35 = sys.version_info >= (3, 5) log = logging.getLogger(__name__) request_logging_format = '{method} {response.url} has returned {response.status}' request_success_log = '{response.url} with {json} received {data}' @@ -1115,24 +1117,6 @@ class Client: @asyncio.coroutine def _logs_from(self, channel, limit=100, before=None, after=None): - url = '{}/{}/messages'.format(endpoints.CHANNELS, channel.id) - params = { - 'limit': limit - } - - if before: - params['before'] = before.id - if after: - params['after'] = after.id - - response = yield from aiohttp.get(url, params=params, headers=self.headers, loop=self.loop) - log.debug(request_logging_format.format(method='GET', response=response)) - yield from utils._verify_successful_response(response) - messages = yield from response.json() - return messages - - @asyncio.coroutine - def logs_from(self, channel, limit=100, *, before=None, after=None): """|coro| This coroutine returns a generator that obtains logs from a specified channel. @@ -1172,24 +1156,54 @@ class Client: if message.content.startswith('!hello'): if message.author == client.user: yield from client.edit_message(message, 'goodbye') + + Python 3.5 Usage :: + + counter = 0 + async for message in client.logs_from(channel, limit=500): + if message.author == client.user: + counter += 1 """ + url = '{}/{}/messages'.format(endpoints.CHANNELS, channel.id) + params = { + 'limit': limit + } - def generator(data): - for message in data: - yield Message(channel=channel, **message) - - result = [] - while limit > 0: - retrieve = limit if limit <= 100 else 100 - data = yield from self._logs_from(channel, retrieve, before, after) - if len(data): - limit -= retrieve - result.extend(data) - before = Object(id=data[-1]['id']) - else: - break + if before: + params['before'] = before.id + if after: + params['after'] = after.id + + response = yield from aiohttp.get(url, params=params, headers=self.headers, loop=self.loop) + log.debug(request_logging_format.format(method='GET', response=response)) + yield from utils._verify_successful_response(response) + messages = yield from response.json() + return messages + + if PY35: + def logs_from(self, channel, limit=100, *, before=None, after=None): + return LogsFromIterator(self, channel, limit, before, after) + else: + @asyncio.coroutine + def logs_from(self, channel, limit=100, *, before=None, after=None): + def generator(data): + for message in data: + yield Message(channel=channel, **message) + + result = [] + while limit > 0: + retrieve = limit if limit <= 100 else 100 + data = yield from self._logs_from(channel, retrieve, before, after) + if len(data): + limit -= retrieve + result.extend(data) + before = Object(id=data[-1]['id']) + else: + break + + return generator(result) - return generator(result) + logs_from.__doc__ = _logs_from.__doc__ # Member management diff --git a/discord/iterators.py b/discord/iterators.py new file mode 100644 index 000000000..b0cb0c772 --- /dev/null +++ b/discord/iterators.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2016 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import sys +import asyncio +import aiohttp +from .message import Message +from .object import Object + +PY35 = sys.version_info >= (3, 5) + +class LogsFromIterator: + def __init__(self, client, channel, limit, before, after): + self.client = client + self.channel = channel + self.limit = limit + self.before = before + self.after = after + self.messages = asyncio.LifoQueue() + + @asyncio.coroutine + def fill_messages(self): + if self.limit > 0: + retrieve = self.limit if self.limit <= 100 else 100 + data = yield from self.client._logs_from(self.channel, retrieve, self.before, self.after) + if len(data): + self.limit -= retrieve + self.before = Object(id=data[-1]['id']) + for element in data: + yield from self.messages.put(Message(channel=self.channel, **element)) + + if PY35: + @asyncio.coroutine + def __aiter__(self): + return self + + @asyncio.coroutine + def __anext__(self): + if self.messages.empty(): + yield from self.fill_messages() + + try: + msg = self.messages.get_nowait() + return msg + except asyncio.QueueEmpty: + # if we're still empty at this point... + # we didn't get any new messages so stop looping + raise StopAsyncIteration()