Browse Source

Add around parameter to LogsFromIterator.

pull/365/head
khazhyk 9 years ago
parent
commit
158ac6bb50
  1. 23
      discord/client.py
  2. 4
      discord/http.py
  3. 32
      discord/iterators.py

23
discord/client.py

@ -978,7 +978,7 @@ class Client:
yield from self.http.delete_messages(channel.id, message_ids, guild_id) yield from self.http.delete_messages(channel.id, message_ids, guild_id)
@asyncio.coroutine @asyncio.coroutine
def purge_from(self, channel, *, limit=100, check=None, before=None, after=None): def purge_from(self, channel, *, limit=100, check=None, before=None, after=None, around=None):
"""|coro| """|coro|
Purges a list of messages that meet the criteria given by the predicate Purges a list of messages that meet the criteria given by the predicate
@ -1007,6 +1007,9 @@ class Client:
after : :class:`Message` or `datetime` after : :class:`Message` or `datetime`
The message or date after which all deleted messages must be. The message or date after which all deleted messages must be.
If a date is provided it must be a timezone-naive datetime representing UTC time. If a date is provided it must be a timezone-naive datetime representing UTC time.
around : :class:`Message` or `datetime`
The message or date around which all deleted messages must be.
If a date is provided it must be a timezone-naive datetime representing UTC time.
Raises Raises
------- -------
@ -1040,8 +1043,10 @@ class Client:
before = Object(utils.time_snowflake(before, high=False)) before = Object(utils.time_snowflake(before, high=False))
if isinstance(after, datetime.datetime): if isinstance(after, datetime.datetime):
after = Object(utils.time_snowflake(after, high=True)) after = Object(utils.time_snowflake(after, high=True))
if isinstance(around, datetime.datetime):
around = Object(utils.time_snowflake(around, high=True))
iterator = LogsFromIterator(self, channel, limit, before=before, after=after) iterator = LogsFromIterator(self, channel, limit, before=before, after=after, around=around)
ret = [] ret = []
count = 0 count = 0
@ -1209,7 +1214,7 @@ class Client:
data = yield from self.http.pins_from(channel.id) data = yield from self.http.pins_from(channel.id)
return [Message(channel=channel, **m) for m in data] return [Message(channel=channel, **m) for m in data]
def _logs_from(self, channel, limit=100, before=None, after=None): def _logs_from(self, channel, limit=100, before=None, after=None, around=None):
"""|coro| """|coro|
This coroutine returns a generator that obtains logs from a specified channel. This coroutine returns a generator that obtains logs from a specified channel.
@ -1226,6 +1231,9 @@ class Client:
after : :class:`Message` or `datetime` after : :class:`Message` or `datetime`
The message or date after which all returned messages must be. The message or date after which all returned messages must be.
If a date is provided it must be a timezone-naive datetime representing UTC time. If a date is provided it must be a timezone-naive datetime representing UTC time.
around : :class:`Message` or `datetime`
The message or date around which all returned messages must be.
If a date is provided it must be a timezone-naive datetime representing UTC time.
Raises Raises
------ ------
@ -1261,17 +1269,20 @@ class Client:
""" """
before = getattr(before, 'id', None) before = getattr(before, 'id', None)
after = getattr(after, 'id', None) after = getattr(after, 'id', None)
around = getattr(around, 'id', None)
return self.http.logs_from(channel.id, limit, before=before, after=after) return self.http.logs_from(channel.id, limit, before=before, after=after, around=around)
if PY35: if PY35:
def logs_from(self, channel, limit=100, *, before=None, after=None, reverse=False): def logs_from(self, channel, limit=100, *, before=None, after=None, around=None, reverse=False):
if isinstance(before, datetime.datetime): if isinstance(before, datetime.datetime):
before = Object(utils.time_snowflake(before, high=False)) before = Object(utils.time_snowflake(before, high=False))
if isinstance(after, datetime.datetime): if isinstance(after, datetime.datetime):
after = Object(utils.time_snowflake(after, high=True)) after = Object(utils.time_snowflake(after, high=True))
if isinstance(around, datetime.datetime):
around = Object(utils.time_snowflake(around))
return LogsFromIterator(self, channel, limit, before=before, after=after, reverse=reverse) return LogsFromIterator(self, channel, limit, before=before, after=after, around=around, reverse=reverse)
else: else:
@asyncio.coroutine @asyncio.coroutine
def logs_from(self, channel, limit=100, *, before=None, after=None): def logs_from(self, channel, limit=100, *, before=None, after=None):

4
discord/http.py

@ -265,7 +265,7 @@ class HTTPClient:
url = '{0.CHANNELS}/{1}/messages/{2}'.format(self, channel_id, message_id) url = '{0.CHANNELS}/{1}/messages/{2}'.format(self, channel_id, message_id)
return self.get(url, bucket=_func_()) return self.get(url, bucket=_func_())
def logs_from(self, channel_id, limit, before=None, after=None): def logs_from(self, channel_id, limit, before=None, after=None, around=None):
url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id) url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id)
params = { params = {
'limit': limit 'limit': limit
@ -275,6 +275,8 @@ class HTTPClient:
params['before'] = before params['before'] = before
if after: if after:
params['after'] = after params['after'] = after
if around:
params['around'] = around
return self.get(url, params=params, bucket=_func_()) return self.get(url, params=params, bucket=_func_())

32
discord/iterators.py

@ -60,6 +60,9 @@ class LogsFromIterator:
Message before which all messages must be. Message before which all messages must be.
after : :class:`Message` or id-like after : :class:`Message` or id-like
Message after which all messages must be. Message after which all messages must be.
around : :class:`Message` or id-like
Message around which all messages must be. Limit max 101. Note that if
limit is an even number, this will return at most limit+1 messages.
reverse : bool reverse : bool
If set to true, return messages in oldest->newest order. Recommended If set to true, return messages in oldest->newest order. Recommended
when using with "after" queries with limit over 100, otherwise messages when using with "after" queries with limit over 100, otherwise messages
@ -67,17 +70,33 @@ class LogsFromIterator:
""" """
def __init__(self, client, channel, limit, def __init__(self, client, channel, limit,
before=None, after=None, reverse=False): before=None, after=None, around=None, reverse=False):
self.client = client self.client = client
self.channel = channel self.channel = channel
self.limit = limit self.limit = limit
self.before = before self.before = before
self.after = after self.after = after
self.around = around
self.reverse = reverse self.reverse = reverse
self._filter = None # message dict -> bool self._filter = None # message dict -> bool
self.messages = asyncio.Queue() self.messages = asyncio.Queue()
if self.before and self.after: if self.around:
if self.limit > 101:
raise ValueError("LogsFrom max limit 101 when specifying around parameter")
elif self.limit == 101:
self.limit = 100 # Thanks discord
elif self.limit == 1:
raise ValueError("Use get_message.")
self._retrieve_messages = self._retrieve_messages_around_strategy
if self.before and self.after:
self._filter = lambda m: int(self.after.id) < int(m['id']) < int(self.before.id)
elif self.before:
self._filter = lambda m: int(m['id']) < int(self.before.id)
elif self.after:
self._filter = lambda m: int(self.after.id) < int(m['id'])
elif self.before and self.after:
if self.reverse: if self.reverse:
self._retrieve_messages = self._retrieve_messages_after_strategy self._retrieve_messages = self._retrieve_messages_after_strategy
self._filter = lambda m: int(m['id']) < int(self.before.id) self._filter = lambda m: int(m['id']) < int(self.before.id)
@ -131,6 +150,15 @@ class LogsFromIterator:
self.after = Object(id=data[0]['id']) self.after = Object(id=data[0]['id'])
return data return data
@asyncio.coroutine
def _retrieve_messages_around_strategy(self, retrieve):
"""Retrieve messages using around parameter."""
if self.around:
data = yield from self.client._logs_from(self.channel, retrieve, around=self.around)
self.around = None
return data
return []
if PY35: if PY35:
@asyncio.coroutine @asyncio.coroutine
def __aiter__(self): def __aiter__(self):

Loading…
Cancel
Save