From 588cda09960530f73d9baecc121758b7cc798c42 Mon Sep 17 00:00:00 2001 From: Kaylynn Morgan <51037748+kaylynn234@users.noreply.github.com> Date: Sun, 20 Feb 2022 13:58:13 +1100 Subject: [PATCH] Refactor AsyncIter to use 3.6+ asynchronous generators --- discord/abc.py | 110 ++++++- discord/channel.py | 54 +++- discord/client.py | 88 ++++- discord/guild.py | 133 +++++++- discord/iterators.py | 753 ------------------------------------------- discord/object.py | 7 +- discord/reaction.py | 42 ++- docs/api.rst | 129 -------- 8 files changed, 386 insertions(+), 930 deletions(-) delete mode 100644 discord/iterators.py diff --git a/discord/abc.py b/discord/abc.py index 2bb5f3f30..8514181a3 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -26,8 +26,10 @@ from __future__ import annotations import copy import asyncio +from datetime import datetime from typing import ( Any, + AsyncIterator, Callable, Dict, List, @@ -42,7 +44,7 @@ from typing import ( runtime_checkable, ) -from .iterators import HistoryIterator +from .object import OLDEST_OBJECT, Object from .context_managers import Typing from .enums import ChannelType from .errors import InvalidArgument, ClientException @@ -68,8 +70,6 @@ __all__ = ( T = TypeVar('T', bound=VoiceProtocol) if TYPE_CHECKING: - from datetime import datetime - from .client import Client from .user import ClientUser from .asset import Asset @@ -1465,7 +1465,7 @@ class Messageable: data = await state.http.pins_from(channel.id) return [state.create_message(channel=channel, data=m) for m in data] - def history( + async def history( self, *, limit: Optional[int] = 100, @@ -1473,8 +1473,8 @@ class Messageable: after: Optional[SnowflakeTime] = None, around: Optional[SnowflakeTime] = None, oldest_first: Optional[bool] = None, - ) -> HistoryIterator: - """Returns an :class:`~discord.AsyncIterator` that enables receiving the destination's message history. + ) -> AsyncIterator[Message]: + """Returns an :term:`asynchronous iterator` that enables receiving the destination's message history. You must have :attr:`~discord.Permissions.read_message_history` permissions to use this. @@ -1490,7 +1490,7 @@ class Messageable: Flattening into a list: :: - messages = await channel.history(limit=123).flatten() + messages = [message async for message in channel.history(limit=123)] # messages is now a list of Message... All parameters are optional. @@ -1531,7 +1531,101 @@ class Messageable: :class:`~discord.Message` The message with the message data parsed. """ - return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first) + + async def _around_strategy(retrieve, around, limit): + if not around: + return [] + + around_id = around.id if around else None + data = await self._state.http.logs_from(channel.id, retrieve, around=around_id) + + return data, None, limit + + async def _after_strategy(retrieve, after, limit): + after_id = after.id if after else None + data = await self._state.http.logs_from(channel.id, retrieve, after=after_id) + + if data: + if limit is not None: + limit -= len(data) + + after = Object(id=int(data[0]['id'])) + + return data, after, limit + + async def _before_strategy(retrieve, before, limit): + before_id = before.id if before else None + data = await self._state.http.logs_from(channel.id, retrieve, before=before_id) + + if data: + if limit is not None: + limit -= len(data) + + before = Object(id=int(data[-1]['id'])) + + return data, before, limit + + if isinstance(before, datetime): + before = Object(id=utils.time_snowflake(before, high=False)) + if isinstance(after, datetime): + after = Object(id=utils.time_snowflake(after, high=True)) + if isinstance(around, datetime): + around = Object(id=utils.time_snowflake(around)) + + if oldest_first is None: + reverse = after is not None + else: + reverse = oldest_first + + after = after or OLDEST_OBJECT + predicate = None + + if around: + if limit is None: + raise ValueError('history does not support around with limit=None') + if limit > 101: + raise ValueError("history max limit 101 when specifying around parameter") + + # Strange Discord quirk + limit = 100 if limit == 101 else limit + + strategy, state = _around_strategy, around + + if before and after: + predicate = lambda m: after.id < int(m['id']) < before.id + elif before: + predicate = lambda m: int(m['id']) < before.id + elif after: + predicate = lambda m: after.id < int(m['id']) + elif reverse: + strategy, state = _after_strategy, after + if before: + predicate = lambda m: int(m['id']) < before.id + else: + strategy, state = _before_strategy, before + if after and after != OLDEST_OBJECT: + predicate = lambda m: int(m['id']) > after.id + + channel = await self._get_channel() + + while True: + retrieve = min(100 if limit is None else limit, 100) + if retrieve < 1: + return + + data, state, limit = await strategy(retrieve, state, limit) + + # Terminate loop on next iteration; there's no data left after this + if len(data) < 100: + limit = 0 + + if reverse: + data = reversed(data) + if predicate: + data = filter(predicate, data) + + for raw_message in data: + yield self._state.create_message(channel=channel, data=raw_message) class Connectable(Protocol): diff --git a/discord/channel.py b/discord/channel.py index 229e7531a..d2409cd34 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -28,6 +28,7 @@ import time import asyncio from typing import ( Any, + AsyncIterator, Callable, Dict, Iterable, @@ -54,7 +55,6 @@ from .asset import Asset from .errors import ClientException, InvalidArgument from .stage_instance import StageInstance from .threads import Thread -from .iterators import ArchivedThreadIterator __all__ = ( 'TextChannel', @@ -755,15 +755,15 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): return Thread(guild=self.guild, state=self._state, data=data) - def archived_threads( + async def archived_threads( self, *, private: bool = False, joined: bool = False, limit: Optional[int] = 50, before: Optional[Union[Snowflake, datetime.datetime]] = None, - ) -> ArchivedThreadIterator: - """Returns an :class:`~discord.AsyncIterator` that iterates over all archived threads in the guild. + ) -> AsyncIterator[Thread]: + """Returns an :term:`asynchronous iterator` that iterates over all archived threads in the guild. You must have :attr:`~Permissions.read_message_history` to use this. If iterating over private threads then :attr:`~Permissions.manage_threads` is also required. @@ -790,13 +790,57 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): You do not have permissions to get archived threads. HTTPException The request to get the archived threads failed. + ValueError + `joined`` was set to ``True`` and ``private`` was set to ``False``. You cannot retrieve public archived + threads that you have joined. Yields ------- :class:`Thread` The archived threads. """ - return ArchivedThreadIterator(self.id, self.guild, limit=limit, joined=joined, private=private, before=before) + if joined and not private: + raise ValueError('Cannot retrieve joined public archived threads') + + before_timestamp = None + + if isinstance(before, datetime.datetime): + if joined: + before_timestamp = str(utils.time_snowflake(before, high=False)) + else: + before_timestamp = before.isoformat() + elif before is not None: + if joined: + before_timestamp = str(before.id) + else: + before_timestamp = utils.snowflake_time(before.id).isoformat() + + update_before = lambda data: data['thread_metadata']['archive_timestamp'] + endpoint = self.guild._state.http.get_public_archived_threads + + if joined: + update_before = lambda data: data['id'] + endpoint = self.guild._state.http.get_joined_private_archived_threads + elif private: + endpoint = self.guild._state.http.get_private_archived_threads + + while True: + retrieve = 50 if limit is None else max(limit, 50) + data = await endpoint(self.id, before=before_timestamp, limit=retrieve) + + threads = data.get('threads', []) + for raw_thread in reversed(threads): + yield Thread(guild=self.guild, state=self.guild._state, data=raw_thread) + + if not data.get('has_more', False): + return + + if limit is not None: + limit -= len(threads) + if limit <= 0: + return + + before = update_before(threads[-1]) class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): diff --git a/discord/client.py b/discord/client.py index 0e2a106a6..e7cdeac3e 100644 --- a/discord/client.py +++ b/discord/client.py @@ -25,11 +25,26 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations import asyncio +import datetime import logging import signal import sys import traceback -from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union +from typing import ( + Any, + AsyncIterator, + Callable, + Coroutine, + Dict, + Generator, + List, + Optional, + Sequence, + TYPE_CHECKING, + Tuple, + TypeVar, + Union +) import aiohttp @@ -51,11 +66,10 @@ from .voice_client import VoiceClient from .http import HTTPClient from .state import ConnectionState from . import utils -from .utils import MISSING +from .utils import MISSING, time_snowflake from .object import Object from .backoff import ExponentialBackoff from .webhook import Webhook -from .iterators import GuildIterator from .appinfo import AppInfo from .ui.view import View from .stage_instance import StageInstance @@ -63,6 +77,7 @@ from .threads import Thread from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory if TYPE_CHECKING: + from .types.guild import Guild as GuildPayload from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake from .channel import DMChannel from .message import Message @@ -1120,14 +1135,14 @@ class Client: # Guild stuff - def fetch_guilds( + async def fetch_guilds( self, *, limit: Optional[int] = 100, - before: SnowflakeTime = None, - after: SnowflakeTime = None - ) -> GuildIterator: - """Retrieves an :class:`.AsyncIterator` that enables receiving your guilds. + before: Optional[SnowflakeTime] = None, + after: Optional[SnowflakeTime] = None, + ) -> AsyncIterator[Guild]: + """Retrieves an :term:`asynchronous iterator` that enables receiving your guilds. .. note:: @@ -1148,7 +1163,7 @@ class Client: Flattening into a list :: - guilds = await client.fetch_guilds(limit=150).flatten() + guilds = [guild async for guild in client.fetch_guilds(limit=150)] # guilds is now a list of Guild... All parameters are optional. @@ -1179,7 +1194,60 @@ class Client: :class:`.Guild` The guild with the guild data parsed. """ - return GuildIterator(self, limit=limit, before=before, after=after) + + async def _before_strategy(retrieve, before, limit): + before_id = before.id if before else None + data = await self.http.get_guilds(retrieve, before=before_id) + + if data: + if limit is not None: + limit -= len(data) + + before = Object(id=int(data[-1]['id'])) + + return data, before, limit + + async def _after_strategy(retrieve, after, limit): + after_id = after.id if after else None + data = await self.http.get_guilds(retrieve, after=after_id) + + if data: + if limit is not None: + limit -= len(data) + + after = Object(id=int(data[0]['id'])) + + return data, after, limit + + if isinstance(before, datetime.datetime): + before = Object(id=time_snowflake(before, high=False)) + if isinstance(after, datetime.datetime): + after = Object(id=time_snowflake(after, high=True)) + + predicate = None + strategy, state = _before_strategy, before + + if before and after: + predicate = lambda m: int(m['id']) > after.id # type: ignore + elif after: + strategy, state = _after_strategy, after + + while True: + retrieve = min(100 if limit is None else limit, 100) + if retrieve < 1: + return + + data, state, limit = await strategy(retrieve, state, limit) + + # Terminate loop on next iteration; there's no data left after this + if len(data) < 100: + limit = 0 + + if predicate: + data = filter(predicate, data) + + for raw_guild in data: + yield Guild(state=self._connection, data=raw_guild) async def fetch_template(self, code: Union[Template, str]) -> Template: """|coro| diff --git a/discord/guild.py b/discord/guild.py index c6606e171..062b28f37 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -25,9 +25,11 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations import copy +import datetime import unicodedata from typing import ( Any, + AsyncIterator, ClassVar, Dict, List, @@ -67,7 +69,6 @@ from .enums import ( from .mixins import Hashable from .user import User from .invite import Invite -from .iterators import AuditLogIterator, MemberIterator from .widget import Widget from .asset import Asset from .flags import SystemChannelFlags @@ -76,6 +77,8 @@ from .stage_instance import StageInstance from .threads import Thread, ThreadMember from .sticker import GuildSticker from .file import File +from .audit_logs import AuditLogEntry +from .object import OLDEST_OBJECT, Object __all__ = ( @@ -98,8 +101,6 @@ if TYPE_CHECKING: from .state import ConnectionState from .voice_client import VoiceProtocol - import datetime - VocalGuildChannel = Union[VoiceChannel, StageChannel] GuildChannel = Union[VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel] ByCategoryItem = Tuple[Optional[CategoryChannel], List[GuildChannel]] @@ -1649,9 +1650,8 @@ class Guild(Hashable): return threads - # TODO: Remove Optional typing here when async iterators are refactored - def fetch_members(self, *, limit: int = 1000, after: Optional[SnowflakeTime] = None) -> MemberIterator: - """Retrieves an :class:`.AsyncIterator` that enables receiving the guild's members. In order to use this, + async def fetch_members(self, *, limit: int = 1000, after: SnowflakeTime = MISSING) -> AsyncIterator[Member]: + """Retrieves an :term:`asynchronous iterator` that enables receiving the guild's members. In order to use this, :meth:`Intents.members` must be enabled. .. note:: @@ -1701,7 +1701,30 @@ class Guild(Hashable): if not self._state._intents.members: raise ClientException('Intents.members must be enabled to use this.') - return MemberIterator(self, limit=limit, after=after) + while True: + retrieve = min(1000 if limit is None else limit, 1000) + if retrieve < 1: + return + + if isinstance(after, datetime.datetime): + after = Object(id=utils.time_snowflake(after, high=True)) + + after = after or OLDEST_OBJECT + after_id = after.id if after else None + state = self._state + + data = await state.http.get_members(self.id, retrieve, after_id) + if not data: + return + + # Terminate loop on next iteration; there's no data left after this + if len(data) < 1000: + limit = 0 + + after = Object(id=int(data[-1]['user']['id'])) + + for raw_member in reversed(data): + yield Member(data=raw_member, guild=self, state=state) async def fetch_member(self, member_id: int, /) -> Member: """|coro| @@ -2731,18 +2754,17 @@ class Guild(Hashable): payload['uses'] = payload.get('uses', 0) return Invite(state=self._state, data=payload, guild=self, channel=channel) - # TODO: use MISSING when async iterators get refactored - def audit_logs( + async def audit_logs( self, *, limit: int = 100, before: Optional[SnowflakeTime] = None, after: Optional[SnowflakeTime] = None, oldest_first: Optional[bool] = None, - user: Snowflake = None, - action: AuditLogAction = None, - ) -> AuditLogIterator: - """Returns an :class:`AsyncIterator` that enables receiving the guild's audit logs. + user: Snowflake = MISSING, + action: AuditLogAction = MISSING, + ) -> AsyncIterator[AuditLogEntry]: + """Returns an :term:`asynchronous iterator` that enables receiving the guild's audit logs. You must have the :attr:`~Permissions.view_audit_log` permission to use this. @@ -2761,7 +2783,7 @@ class Guild(Hashable): Getting entries made by a specific user: :: - entries = await guild.audit_logs(limit=None, user=guild.me).flatten() + entries = [entry async for entry in guild.audit_logs(limit=None, user=guild.me)] await channel.send(f'I made {len(entries)} moderation actions.') Parameters @@ -2796,6 +2818,39 @@ class Guild(Hashable): :class:`AuditLogEntry` The audit log entry. """ + + async def _before_strategy(retrieve, before, limit): + before_id = before.id if before else None + data = await self._state.http.get_audit_logs( + self.id, limit=retrieve, user_id=user_id, action_type=action, before=before_id + ) + + entries = data.get('audit_log_entries', []) + + if data and entries: + if limit is not None: + limit -= len(data) + + before = Object(id=int(entries[-1]['id'])) + + return data.get('users', []), entries, before, limit + + async def _after_strategy(retrieve, after, limit): + after_id = after.id if after else None + data = await self._state.http.get_audit_logs( + self.id, limit=retrieve, user_id=user_id, action_type=action, after=after_id + ) + + entries = data.get('audit_log_entries', []) + + if data and entries: + if limit is not None: + limit -= len(data) + + after = Object(id=int(entries[0]['id'])) + + return data.get('users', []), entries, after, limit + if user is not None: user_id = user.id else: @@ -2804,9 +2859,53 @@ class Guild(Hashable): if action: action = action.value - return AuditLogIterator( - self, before=before, after=after, limit=limit, oldest_first=oldest_first, user_id=user_id, action_type=action - ) + if isinstance(before, datetime.datetime): + before = Object(id=utils.time_snowflake(before, high=False)) + if isinstance(after, datetime.datetime): + after = Object(id=utils.time_snowflake(after, high=True)) + + + if oldest_first is None: + reverse = after is not None + else: + reverse = oldest_first + + predicate = None + + if reverse: + strategy, state = _after_strategy, after + if before: + predicate = lambda m: int(m['id']) < before.id + else: + strategy, state = _before_strategy, before + if after and after != OLDEST_OBJECT: + predicate = lambda m: int(m['id']) > after.id + + while True: + retrieve = min(100 if limit is None else limit, 100) + if retrieve < 1: + return + + raw_users, data, state, limit = await strategy(retrieve, state, limit) + + # Terminate loop on next iteration; there's no data left after this + if len(data) < 100: + limit = 0 + + if reverse: + data = reversed(data) + if predicate: + data = filter(predicate, data) + + users = (User(data=raw_user, state=self._state) for raw_user in raw_users) + user_map = {user.id: user for user in users} + + for raw_entry in data: + # Weird Discord quirk + if raw_entry['action_type'] is None: + continue + + yield AuditLogEntry(data=raw_entry, users=user_map, guild=self) async def widget(self) -> Widget: """|coro| diff --git a/discord/iterators.py b/discord/iterators.py deleted file mode 100644 index f725d527e..000000000 --- a/discord/iterators.py +++ /dev/null @@ -1,753 +0,0 @@ -""" -The MIT License (MIT) - -Copyright (c) 2015-present Rapptz - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. -""" - -from __future__ import annotations - -import asyncio -import datetime -from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator - -from .errors import NoMoreItems -from .utils import snowflake_time, time_snowflake, maybe_coroutine -from .object import Object -from .audit_logs import AuditLogEntry - -__all__ = ( - 'ReactionIterator', - 'HistoryIterator', - 'AuditLogIterator', - 'GuildIterator', - 'MemberIterator', -) - -if TYPE_CHECKING: - from .types.audit_log import ( - AuditLog as AuditLogPayload, - ) - from .types.guild import ( - Guild as GuildPayload, - ) - from .types.message import ( - Message as MessagePayload, - ) - from .types.user import ( - PartialUser as PartialUserPayload, - ) - - from .types.threads import ( - Thread as ThreadPayload, - ) - - from .member import Member - from .user import User - from .message import Message - from .audit_logs import AuditLogEntry - from .guild import Guild - from .threads import Thread - from .abc import Snowflake - -T = TypeVar('T') -OT = TypeVar('OT') -_Func = Callable[[T], Union[OT, Awaitable[OT]]] - -OLDEST_OBJECT = Object(id=0) - - -class _AsyncIterator(AsyncIterator[T]): - __slots__ = () - - async def next(self) -> T: - raise NotImplementedError - - def get(self, **attrs: Any) -> Awaitable[Optional[T]]: - def predicate(elem: T): - for attr, val in attrs.items(): - nested = attr.split('__') - obj = elem - for attribute in nested: - obj = getattr(obj, attribute) - - if obj != val: - return False - return True - - return self.find(predicate) - - async def find(self, predicate: _Func[T, bool]) -> Optional[T]: - while True: - try: - elem = await self.next() - except NoMoreItems: - return None - - ret = await maybe_coroutine(predicate, elem) - if ret: - return elem - - def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]: - if max_size <= 0: - raise ValueError('async iterator chunk sizes must be greater than 0.') - return _ChunkedAsyncIterator(self, max_size) - - def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]: - return _MappedAsyncIterator(self, func) - - def filter(self, predicate: _Func[T, bool]) -> _FilteredAsyncIterator[T]: - return _FilteredAsyncIterator(self, predicate) - - async def flatten(self) -> List[T]: - return [element async for element in self] - - async def __anext__(self) -> T: - try: - return await self.next() - except NoMoreItems: - raise StopAsyncIteration() - - -def _identity(x): - return x - - -class _ChunkedAsyncIterator(_AsyncIterator[List[T]]): - def __init__(self, iterator, max_size): - self.iterator = iterator - self.max_size = max_size - - async def next(self) -> List[T]: - ret: List[T] = [] - n = 0 - while n < self.max_size: - try: - item = await self.iterator.next() - except NoMoreItems: - if ret: - return ret - raise - else: - ret.append(item) - n += 1 - return ret - - -class _MappedAsyncIterator(_AsyncIterator[T]): - def __init__(self, iterator, func): - self.iterator = iterator - self.func = func - - async def next(self) -> T: - # this raises NoMoreItems and will propagate appropriately - item = await self.iterator.next() - return await maybe_coroutine(self.func, item) - - -class _FilteredAsyncIterator(_AsyncIterator[T]): - def __init__(self, iterator, predicate): - self.iterator = iterator - - if predicate is None: - predicate = _identity - - self.predicate = predicate - - async def next(self) -> T: - getter = self.iterator.next - pred = self.predicate - while True: - # propagate NoMoreItems similar to _MappedAsyncIterator - item = await getter() - ret = await maybe_coroutine(pred, item) - if ret: - return item - - -class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): - def __init__(self, message, emoji, limit=100, after=None): - self.message = message - self.limit = limit - self.after = after - state = message._state - self.getter = state.http.get_reaction_users - self.state = state - self.emoji = emoji - self.guild = message.guild - self.channel_id = message.channel.id - self.users = asyncio.Queue() - - async def next(self) -> Union[User, Member]: - if self.users.empty(): - await self.fill_users() - - try: - return self.users.get_nowait() - except asyncio.QueueEmpty: - raise NoMoreItems() - - async def fill_users(self): - # this is a hack because >circular imports< - from .user import User - - if self.limit > 0: - retrieve = self.limit if self.limit <= 100 else 100 - - after = self.after.id if self.after else None - data: List[PartialUserPayload] = await self.getter( - self.channel_id, self.message.id, self.emoji, retrieve, after=after - ) - - if data: - self.limit -= retrieve - self.after = Object(id=int(data[-1]['id'])) - - if self.guild is None or isinstance(self.guild, Object): - for element in reversed(data): - await self.users.put(User(state=self.state, data=element)) - else: - for element in reversed(data): - member_id = int(element['id']) - member = self.guild.get_member(member_id) - if member is not None: - await self.users.put(member) - else: - await self.users.put(User(state=self.state, data=element)) - - -class HistoryIterator(_AsyncIterator['Message']): - """Iterator for receiving a channel's message history. - - The messages endpoint has two behaviours we care about here: - If ``before`` is specified, the messages endpoint returns the `limit` - newest messages before ``before``, sorted with newest first. For filling over - 100 messages, update the ``before`` parameter to the oldest message received. - Messages will be returned in order by time. - If ``after`` is specified, it returns the ``limit`` oldest messages after - ``after``, sorted with newest first. For filling over 100 messages, update the - ``after`` parameter to the newest message received. If messages are not - reversed, they will be out of order (99-0, 199-100, so on) - - A note that if both ``before`` and ``after`` are specified, ``before`` is ignored by the - messages endpoint. - - Parameters - ----------- - messageable: :class:`abc.Messageable` - Messageable class to retrieve message history from. - limit: :class:`int` - Maximum number of messages to retrieve - before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] - Message before which all messages must be. - after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] - Message after which all messages must be. - around: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] - Message around which all messages must be. Limit max 101. Note that if - limit is an even number, this will return at most limit+1 messages. - oldest_first: Optional[:class:`bool`] - If set to ``True``, return messages in oldest->newest order. Defaults to - ``True`` if `after` is specified, otherwise ``False``. - """ - - def __init__(self, messageable, limit, before=None, after=None, around=None, oldest_first=None): - - if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) - if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) - if isinstance(around, datetime.datetime): - around = Object(id=time_snowflake(around)) - - if oldest_first is None: - self.reverse = after is not None - else: - self.reverse = oldest_first - - self.messageable = messageable - self.limit = limit - self.before = before - self.after = after or OLDEST_OBJECT - self.around = around - - self._filter = None # message dict -> bool - - self.state = self.messageable._state - self.logs_from = self.state.http.logs_from - self.messages = asyncio.Queue() - - if self.around: - if self.limit is None: - raise ValueError('history does not support around with limit=None') - if self.limit > 101: - raise ValueError("history max limit 101 when specifying around parameter") - elif self.limit == 101: - self.limit = 100 # Thanks discord - - self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore - if self.before and self.after: - self._filter = lambda m: self.after.id < int(m['id']) < self.before.id - elif self.before: - self._filter = lambda m: int(m['id']) < self.before.id - elif self.after: - self._filter = lambda m: self.after.id < int(m['id']) - else: - if self.reverse: - self._retrieve_messages = self._retrieve_messages_after_strategy # type: ignore - if self.before: - self._filter = lambda m: int(m['id']) < self.before.id - else: - self._retrieve_messages = self._retrieve_messages_before_strategy # type: ignore - if self.after and self.after != OLDEST_OBJECT: - self._filter = lambda m: int(m['id']) > self.after.id - - async def next(self) -> Message: - if self.messages.empty(): - await self.fill_messages() - - try: - return self.messages.get_nowait() - except asyncio.QueueEmpty: - raise NoMoreItems() - - def _get_retrieve(self): - l = self.limit - if l is None or l > 100: - r = 100 - else: - r = l - self.retrieve = r - return r > 0 - - async def fill_messages(self): - if not hasattr(self, 'channel'): - # do the required set up - channel = await self.messageable._get_channel() - self.channel = channel - - if self._get_retrieve(): - data = await self._retrieve_messages(self.retrieve) - if len(data) < 100: - self.limit = 0 # terminate the infinite loop - - if self.reverse: - data = reversed(data) - if self._filter: - data = filter(self._filter, data) - - channel = self.channel - for element in data: - await self.messages.put(self.state.create_message(channel=channel, data=element)) - - async def _retrieve_messages(self, retrieve) -> List[Message]: - """Retrieve messages and update next parameters.""" - raise NotImplementedError - - async def _retrieve_messages_before_strategy(self, retrieve): - """Retrieve messages using before parameter.""" - before = self.before.id if self.before else None - data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, before=before) - if len(data): - if self.limit is not None: - self.limit -= retrieve - self.before = Object(id=int(data[-1]['id'])) - return data - - async def _retrieve_messages_after_strategy(self, retrieve): - """Retrieve messages using after parameter.""" - after = self.after.id if self.after else None - data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, after=after) - if len(data): - if self.limit is not None: - self.limit -= retrieve - self.after = Object(id=int(data[0]['id'])) - return data - - async def _retrieve_messages_around_strategy(self, retrieve): - """Retrieve messages using around parameter.""" - if self.around: - around = self.around.id if self.around else None - data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, around=around) - self.around = None - return data - return [] - - -class AuditLogIterator(_AsyncIterator['AuditLogEntry']): - def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None): - if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) - if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) - - if oldest_first is None: - self.reverse = after is not None - else: - self.reverse = oldest_first - - self.guild = guild - self.loop = guild._state.loop - self.request = guild._state.http.get_audit_logs - self.limit = limit - self.before = before - self.user_id = user_id - self.action_type = action_type - self.after = OLDEST_OBJECT - self._users = {} - self._state = guild._state - - self._filter = None # entry dict -> bool - - self.entries = asyncio.Queue() - - if self.reverse: - self._strategy = self._after_strategy - if self.before: - self._filter = lambda m: int(m['id']) < self.before.id - else: - self._strategy = self._before_strategy - if self.after and self.after != OLDEST_OBJECT: - self._filter = lambda m: int(m['id']) > self.after.id - - async def _before_strategy(self, retrieve): - before = self.before.id if self.before else None - data: AuditLogPayload = await self.request( - self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before - ) - - entries = data.get('audit_log_entries', []) - if len(data) and entries: - if self.limit is not None: - self.limit -= retrieve - self.before = Object(id=int(entries[-1]['id'])) - return data.get('users', []), entries - - async def _after_strategy(self, retrieve): - after = self.after.id if self.after else None - data: AuditLogPayload = await self.request( - self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after - ) - entries = data.get('audit_log_entries', []) - if len(data) and entries: - if self.limit is not None: - self.limit -= retrieve - self.after = Object(id=int(entries[0]['id'])) - return data.get('users', []), entries - - async def next(self) -> AuditLogEntry: - if self.entries.empty(): - await self._fill() - - try: - return self.entries.get_nowait() - except asyncio.QueueEmpty: - raise NoMoreItems() - - def _get_retrieve(self): - l = self.limit - if l is None or l > 100: - r = 100 - else: - r = l - self.retrieve = r - return r > 0 - - async def _fill(self): - from .user import User - - if self._get_retrieve(): - users, data = await self._strategy(self.retrieve) - if len(data) < 100: - self.limit = 0 # terminate the infinite loop - - if self.reverse: - data = reversed(data) - if self._filter: - data = filter(self._filter, data) - - for user in users: - u = User(data=user, state=self._state) - self._users[u.id] = u - - for element in data: - # TODO: remove this if statement later - if element['action_type'] is None: - continue - - await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild)) - - -class GuildIterator(_AsyncIterator['Guild']): - """Iterator for receiving the client's guilds. - - The guilds endpoint has the same two behaviours as described - in :class:`HistoryIterator`: - If ``before`` is specified, the guilds endpoint returns the ``limit`` - newest guilds before ``before``, sorted with newest first. For filling over - 100 guilds, update the ``before`` parameter to the oldest guild received. - Guilds will be returned in order by time. - If `after` is specified, it returns the ``limit`` oldest guilds after ``after``, - sorted with newest first. For filling over 100 guilds, update the ``after`` - parameter to the newest guild received, If guilds are not reversed, they - will be out of order (99-0, 199-100, so on) - - Not that if both ``before`` and ``after`` are specified, ``before`` is ignored by the - guilds endpoint. - - Parameters - ----------- - bot: :class:`discord.Client` - The client to retrieve the guilds from. - limit: :class:`int` - Maximum number of guilds to retrieve. - before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] - Object before which all guilds must be. - after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] - Object after which all guilds must be. - """ - - def __init__(self, bot, limit, before=None, after=None): - - if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) - if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) - - self.bot = bot - self.limit = limit - self.before = before - self.after = after - - self._filter = None - - self.state = self.bot._connection - self.get_guilds = self.bot.http.get_guilds - self.guilds = asyncio.Queue() - - if self.before and self.after: - self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore - self._filter = lambda m: int(m['id']) > self.after.id - elif self.after: - self._retrieve_guilds = self._retrieve_guilds_after_strategy # type: ignore - else: - self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore - - async def next(self) -> Guild: - if self.guilds.empty(): - await self.fill_guilds() - - try: - return self.guilds.get_nowait() - except asyncio.QueueEmpty: - raise NoMoreItems() - - def _get_retrieve(self): - l = self.limit - if l is None or l > 100: - r = 100 - else: - r = l - self.retrieve = r - return r > 0 - - def create_guild(self, data): - from .guild import Guild - - return Guild(state=self.state, data=data) - - async def fill_guilds(self): - if self._get_retrieve(): - data = await self._retrieve_guilds(self.retrieve) - if self.limit is None or len(data) < 100: - self.limit = 0 - - if self._filter: - data = filter(self._filter, data) - - for element in data: - await self.guilds.put(self.create_guild(element)) - - async def _retrieve_guilds(self, retrieve) -> List[Guild]: - """Retrieve guilds and update next parameters.""" - raise NotImplementedError - - async def _retrieve_guilds_before_strategy(self, retrieve): - """Retrieve guilds using before parameter.""" - before = self.before.id if self.before else None - data: List[GuildPayload] = await self.get_guilds(retrieve, before=before) - if len(data): - if self.limit is not None: - self.limit -= retrieve - self.before = Object(id=int(data[-1]['id'])) - return data - - async def _retrieve_guilds_after_strategy(self, retrieve): - """Retrieve guilds using after parameter.""" - after = self.after.id if self.after else None - data: List[GuildPayload] = await self.get_guilds(retrieve, after=after) - if len(data): - if self.limit is not None: - self.limit -= retrieve - self.after = Object(id=int(data[0]['id'])) - return data - - -class MemberIterator(_AsyncIterator['Member']): - def __init__(self, guild, limit=1000, after=None): - - if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) - - self.guild = guild - self.limit = limit - self.after = after or OLDEST_OBJECT - - self.state = self.guild._state - self.get_members = self.state.http.get_members - self.members = asyncio.Queue() - - async def next(self) -> Member: - if self.members.empty(): - await self.fill_members() - - try: - return self.members.get_nowait() - except asyncio.QueueEmpty: - raise NoMoreItems() - - def _get_retrieve(self): - l = self.limit - if l is None or l > 1000: - r = 1000 - else: - r = l - self.retrieve = r - return r > 0 - - async def fill_members(self): - if self._get_retrieve(): - after = self.after.id if self.after else None - data = await self.get_members(self.guild.id, self.retrieve, after) - if not data: - # no data, terminate - return - - if len(data) < 1000: - self.limit = 0 # terminate loop - - self.after = Object(id=int(data[-1]['user']['id'])) - - for element in reversed(data): - await self.members.put(self.create_member(element)) - - def create_member(self, data): - from .member import Member - - return Member(data=data, guild=self.guild, state=self.state) - - -class ArchivedThreadIterator(_AsyncIterator['Thread']): - def __init__( - self, - channel_id: int, - guild: Guild, - limit: Optional[int], - joined: bool, - private: bool, - before: Optional[Union[Snowflake, datetime.datetime]] = None, - ): - self.channel_id = channel_id - self.guild = guild - self.limit = limit - self.joined = joined - self.private = private - self.http = guild._state.http - - if joined and not private: - raise ValueError('Cannot iterate over joined public archived threads') - - self.before: Optional[str] - if before is None: - self.before = None - elif isinstance(before, datetime.datetime): - if joined: - self.before = str(time_snowflake(before, high=False)) - else: - self.before = before.isoformat() - else: - if joined: - self.before = str(before.id) - else: - self.before = snowflake_time(before.id).isoformat() - - self.update_before: Callable[[ThreadPayload], str] = self.get_archive_timestamp - - if joined: - self.endpoint = self.http.get_joined_private_archived_threads - self.update_before = self.get_thread_id - elif private: - self.endpoint = self.http.get_private_archived_threads - else: - self.endpoint = self.http.get_public_archived_threads - - self.queue: asyncio.Queue[Thread] = asyncio.Queue() - self.has_more: bool = True - - async def next(self) -> Thread: - if self.queue.empty(): - await self.fill_queue() - - try: - return self.queue.get_nowait() - except asyncio.QueueEmpty: - raise NoMoreItems() - - @staticmethod - def get_archive_timestamp(data: ThreadPayload) -> str: - return data['thread_metadata']['archive_timestamp'] - - @staticmethod - def get_thread_id(data: ThreadPayload) -> str: - return data['id'] # type: ignore - - async def fill_queue(self) -> None: - if not self.has_more: - raise NoMoreItems() - - limit = 50 if self.limit is None else max(self.limit, 50) - data = await self.endpoint(self.channel_id, before=self.before, limit=limit) - - # This stuff is obviously WIP because 'members' is always empty - threads: List[ThreadPayload] = data.get('threads', []) - for d in reversed(threads): - self.queue.put_nowait(self.create_thread(d)) - - self.has_more = data.get('has_more', False) - if self.limit is not None: - self.limit -= len(threads) - if self.limit <= 0: - self.has_more = False - - if self.has_more: - self.before = self.update_before(threads[-1]) - - def create_thread(self, data: ThreadPayload) -> Thread: - from .threads import Thread - return Thread(guild=self.guild, state=self.guild._state, data=data) diff --git a/discord/object.py b/discord/object.py index 3795425f6..8ba0afd72 100644 --- a/discord/object.py +++ b/discord/object.py @@ -24,8 +24,8 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from . import utils from .mixins import Hashable +from .utils import snowflake_time from typing import ( SupportsInt, @@ -89,4 +89,7 @@ class Object(Hashable): @property def created_at(self) -> datetime.datetime: """:class:`datetime.datetime`: Returns the snowflake's creation time in UTC.""" - return utils.snowflake_time(self.id) + return snowflake_time(self.id) + + +OLDEST_OBJECT = Object(id=0) diff --git a/discord/reaction.py b/discord/reaction.py index 0bf885e6b..3afa310b2 100644 --- a/discord/reaction.py +++ b/discord/reaction.py @@ -23,15 +23,18 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations -from typing import Any, TYPE_CHECKING, Union, Optional +from typing import Any, TYPE_CHECKING, AsyncIterator, List, Union, Optional +from typing_extensions import reveal_type -from .iterators import ReactionIterator +from .object import Object __all__ = ( 'Reaction', ) if TYPE_CHECKING: + from .user import User + from .member import Member from .types.message import Reaction as ReactionPayload from .message import Message from .partial_emoji import PartialEmoji @@ -155,8 +158,8 @@ class Reaction: """ await self.message.clear_reaction(self.emoji) - def users(self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None) -> ReactionIterator: - """Returns an :class:`AsyncIterator` representing the users that have reacted to the message. + async def users(self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None) -> AsyncIterator[Union[Member, User]]: + """Returns an :term:`asynchronous iterator` representing the users that have reacted to the message. The ``after`` parameter must represent a member and meet the :class:`abc.Snowflake` abc. @@ -176,7 +179,7 @@ class Reaction: Flattening into a list: :: - users = await reaction.users().flatten() + users = [user async for user in reaction.users()] # users is now a list of User... winner = random.choice(users) await channel.send(f'{winner} has won the raffle.') @@ -212,4 +215,31 @@ class Reaction: if limit is None: limit = self.count - return ReactionIterator(self.message, emoji, limit, after) + while limit > 0: + retrieve = min(limit, 100) + + message = self.message + guild = message.guild + state = message._state + after_id = after.id if after else None + + data = await state.http.get_reaction_users( + message.channel.id, message.id, emoji, retrieve, after=after_id + ) + + if data: + limit -= len(data) + after = Object(id=int(data[-1]['id'])) + + if guild is None or isinstance(guild, Object): + for raw_user in reversed(data): + yield User(state=state, data=raw_user) + + continue + + for raw_user in reversed(data): + member_id = int(raw_user['id']) + member = guild.get_member(member_id) + + yield member or User(state=state, data=raw_user) + diff --git a/docs/api.rst b/docs/api.rst index a76573c4f..ce43a92cd 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -2614,135 +2614,6 @@ of :class:`enum.Enum`. The guild may contain NSFW content. -Async Iterator ----------------- - -Some API functions return an "async iterator". An async iterator is something that is -capable of being used in an :ref:`async for statement `. - -These async iterators can be used as follows: :: - - async for elem in channel.history(): - # do stuff with elem here - -Certain utilities make working with async iterators easier, detailed below. - -.. class:: AsyncIterator - - Represents the "AsyncIterator" concept. Note that no such class exists, - it is purely abstract. - - .. container:: operations - - .. describe:: async for x in y - - Iterates over the contents of the async iterator. - - - .. method:: next() - :async: - - |coro| - - Advances the iterator by one, if possible. If no more items are found - then this raises :exc:`NoMoreItems`. - - .. method:: get(**attrs) - :async: - - |coro| - - Similar to :func:`utils.get` except run over the async iterator. - - Getting the last message by a user named 'Dave' or ``None``: :: - - msg = await channel.history().get(author__name='Dave') - - .. method:: find(predicate) - :async: - - |coro| - - Similar to :func:`utils.find` except run over the async iterator. - - Unlike :func:`utils.find`\, the predicate provided can be a - |coroutine_link|_. - - Getting the last audit log with a reason or ``None``: :: - - def predicate(event): - return event.reason is not None - - event = await guild.audit_logs().find(predicate) - - :param predicate: The predicate to use. Could be a |coroutine_link|_. - :return: The first element that returns ``True`` for the predicate or ``None``. - - .. method:: flatten() - :async: - - |coro| - - Flattens the async iterator into a :class:`list` with all the elements. - - :return: A list of every element in the async iterator. - :rtype: list - - .. method:: chunk(max_size) - - Collects items into chunks of up to a given maximum size. - Another :class:`AsyncIterator` is returned which collects items into - :class:`list`\s of a given size. The maximum chunk size must be a positive integer. - - .. versionadded:: 1.6 - - Collecting groups of users: :: - - async for leader, *users in reaction.users().chunk(3): - ... - - .. warning:: - - The last chunk collected may not be as large as ``max_size``. - - :param max_size: The size of individual chunks. - :rtype: :class:`AsyncIterator` - - .. method:: map(func) - - This is similar to the built-in :func:`map ` function. Another - :class:`AsyncIterator` is returned that executes the function on - every element it is iterating over. This function can either be a - regular function or a |coroutine_link|_. - - Creating a content iterator: :: - - def transform(message): - return message.content - - async for content in channel.history().map(transform): - message_length = len(content) - - :param func: The function to call on every element. Could be a |coroutine_link|_. - :rtype: :class:`AsyncIterator` - - .. method:: filter(predicate) - - This is similar to the built-in :func:`filter ` function. Another - :class:`AsyncIterator` is returned that filters over the original - async iterator. This predicate can be a regular function or a |coroutine_link|_. - - Getting messages by non-bot accounts: :: - - def predicate(message): - return not message.author.bot - - async for elem in channel.history().filter(predicate): - ... - - :param predicate: The predicate to call on every element. Could be a |coroutine_link|_. - :rtype: :class:`AsyncIterator` - .. _discord-api-audit-logs: Audit Log Data