committed by
GitHub
8 changed files with 386 additions and 930 deletions
@ -1,753 +0,0 @@ |
|||
""" |
|||
The MIT License (MIT) |
|||
|
|||
Copyright (c) 2015-present Rapptz |
|||
|
|||
Permission is hereby granted, free of charge, to any person obtaining a |
|||
copy of this software and associated documentation files (the "Software"), |
|||
to deal in the Software without restriction, including without limitation |
|||
the rights to use, copy, modify, merge, publish, distribute, sublicense, |
|||
and/or sell copies of the Software, and to permit persons to whom the |
|||
Software is furnished to do so, subject to the following conditions: |
|||
|
|||
The above copyright notice and this permission notice shall be included in |
|||
all copies or substantial portions of the Software. |
|||
|
|||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS |
|||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING |
|||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
|||
DEALINGS IN THE SOFTWARE. |
|||
""" |
|||
|
|||
from __future__ import annotations |
|||
|
|||
import asyncio |
|||
import datetime |
|||
from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator |
|||
|
|||
from .errors import NoMoreItems |
|||
from .utils import snowflake_time, time_snowflake, maybe_coroutine |
|||
from .object import Object |
|||
from .audit_logs import AuditLogEntry |
|||
|
|||
__all__ = ( |
|||
'ReactionIterator', |
|||
'HistoryIterator', |
|||
'AuditLogIterator', |
|||
'GuildIterator', |
|||
'MemberIterator', |
|||
) |
|||
|
|||
if TYPE_CHECKING: |
|||
from .types.audit_log import ( |
|||
AuditLog as AuditLogPayload, |
|||
) |
|||
from .types.guild import ( |
|||
Guild as GuildPayload, |
|||
) |
|||
from .types.message import ( |
|||
Message as MessagePayload, |
|||
) |
|||
from .types.user import ( |
|||
PartialUser as PartialUserPayload, |
|||
) |
|||
|
|||
from .types.threads import ( |
|||
Thread as ThreadPayload, |
|||
) |
|||
|
|||
from .member import Member |
|||
from .user import User |
|||
from .message import Message |
|||
from .audit_logs import AuditLogEntry |
|||
from .guild import Guild |
|||
from .threads import Thread |
|||
from .abc import Snowflake |
|||
|
|||
T = TypeVar('T') |
|||
OT = TypeVar('OT') |
|||
_Func = Callable[[T], Union[OT, Awaitable[OT]]] |
|||
|
|||
OLDEST_OBJECT = Object(id=0) |
|||
|
|||
|
|||
class _AsyncIterator(AsyncIterator[T]): |
|||
__slots__ = () |
|||
|
|||
async def next(self) -> T: |
|||
raise NotImplementedError |
|||
|
|||
def get(self, **attrs: Any) -> Awaitable[Optional[T]]: |
|||
def predicate(elem: T): |
|||
for attr, val in attrs.items(): |
|||
nested = attr.split('__') |
|||
obj = elem |
|||
for attribute in nested: |
|||
obj = getattr(obj, attribute) |
|||
|
|||
if obj != val: |
|||
return False |
|||
return True |
|||
|
|||
return self.find(predicate) |
|||
|
|||
async def find(self, predicate: _Func[T, bool]) -> Optional[T]: |
|||
while True: |
|||
try: |
|||
elem = await self.next() |
|||
except NoMoreItems: |
|||
return None |
|||
|
|||
ret = await maybe_coroutine(predicate, elem) |
|||
if ret: |
|||
return elem |
|||
|
|||
def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]: |
|||
if max_size <= 0: |
|||
raise ValueError('async iterator chunk sizes must be greater than 0.') |
|||
return _ChunkedAsyncIterator(self, max_size) |
|||
|
|||
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]: |
|||
return _MappedAsyncIterator(self, func) |
|||
|
|||
def filter(self, predicate: _Func[T, bool]) -> _FilteredAsyncIterator[T]: |
|||
return _FilteredAsyncIterator(self, predicate) |
|||
|
|||
async def flatten(self) -> List[T]: |
|||
return [element async for element in self] |
|||
|
|||
async def __anext__(self) -> T: |
|||
try: |
|||
return await self.next() |
|||
except NoMoreItems: |
|||
raise StopAsyncIteration() |
|||
|
|||
|
|||
def _identity(x): |
|||
return x |
|||
|
|||
|
|||
class _ChunkedAsyncIterator(_AsyncIterator[List[T]]): |
|||
def __init__(self, iterator, max_size): |
|||
self.iterator = iterator |
|||
self.max_size = max_size |
|||
|
|||
async def next(self) -> List[T]: |
|||
ret: List[T] = [] |
|||
n = 0 |
|||
while n < self.max_size: |
|||
try: |
|||
item = await self.iterator.next() |
|||
except NoMoreItems: |
|||
if ret: |
|||
return ret |
|||
raise |
|||
else: |
|||
ret.append(item) |
|||
n += 1 |
|||
return ret |
|||
|
|||
|
|||
class _MappedAsyncIterator(_AsyncIterator[T]): |
|||
def __init__(self, iterator, func): |
|||
self.iterator = iterator |
|||
self.func = func |
|||
|
|||
async def next(self) -> T: |
|||
# this raises NoMoreItems and will propagate appropriately |
|||
item = await self.iterator.next() |
|||
return await maybe_coroutine(self.func, item) |
|||
|
|||
|
|||
class _FilteredAsyncIterator(_AsyncIterator[T]): |
|||
def __init__(self, iterator, predicate): |
|||
self.iterator = iterator |
|||
|
|||
if predicate is None: |
|||
predicate = _identity |
|||
|
|||
self.predicate = predicate |
|||
|
|||
async def next(self) -> T: |
|||
getter = self.iterator.next |
|||
pred = self.predicate |
|||
while True: |
|||
# propagate NoMoreItems similar to _MappedAsyncIterator |
|||
item = await getter() |
|||
ret = await maybe_coroutine(pred, item) |
|||
if ret: |
|||
return item |
|||
|
|||
|
|||
class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): |
|||
def __init__(self, message, emoji, limit=100, after=None): |
|||
self.message = message |
|||
self.limit = limit |
|||
self.after = after |
|||
state = message._state |
|||
self.getter = state.http.get_reaction_users |
|||
self.state = state |
|||
self.emoji = emoji |
|||
self.guild = message.guild |
|||
self.channel_id = message.channel.id |
|||
self.users = asyncio.Queue() |
|||
|
|||
async def next(self) -> Union[User, Member]: |
|||
if self.users.empty(): |
|||
await self.fill_users() |
|||
|
|||
try: |
|||
return self.users.get_nowait() |
|||
except asyncio.QueueEmpty: |
|||
raise NoMoreItems() |
|||
|
|||
async def fill_users(self): |
|||
# this is a hack because >circular imports< |
|||
from .user import User |
|||
|
|||
if self.limit > 0: |
|||
retrieve = self.limit if self.limit <= 100 else 100 |
|||
|
|||
after = self.after.id if self.after else None |
|||
data: List[PartialUserPayload] = await self.getter( |
|||
self.channel_id, self.message.id, self.emoji, retrieve, after=after |
|||
) |
|||
|
|||
if data: |
|||
self.limit -= retrieve |
|||
self.after = Object(id=int(data[-1]['id'])) |
|||
|
|||
if self.guild is None or isinstance(self.guild, Object): |
|||
for element in reversed(data): |
|||
await self.users.put(User(state=self.state, data=element)) |
|||
else: |
|||
for element in reversed(data): |
|||
member_id = int(element['id']) |
|||
member = self.guild.get_member(member_id) |
|||
if member is not None: |
|||
await self.users.put(member) |
|||
else: |
|||
await self.users.put(User(state=self.state, data=element)) |
|||
|
|||
|
|||
class HistoryIterator(_AsyncIterator['Message']): |
|||
"""Iterator for receiving a channel's message history. |
|||
|
|||
The messages endpoint has two behaviours we care about here: |
|||
If ``before`` is specified, the messages endpoint returns the `limit` |
|||
newest messages before ``before``, sorted with newest first. For filling over |
|||
100 messages, update the ``before`` parameter to the oldest message received. |
|||
Messages will be returned in order by time. |
|||
If ``after`` is specified, it returns the ``limit`` oldest messages after |
|||
``after``, sorted with newest first. For filling over 100 messages, update the |
|||
``after`` parameter to the newest message received. If messages are not |
|||
reversed, they will be out of order (99-0, 199-100, so on) |
|||
|
|||
A note that if both ``before`` and ``after`` are specified, ``before`` is ignored by the |
|||
messages endpoint. |
|||
|
|||
Parameters |
|||
----------- |
|||
messageable: :class:`abc.Messageable` |
|||
Messageable class to retrieve message history from. |
|||
limit: :class:`int` |
|||
Maximum number of messages to retrieve |
|||
before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] |
|||
Message before which all messages must be. |
|||
after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] |
|||
Message after which all messages must be. |
|||
around: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] |
|||
Message around which all messages must be. Limit max 101. Note that if |
|||
limit is an even number, this will return at most limit+1 messages. |
|||
oldest_first: Optional[:class:`bool`] |
|||
If set to ``True``, return messages in oldest->newest order. Defaults to |
|||
``True`` if `after` is specified, otherwise ``False``. |
|||
""" |
|||
|
|||
def __init__(self, messageable, limit, before=None, after=None, around=None, oldest_first=None): |
|||
|
|||
if isinstance(before, datetime.datetime): |
|||
before = Object(id=time_snowflake(before, high=False)) |
|||
if isinstance(after, datetime.datetime): |
|||
after = Object(id=time_snowflake(after, high=True)) |
|||
if isinstance(around, datetime.datetime): |
|||
around = Object(id=time_snowflake(around)) |
|||
|
|||
if oldest_first is None: |
|||
self.reverse = after is not None |
|||
else: |
|||
self.reverse = oldest_first |
|||
|
|||
self.messageable = messageable |
|||
self.limit = limit |
|||
self.before = before |
|||
self.after = after or OLDEST_OBJECT |
|||
self.around = around |
|||
|
|||
self._filter = None # message dict -> bool |
|||
|
|||
self.state = self.messageable._state |
|||
self.logs_from = self.state.http.logs_from |
|||
self.messages = asyncio.Queue() |
|||
|
|||
if self.around: |
|||
if self.limit is None: |
|||
raise ValueError('history does not support around with limit=None') |
|||
if self.limit > 101: |
|||
raise ValueError("history max limit 101 when specifying around parameter") |
|||
elif self.limit == 101: |
|||
self.limit = 100 # Thanks discord |
|||
|
|||
self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore |
|||
if self.before and self.after: |
|||
self._filter = lambda m: self.after.id < int(m['id']) < self.before.id |
|||
elif self.before: |
|||
self._filter = lambda m: int(m['id']) < self.before.id |
|||
elif self.after: |
|||
self._filter = lambda m: self.after.id < int(m['id']) |
|||
else: |
|||
if self.reverse: |
|||
self._retrieve_messages = self._retrieve_messages_after_strategy # type: ignore |
|||
if self.before: |
|||
self._filter = lambda m: int(m['id']) < self.before.id |
|||
else: |
|||
self._retrieve_messages = self._retrieve_messages_before_strategy # type: ignore |
|||
if self.after and self.after != OLDEST_OBJECT: |
|||
self._filter = lambda m: int(m['id']) > self.after.id |
|||
|
|||
async def next(self) -> Message: |
|||
if self.messages.empty(): |
|||
await self.fill_messages() |
|||
|
|||
try: |
|||
return self.messages.get_nowait() |
|||
except asyncio.QueueEmpty: |
|||
raise NoMoreItems() |
|||
|
|||
def _get_retrieve(self): |
|||
l = self.limit |
|||
if l is None or l > 100: |
|||
r = 100 |
|||
else: |
|||
r = l |
|||
self.retrieve = r |
|||
return r > 0 |
|||
|
|||
async def fill_messages(self): |
|||
if not hasattr(self, 'channel'): |
|||
# do the required set up |
|||
channel = await self.messageable._get_channel() |
|||
self.channel = channel |
|||
|
|||
if self._get_retrieve(): |
|||
data = await self._retrieve_messages(self.retrieve) |
|||
if len(data) < 100: |
|||
self.limit = 0 # terminate the infinite loop |
|||
|
|||
if self.reverse: |
|||
data = reversed(data) |
|||
if self._filter: |
|||
data = filter(self._filter, data) |
|||
|
|||
channel = self.channel |
|||
for element in data: |
|||
await self.messages.put(self.state.create_message(channel=channel, data=element)) |
|||
|
|||
async def _retrieve_messages(self, retrieve) -> List[Message]: |
|||
"""Retrieve messages and update next parameters.""" |
|||
raise NotImplementedError |
|||
|
|||
async def _retrieve_messages_before_strategy(self, retrieve): |
|||
"""Retrieve messages using before parameter.""" |
|||
before = self.before.id if self.before else None |
|||
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, before=before) |
|||
if len(data): |
|||
if self.limit is not None: |
|||
self.limit -= retrieve |
|||
self.before = Object(id=int(data[-1]['id'])) |
|||
return data |
|||
|
|||
async def _retrieve_messages_after_strategy(self, retrieve): |
|||
"""Retrieve messages using after parameter.""" |
|||
after = self.after.id if self.after else None |
|||
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, after=after) |
|||
if len(data): |
|||
if self.limit is not None: |
|||
self.limit -= retrieve |
|||
self.after = Object(id=int(data[0]['id'])) |
|||
return data |
|||
|
|||
async def _retrieve_messages_around_strategy(self, retrieve): |
|||
"""Retrieve messages using around parameter.""" |
|||
if self.around: |
|||
around = self.around.id if self.around else None |
|||
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, around=around) |
|||
self.around = None |
|||
return data |
|||
return [] |
|||
|
|||
|
|||
class AuditLogIterator(_AsyncIterator['AuditLogEntry']): |
|||
def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None): |
|||
if isinstance(before, datetime.datetime): |
|||
before = Object(id=time_snowflake(before, high=False)) |
|||
if isinstance(after, datetime.datetime): |
|||
after = Object(id=time_snowflake(after, high=True)) |
|||
|
|||
if oldest_first is None: |
|||
self.reverse = after is not None |
|||
else: |
|||
self.reverse = oldest_first |
|||
|
|||
self.guild = guild |
|||
self.loop = guild._state.loop |
|||
self.request = guild._state.http.get_audit_logs |
|||
self.limit = limit |
|||
self.before = before |
|||
self.user_id = user_id |
|||
self.action_type = action_type |
|||
self.after = OLDEST_OBJECT |
|||
self._users = {} |
|||
self._state = guild._state |
|||
|
|||
self._filter = None # entry dict -> bool |
|||
|
|||
self.entries = asyncio.Queue() |
|||
|
|||
if self.reverse: |
|||
self._strategy = self._after_strategy |
|||
if self.before: |
|||
self._filter = lambda m: int(m['id']) < self.before.id |
|||
else: |
|||
self._strategy = self._before_strategy |
|||
if self.after and self.after != OLDEST_OBJECT: |
|||
self._filter = lambda m: int(m['id']) > self.after.id |
|||
|
|||
async def _before_strategy(self, retrieve): |
|||
before = self.before.id if self.before else None |
|||
data: AuditLogPayload = await self.request( |
|||
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before |
|||
) |
|||
|
|||
entries = data.get('audit_log_entries', []) |
|||
if len(data) and entries: |
|||
if self.limit is not None: |
|||
self.limit -= retrieve |
|||
self.before = Object(id=int(entries[-1]['id'])) |
|||
return data.get('users', []), entries |
|||
|
|||
async def _after_strategy(self, retrieve): |
|||
after = self.after.id if self.after else None |
|||
data: AuditLogPayload = await self.request( |
|||
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after |
|||
) |
|||
entries = data.get('audit_log_entries', []) |
|||
if len(data) and entries: |
|||
if self.limit is not None: |
|||
self.limit -= retrieve |
|||
self.after = Object(id=int(entries[0]['id'])) |
|||
return data.get('users', []), entries |
|||
|
|||
async def next(self) -> AuditLogEntry: |
|||
if self.entries.empty(): |
|||
await self._fill() |
|||
|
|||
try: |
|||
return self.entries.get_nowait() |
|||
except asyncio.QueueEmpty: |
|||
raise NoMoreItems() |
|||
|
|||
def _get_retrieve(self): |
|||
l = self.limit |
|||
if l is None or l > 100: |
|||
r = 100 |
|||
else: |
|||
r = l |
|||
self.retrieve = r |
|||
return r > 0 |
|||
|
|||
async def _fill(self): |
|||
from .user import User |
|||
|
|||
if self._get_retrieve(): |
|||
users, data = await self._strategy(self.retrieve) |
|||
if len(data) < 100: |
|||
self.limit = 0 # terminate the infinite loop |
|||
|
|||
if self.reverse: |
|||
data = reversed(data) |
|||
if self._filter: |
|||
data = filter(self._filter, data) |
|||
|
|||
for user in users: |
|||
u = User(data=user, state=self._state) |
|||
self._users[u.id] = u |
|||
|
|||
for element in data: |
|||
# TODO: remove this if statement later |
|||
if element['action_type'] is None: |
|||
continue |
|||
|
|||
await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild)) |
|||
|
|||
|
|||
class GuildIterator(_AsyncIterator['Guild']): |
|||
"""Iterator for receiving the client's guilds. |
|||
|
|||
The guilds endpoint has the same two behaviours as described |
|||
in :class:`HistoryIterator`: |
|||
If ``before`` is specified, the guilds endpoint returns the ``limit`` |
|||
newest guilds before ``before``, sorted with newest first. For filling over |
|||
100 guilds, update the ``before`` parameter to the oldest guild received. |
|||
Guilds will be returned in order by time. |
|||
If `after` is specified, it returns the ``limit`` oldest guilds after ``after``, |
|||
sorted with newest first. For filling over 100 guilds, update the ``after`` |
|||
parameter to the newest guild received, If guilds are not reversed, they |
|||
will be out of order (99-0, 199-100, so on) |
|||
|
|||
Not that if both ``before`` and ``after`` are specified, ``before`` is ignored by the |
|||
guilds endpoint. |
|||
|
|||
Parameters |
|||
----------- |
|||
bot: :class:`discord.Client` |
|||
The client to retrieve the guilds from. |
|||
limit: :class:`int` |
|||
Maximum number of guilds to retrieve. |
|||
before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] |
|||
Object before which all guilds must be. |
|||
after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] |
|||
Object after which all guilds must be. |
|||
""" |
|||
|
|||
def __init__(self, bot, limit, before=None, after=None): |
|||
|
|||
if isinstance(before, datetime.datetime): |
|||
before = Object(id=time_snowflake(before, high=False)) |
|||
if isinstance(after, datetime.datetime): |
|||
after = Object(id=time_snowflake(after, high=True)) |
|||
|
|||
self.bot = bot |
|||
self.limit = limit |
|||
self.before = before |
|||
self.after = after |
|||
|
|||
self._filter = None |
|||
|
|||
self.state = self.bot._connection |
|||
self.get_guilds = self.bot.http.get_guilds |
|||
self.guilds = asyncio.Queue() |
|||
|
|||
if self.before and self.after: |
|||
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore |
|||
self._filter = lambda m: int(m['id']) > self.after.id |
|||
elif self.after: |
|||
self._retrieve_guilds = self._retrieve_guilds_after_strategy # type: ignore |
|||
else: |
|||
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore |
|||
|
|||
async def next(self) -> Guild: |
|||
if self.guilds.empty(): |
|||
await self.fill_guilds() |
|||
|
|||
try: |
|||
return self.guilds.get_nowait() |
|||
except asyncio.QueueEmpty: |
|||
raise NoMoreItems() |
|||
|
|||
def _get_retrieve(self): |
|||
l = self.limit |
|||
if l is None or l > 100: |
|||
r = 100 |
|||
else: |
|||
r = l |
|||
self.retrieve = r |
|||
return r > 0 |
|||
|
|||
def create_guild(self, data): |
|||
from .guild import Guild |
|||
|
|||
return Guild(state=self.state, data=data) |
|||
|
|||
async def fill_guilds(self): |
|||
if self._get_retrieve(): |
|||
data = await self._retrieve_guilds(self.retrieve) |
|||
if self.limit is None or len(data) < 100: |
|||
self.limit = 0 |
|||
|
|||
if self._filter: |
|||
data = filter(self._filter, data) |
|||
|
|||
for element in data: |
|||
await self.guilds.put(self.create_guild(element)) |
|||
|
|||
async def _retrieve_guilds(self, retrieve) -> List[Guild]: |
|||
"""Retrieve guilds and update next parameters.""" |
|||
raise NotImplementedError |
|||
|
|||
async def _retrieve_guilds_before_strategy(self, retrieve): |
|||
"""Retrieve guilds using before parameter.""" |
|||
before = self.before.id if self.before else None |
|||
data: List[GuildPayload] = await self.get_guilds(retrieve, before=before) |
|||
if len(data): |
|||
if self.limit is not None: |
|||
self.limit -= retrieve |
|||
self.before = Object(id=int(data[-1]['id'])) |
|||
return data |
|||
|
|||
async def _retrieve_guilds_after_strategy(self, retrieve): |
|||
"""Retrieve guilds using after parameter.""" |
|||
after = self.after.id if self.after else None |
|||
data: List[GuildPayload] = await self.get_guilds(retrieve, after=after) |
|||
if len(data): |
|||
if self.limit is not None: |
|||
self.limit -= retrieve |
|||
self.after = Object(id=int(data[0]['id'])) |
|||
return data |
|||
|
|||
|
|||
class MemberIterator(_AsyncIterator['Member']): |
|||
def __init__(self, guild, limit=1000, after=None): |
|||
|
|||
if isinstance(after, datetime.datetime): |
|||
after = Object(id=time_snowflake(after, high=True)) |
|||
|
|||
self.guild = guild |
|||
self.limit = limit |
|||
self.after = after or OLDEST_OBJECT |
|||
|
|||
self.state = self.guild._state |
|||
self.get_members = self.state.http.get_members |
|||
self.members = asyncio.Queue() |
|||
|
|||
async def next(self) -> Member: |
|||
if self.members.empty(): |
|||
await self.fill_members() |
|||
|
|||
try: |
|||
return self.members.get_nowait() |
|||
except asyncio.QueueEmpty: |
|||
raise NoMoreItems() |
|||
|
|||
def _get_retrieve(self): |
|||
l = self.limit |
|||
if l is None or l > 1000: |
|||
r = 1000 |
|||
else: |
|||
r = l |
|||
self.retrieve = r |
|||
return r > 0 |
|||
|
|||
async def fill_members(self): |
|||
if self._get_retrieve(): |
|||
after = self.after.id if self.after else None |
|||
data = await self.get_members(self.guild.id, self.retrieve, after) |
|||
if not data: |
|||
# no data, terminate |
|||
return |
|||
|
|||
if len(data) < 1000: |
|||
self.limit = 0 # terminate loop |
|||
|
|||
self.after = Object(id=int(data[-1]['user']['id'])) |
|||
|
|||
for element in reversed(data): |
|||
await self.members.put(self.create_member(element)) |
|||
|
|||
def create_member(self, data): |
|||
from .member import Member |
|||
|
|||
return Member(data=data, guild=self.guild, state=self.state) |
|||
|
|||
|
|||
class ArchivedThreadIterator(_AsyncIterator['Thread']): |
|||
def __init__( |
|||
self, |
|||
channel_id: int, |
|||
guild: Guild, |
|||
limit: Optional[int], |
|||
joined: bool, |
|||
private: bool, |
|||
before: Optional[Union[Snowflake, datetime.datetime]] = None, |
|||
): |
|||
self.channel_id = channel_id |
|||
self.guild = guild |
|||
self.limit = limit |
|||
self.joined = joined |
|||
self.private = private |
|||
self.http = guild._state.http |
|||
|
|||
if joined and not private: |
|||
raise ValueError('Cannot iterate over joined public archived threads') |
|||
|
|||
self.before: Optional[str] |
|||
if before is None: |
|||
self.before = None |
|||
elif isinstance(before, datetime.datetime): |
|||
if joined: |
|||
self.before = str(time_snowflake(before, high=False)) |
|||
else: |
|||
self.before = before.isoformat() |
|||
else: |
|||
if joined: |
|||
self.before = str(before.id) |
|||
else: |
|||
self.before = snowflake_time(before.id).isoformat() |
|||
|
|||
self.update_before: Callable[[ThreadPayload], str] = self.get_archive_timestamp |
|||
|
|||
if joined: |
|||
self.endpoint = self.http.get_joined_private_archived_threads |
|||
self.update_before = self.get_thread_id |
|||
elif private: |
|||
self.endpoint = self.http.get_private_archived_threads |
|||
else: |
|||
self.endpoint = self.http.get_public_archived_threads |
|||
|
|||
self.queue: asyncio.Queue[Thread] = asyncio.Queue() |
|||
self.has_more: bool = True |
|||
|
|||
async def next(self) -> Thread: |
|||
if self.queue.empty(): |
|||
await self.fill_queue() |
|||
|
|||
try: |
|||
return self.queue.get_nowait() |
|||
except asyncio.QueueEmpty: |
|||
raise NoMoreItems() |
|||
|
|||
@staticmethod |
|||
def get_archive_timestamp(data: ThreadPayload) -> str: |
|||
return data['thread_metadata']['archive_timestamp'] |
|||
|
|||
@staticmethod |
|||
def get_thread_id(data: ThreadPayload) -> str: |
|||
return data['id'] # type: ignore |
|||
|
|||
async def fill_queue(self) -> None: |
|||
if not self.has_more: |
|||
raise NoMoreItems() |
|||
|
|||
limit = 50 if self.limit is None else max(self.limit, 50) |
|||
data = await self.endpoint(self.channel_id, before=self.before, limit=limit) |
|||
|
|||
# This stuff is obviously WIP because 'members' is always empty |
|||
threads: List[ThreadPayload] = data.get('threads', []) |
|||
for d in reversed(threads): |
|||
self.queue.put_nowait(self.create_thread(d)) |
|||
|
|||
self.has_more = data.get('has_more', False) |
|||
if self.limit is not None: |
|||
self.limit -= len(threads) |
|||
if self.limit <= 0: |
|||
self.has_more = False |
|||
|
|||
if self.has_more: |
|||
self.before = self.update_before(threads[-1]) |
|||
|
|||
def create_thread(self, data: ThreadPayload) -> Thread: |
|||
from .threads import Thread |
|||
return Thread(guild=self.guild, state=self.guild._state, data=data) |
Loading…
Reference in new issue