Browse Source

Refactor AsyncIter to use 3.6+ asynchronous generators

pull/7494/head
Kaylynn Morgan 3 years ago
committed by GitHub
parent
commit
588cda0996
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 110
      discord/abc.py
  2. 54
      discord/channel.py
  3. 88
      discord/client.py
  4. 133
      discord/guild.py
  5. 753
      discord/iterators.py
  6. 7
      discord/object.py
  7. 42
      discord/reaction.py
  8. 129
      docs/api.rst

110
discord/abc.py

@ -26,8 +26,10 @@ from __future__ import annotations
import copy
import asyncio
from datetime import datetime
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
List,
@ -42,7 +44,7 @@ from typing import (
runtime_checkable,
)
from .iterators import HistoryIterator
from .object import OLDEST_OBJECT, Object
from .context_managers import Typing
from .enums import ChannelType
from .errors import InvalidArgument, ClientException
@ -68,8 +70,6 @@ __all__ = (
T = TypeVar('T', bound=VoiceProtocol)
if TYPE_CHECKING:
from datetime import datetime
from .client import Client
from .user import ClientUser
from .asset import Asset
@ -1465,7 +1465,7 @@ class Messageable:
data = await state.http.pins_from(channel.id)
return [state.create_message(channel=channel, data=m) for m in data]
def history(
async def history(
self,
*,
limit: Optional[int] = 100,
@ -1473,8 +1473,8 @@ class Messageable:
after: Optional[SnowflakeTime] = None,
around: Optional[SnowflakeTime] = None,
oldest_first: Optional[bool] = None,
) -> HistoryIterator:
"""Returns an :class:`~discord.AsyncIterator` that enables receiving the destination's message history.
) -> AsyncIterator[Message]:
"""Returns an :term:`asynchronous iterator` that enables receiving the destination's message history.
You must have :attr:`~discord.Permissions.read_message_history` permissions to use this.
@ -1490,7 +1490,7 @@ class Messageable:
Flattening into a list: ::
messages = await channel.history(limit=123).flatten()
messages = [message async for message in channel.history(limit=123)]
# messages is now a list of Message...
All parameters are optional.
@ -1531,7 +1531,101 @@ class Messageable:
:class:`~discord.Message`
The message with the message data parsed.
"""
return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first)
async def _around_strategy(retrieve, around, limit):
if not around:
return []
around_id = around.id if around else None
data = await self._state.http.logs_from(channel.id, retrieve, around=around_id)
return data, None, limit
async def _after_strategy(retrieve, after, limit):
after_id = after.id if after else None
data = await self._state.http.logs_from(channel.id, retrieve, after=after_id)
if data:
if limit is not None:
limit -= len(data)
after = Object(id=int(data[0]['id']))
return data, after, limit
async def _before_strategy(retrieve, before, limit):
before_id = before.id if before else None
data = await self._state.http.logs_from(channel.id, retrieve, before=before_id)
if data:
if limit is not None:
limit -= len(data)
before = Object(id=int(data[-1]['id']))
return data, before, limit
if isinstance(before, datetime):
before = Object(id=utils.time_snowflake(before, high=False))
if isinstance(after, datetime):
after = Object(id=utils.time_snowflake(after, high=True))
if isinstance(around, datetime):
around = Object(id=utils.time_snowflake(around))
if oldest_first is None:
reverse = after is not None
else:
reverse = oldest_first
after = after or OLDEST_OBJECT
predicate = None
if around:
if limit is None:
raise ValueError('history does not support around with limit=None')
if limit > 101:
raise ValueError("history max limit 101 when specifying around parameter")
# Strange Discord quirk
limit = 100 if limit == 101 else limit
strategy, state = _around_strategy, around
if before and after:
predicate = lambda m: after.id < int(m['id']) < before.id
elif before:
predicate = lambda m: int(m['id']) < before.id
elif after:
predicate = lambda m: after.id < int(m['id'])
elif reverse:
strategy, state = _after_strategy, after
if before:
predicate = lambda m: int(m['id']) < before.id
else:
strategy, state = _before_strategy, before
if after and after != OLDEST_OBJECT:
predicate = lambda m: int(m['id']) > after.id
channel = await self._get_channel()
while True:
retrieve = min(100 if limit is None else limit, 100)
if retrieve < 1:
return
data, state, limit = await strategy(retrieve, state, limit)
# Terminate loop on next iteration; there's no data left after this
if len(data) < 100:
limit = 0
if reverse:
data = reversed(data)
if predicate:
data = filter(predicate, data)
for raw_message in data:
yield self._state.create_message(channel=channel, data=raw_message)
class Connectable(Protocol):

54
discord/channel.py

@ -28,6 +28,7 @@ import time
import asyncio
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterable,
@ -54,7 +55,6 @@ from .asset import Asset
from .errors import ClientException, InvalidArgument
from .stage_instance import StageInstance
from .threads import Thread
from .iterators import ArchivedThreadIterator
__all__ = (
'TextChannel',
@ -755,15 +755,15 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
return Thread(guild=self.guild, state=self._state, data=data)
def archived_threads(
async def archived_threads(
self,
*,
private: bool = False,
joined: bool = False,
limit: Optional[int] = 50,
before: Optional[Union[Snowflake, datetime.datetime]] = None,
) -> ArchivedThreadIterator:
"""Returns an :class:`~discord.AsyncIterator` that iterates over all archived threads in the guild.
) -> AsyncIterator[Thread]:
"""Returns an :term:`asynchronous iterator` that iterates over all archived threads in the guild.
You must have :attr:`~Permissions.read_message_history` to use this. If iterating over private threads
then :attr:`~Permissions.manage_threads` is also required.
@ -790,13 +790,57 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
You do not have permissions to get archived threads.
HTTPException
The request to get the archived threads failed.
ValueError
`joined`` was set to ``True`` and ``private`` was set to ``False``. You cannot retrieve public archived
threads that you have joined.
Yields
-------
:class:`Thread`
The archived threads.
"""
return ArchivedThreadIterator(self.id, self.guild, limit=limit, joined=joined, private=private, before=before)
if joined and not private:
raise ValueError('Cannot retrieve joined public archived threads')
before_timestamp = None
if isinstance(before, datetime.datetime):
if joined:
before_timestamp = str(utils.time_snowflake(before, high=False))
else:
before_timestamp = before.isoformat()
elif before is not None:
if joined:
before_timestamp = str(before.id)
else:
before_timestamp = utils.snowflake_time(before.id).isoformat()
update_before = lambda data: data['thread_metadata']['archive_timestamp']
endpoint = self.guild._state.http.get_public_archived_threads
if joined:
update_before = lambda data: data['id']
endpoint = self.guild._state.http.get_joined_private_archived_threads
elif private:
endpoint = self.guild._state.http.get_private_archived_threads
while True:
retrieve = 50 if limit is None else max(limit, 50)
data = await endpoint(self.id, before=before_timestamp, limit=retrieve)
threads = data.get('threads', [])
for raw_thread in reversed(threads):
yield Thread(guild=self.guild, state=self.guild._state, data=raw_thread)
if not data.get('has_more', False):
return
if limit is not None:
limit -= len(threads)
if limit <= 0:
return
before = update_before(threads[-1])
class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable):

88
discord/client.py

@ -25,11 +25,26 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import asyncio
import datetime
import logging
import signal
import sys
import traceback
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
Dict,
Generator,
List,
Optional,
Sequence,
TYPE_CHECKING,
Tuple,
TypeVar,
Union
)
import aiohttp
@ -51,11 +66,10 @@ from .voice_client import VoiceClient
from .http import HTTPClient
from .state import ConnectionState
from . import utils
from .utils import MISSING
from .utils import MISSING, time_snowflake
from .object import Object
from .backoff import ExponentialBackoff
from .webhook import Webhook
from .iterators import GuildIterator
from .appinfo import AppInfo
from .ui.view import View
from .stage_instance import StageInstance
@ -63,6 +77,7 @@ from .threads import Thread
from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory
if TYPE_CHECKING:
from .types.guild import Guild as GuildPayload
from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake
from .channel import DMChannel
from .message import Message
@ -1120,14 +1135,14 @@ class Client:
# Guild stuff
def fetch_guilds(
async def fetch_guilds(
self,
*,
limit: Optional[int] = 100,
before: SnowflakeTime = None,
after: SnowflakeTime = None
) -> GuildIterator:
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
before: Optional[SnowflakeTime] = None,
after: Optional[SnowflakeTime] = None,
) -> AsyncIterator[Guild]:
"""Retrieves an :term:`asynchronous iterator` that enables receiving your guilds.
.. note::
@ -1148,7 +1163,7 @@ class Client:
Flattening into a list ::
guilds = await client.fetch_guilds(limit=150).flatten()
guilds = [guild async for guild in client.fetch_guilds(limit=150)]
# guilds is now a list of Guild...
All parameters are optional.
@ -1179,7 +1194,60 @@ class Client:
:class:`.Guild`
The guild with the guild data parsed.
"""
return GuildIterator(self, limit=limit, before=before, after=after)
async def _before_strategy(retrieve, before, limit):
before_id = before.id if before else None
data = await self.http.get_guilds(retrieve, before=before_id)
if data:
if limit is not None:
limit -= len(data)
before = Object(id=int(data[-1]['id']))
return data, before, limit
async def _after_strategy(retrieve, after, limit):
after_id = after.id if after else None
data = await self.http.get_guilds(retrieve, after=after_id)
if data:
if limit is not None:
limit -= len(data)
after = Object(id=int(data[0]['id']))
return data, after, limit
if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False))
if isinstance(after, datetime.datetime):
after = Object(id=time_snowflake(after, high=True))
predicate = None
strategy, state = _before_strategy, before
if before and after:
predicate = lambda m: int(m['id']) > after.id # type: ignore
elif after:
strategy, state = _after_strategy, after
while True:
retrieve = min(100 if limit is None else limit, 100)
if retrieve < 1:
return
data, state, limit = await strategy(retrieve, state, limit)
# Terminate loop on next iteration; there's no data left after this
if len(data) < 100:
limit = 0
if predicate:
data = filter(predicate, data)
for raw_guild in data:
yield Guild(state=self._connection, data=raw_guild)
async def fetch_template(self, code: Union[Template, str]) -> Template:
"""|coro|

133
discord/guild.py

@ -25,9 +25,11 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import copy
import datetime
import unicodedata
from typing import (
Any,
AsyncIterator,
ClassVar,
Dict,
List,
@ -67,7 +69,6 @@ from .enums import (
from .mixins import Hashable
from .user import User
from .invite import Invite
from .iterators import AuditLogIterator, MemberIterator
from .widget import Widget
from .asset import Asset
from .flags import SystemChannelFlags
@ -76,6 +77,8 @@ from .stage_instance import StageInstance
from .threads import Thread, ThreadMember
from .sticker import GuildSticker
from .file import File
from .audit_logs import AuditLogEntry
from .object import OLDEST_OBJECT, Object
__all__ = (
@ -98,8 +101,6 @@ if TYPE_CHECKING:
from .state import ConnectionState
from .voice_client import VoiceProtocol
import datetime
VocalGuildChannel = Union[VoiceChannel, StageChannel]
GuildChannel = Union[VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel]
ByCategoryItem = Tuple[Optional[CategoryChannel], List[GuildChannel]]
@ -1649,9 +1650,8 @@ class Guild(Hashable):
return threads
# TODO: Remove Optional typing here when async iterators are refactored
def fetch_members(self, *, limit: int = 1000, after: Optional[SnowflakeTime] = None) -> MemberIterator:
"""Retrieves an :class:`.AsyncIterator` that enables receiving the guild's members. In order to use this,
async def fetch_members(self, *, limit: int = 1000, after: SnowflakeTime = MISSING) -> AsyncIterator[Member]:
"""Retrieves an :term:`asynchronous iterator` that enables receiving the guild's members. In order to use this,
:meth:`Intents.members` must be enabled.
.. note::
@ -1701,7 +1701,30 @@ class Guild(Hashable):
if not self._state._intents.members:
raise ClientException('Intents.members must be enabled to use this.')
return MemberIterator(self, limit=limit, after=after)
while True:
retrieve = min(1000 if limit is None else limit, 1000)
if retrieve < 1:
return
if isinstance(after, datetime.datetime):
after = Object(id=utils.time_snowflake(after, high=True))
after = after or OLDEST_OBJECT
after_id = after.id if after else None
state = self._state
data = await state.http.get_members(self.id, retrieve, after_id)
if not data:
return
# Terminate loop on next iteration; there's no data left after this
if len(data) < 1000:
limit = 0
after = Object(id=int(data[-1]['user']['id']))
for raw_member in reversed(data):
yield Member(data=raw_member, guild=self, state=state)
async def fetch_member(self, member_id: int, /) -> Member:
"""|coro|
@ -2731,18 +2754,17 @@ class Guild(Hashable):
payload['uses'] = payload.get('uses', 0)
return Invite(state=self._state, data=payload, guild=self, channel=channel)
# TODO: use MISSING when async iterators get refactored
def audit_logs(
async def audit_logs(
self,
*,
limit: int = 100,
before: Optional[SnowflakeTime] = None,
after: Optional[SnowflakeTime] = None,
oldest_first: Optional[bool] = None,
user: Snowflake = None,
action: AuditLogAction = None,
) -> AuditLogIterator:
"""Returns an :class:`AsyncIterator` that enables receiving the guild's audit logs.
user: Snowflake = MISSING,
action: AuditLogAction = MISSING,
) -> AsyncIterator[AuditLogEntry]:
"""Returns an :term:`asynchronous iterator` that enables receiving the guild's audit logs.
You must have the :attr:`~Permissions.view_audit_log` permission to use this.
@ -2761,7 +2783,7 @@ class Guild(Hashable):
Getting entries made by a specific user: ::
entries = await guild.audit_logs(limit=None, user=guild.me).flatten()
entries = [entry async for entry in guild.audit_logs(limit=None, user=guild.me)]
await channel.send(f'I made {len(entries)} moderation actions.')
Parameters
@ -2796,6 +2818,39 @@ class Guild(Hashable):
:class:`AuditLogEntry`
The audit log entry.
"""
async def _before_strategy(retrieve, before, limit):
before_id = before.id if before else None
data = await self._state.http.get_audit_logs(
self.id, limit=retrieve, user_id=user_id, action_type=action, before=before_id
)
entries = data.get('audit_log_entries', [])
if data and entries:
if limit is not None:
limit -= len(data)
before = Object(id=int(entries[-1]['id']))
return data.get('users', []), entries, before, limit
async def _after_strategy(retrieve, after, limit):
after_id = after.id if after else None
data = await self._state.http.get_audit_logs(
self.id, limit=retrieve, user_id=user_id, action_type=action, after=after_id
)
entries = data.get('audit_log_entries', [])
if data and entries:
if limit is not None:
limit -= len(data)
after = Object(id=int(entries[0]['id']))
return data.get('users', []), entries, after, limit
if user is not None:
user_id = user.id
else:
@ -2804,9 +2859,53 @@ class Guild(Hashable):
if action:
action = action.value
return AuditLogIterator(
self, before=before, after=after, limit=limit, oldest_first=oldest_first, user_id=user_id, action_type=action
)
if isinstance(before, datetime.datetime):
before = Object(id=utils.time_snowflake(before, high=False))
if isinstance(after, datetime.datetime):
after = Object(id=utils.time_snowflake(after, high=True))
if oldest_first is None:
reverse = after is not None
else:
reverse = oldest_first
predicate = None
if reverse:
strategy, state = _after_strategy, after
if before:
predicate = lambda m: int(m['id']) < before.id
else:
strategy, state = _before_strategy, before
if after and after != OLDEST_OBJECT:
predicate = lambda m: int(m['id']) > after.id
while True:
retrieve = min(100 if limit is None else limit, 100)
if retrieve < 1:
return
raw_users, data, state, limit = await strategy(retrieve, state, limit)
# Terminate loop on next iteration; there's no data left after this
if len(data) < 100:
limit = 0
if reverse:
data = reversed(data)
if predicate:
data = filter(predicate, data)
users = (User(data=raw_user, state=self._state) for raw_user in raw_users)
user_map = {user.id: user for user in users}
for raw_entry in data:
# Weird Discord quirk
if raw_entry['action_type'] is None:
continue
yield AuditLogEntry(data=raw_entry, users=user_map, guild=self)
async def widget(self) -> Widget:
"""|coro|

753
discord/iterators.py

@ -1,753 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import asyncio
import datetime
from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator
from .errors import NoMoreItems
from .utils import snowflake_time, time_snowflake, maybe_coroutine
from .object import Object
from .audit_logs import AuditLogEntry
__all__ = (
'ReactionIterator',
'HistoryIterator',
'AuditLogIterator',
'GuildIterator',
'MemberIterator',
)
if TYPE_CHECKING:
from .types.audit_log import (
AuditLog as AuditLogPayload,
)
from .types.guild import (
Guild as GuildPayload,
)
from .types.message import (
Message as MessagePayload,
)
from .types.user import (
PartialUser as PartialUserPayload,
)
from .types.threads import (
Thread as ThreadPayload,
)
from .member import Member
from .user import User
from .message import Message
from .audit_logs import AuditLogEntry
from .guild import Guild
from .threads import Thread
from .abc import Snowflake
T = TypeVar('T')
OT = TypeVar('OT')
_Func = Callable[[T], Union[OT, Awaitable[OT]]]
OLDEST_OBJECT = Object(id=0)
class _AsyncIterator(AsyncIterator[T]):
__slots__ = ()
async def next(self) -> T:
raise NotImplementedError
def get(self, **attrs: Any) -> Awaitable[Optional[T]]:
def predicate(elem: T):
for attr, val in attrs.items():
nested = attr.split('__')
obj = elem
for attribute in nested:
obj = getattr(obj, attribute)
if obj != val:
return False
return True
return self.find(predicate)
async def find(self, predicate: _Func[T, bool]) -> Optional[T]:
while True:
try:
elem = await self.next()
except NoMoreItems:
return None
ret = await maybe_coroutine(predicate, elem)
if ret:
return elem
def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]:
if max_size <= 0:
raise ValueError('async iterator chunk sizes must be greater than 0.')
return _ChunkedAsyncIterator(self, max_size)
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
return _MappedAsyncIterator(self, func)
def filter(self, predicate: _Func[T, bool]) -> _FilteredAsyncIterator[T]:
return _FilteredAsyncIterator(self, predicate)
async def flatten(self) -> List[T]:
return [element async for element in self]
async def __anext__(self) -> T:
try:
return await self.next()
except NoMoreItems:
raise StopAsyncIteration()
def _identity(x):
return x
class _ChunkedAsyncIterator(_AsyncIterator[List[T]]):
def __init__(self, iterator, max_size):
self.iterator = iterator
self.max_size = max_size
async def next(self) -> List[T]:
ret: List[T] = []
n = 0
while n < self.max_size:
try:
item = await self.iterator.next()
except NoMoreItems:
if ret:
return ret
raise
else:
ret.append(item)
n += 1
return ret
class _MappedAsyncIterator(_AsyncIterator[T]):
def __init__(self, iterator, func):
self.iterator = iterator
self.func = func
async def next(self) -> T:
# this raises NoMoreItems and will propagate appropriately
item = await self.iterator.next()
return await maybe_coroutine(self.func, item)
class _FilteredAsyncIterator(_AsyncIterator[T]):
def __init__(self, iterator, predicate):
self.iterator = iterator
if predicate is None:
predicate = _identity
self.predicate = predicate
async def next(self) -> T:
getter = self.iterator.next
pred = self.predicate
while True:
# propagate NoMoreItems similar to _MappedAsyncIterator
item = await getter()
ret = await maybe_coroutine(pred, item)
if ret:
return item
class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
def __init__(self, message, emoji, limit=100, after=None):
self.message = message
self.limit = limit
self.after = after
state = message._state
self.getter = state.http.get_reaction_users
self.state = state
self.emoji = emoji
self.guild = message.guild
self.channel_id = message.channel.id
self.users = asyncio.Queue()
async def next(self) -> Union[User, Member]:
if self.users.empty():
await self.fill_users()
try:
return self.users.get_nowait()
except asyncio.QueueEmpty:
raise NoMoreItems()
async def fill_users(self):
# this is a hack because >circular imports<
from .user import User
if self.limit > 0:
retrieve = self.limit if self.limit <= 100 else 100
after = self.after.id if self.after else None
data: List[PartialUserPayload] = await self.getter(
self.channel_id, self.message.id, self.emoji, retrieve, after=after
)
if data:
self.limit -= retrieve
self.after = Object(id=int(data[-1]['id']))
if self.guild is None or isinstance(self.guild, Object):
for element in reversed(data):
await self.users.put(User(state=self.state, data=element))
else:
for element in reversed(data):
member_id = int(element['id'])
member = self.guild.get_member(member_id)
if member is not None:
await self.users.put(member)
else:
await self.users.put(User(state=self.state, data=element))
class HistoryIterator(_AsyncIterator['Message']):
"""Iterator for receiving a channel's message history.
The messages endpoint has two behaviours we care about here:
If ``before`` is specified, the messages endpoint returns the `limit`
newest messages before ``before``, sorted with newest first. For filling over
100 messages, update the ``before`` parameter to the oldest message received.
Messages will be returned in order by time.
If ``after`` is specified, it returns the ``limit`` oldest messages after
``after``, sorted with newest first. For filling over 100 messages, update the
``after`` parameter to the newest message received. If messages are not
reversed, they will be out of order (99-0, 199-100, so on)
A note that if both ``before`` and ``after`` are specified, ``before`` is ignored by the
messages endpoint.
Parameters
-----------
messageable: :class:`abc.Messageable`
Messageable class to retrieve message history from.
limit: :class:`int`
Maximum number of messages to retrieve
before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
Message before which all messages must be.
after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
Message after which all messages must be.
around: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
Message around which all messages must be. Limit max 101. Note that if
limit is an even number, this will return at most limit+1 messages.
oldest_first: Optional[:class:`bool`]
If set to ``True``, return messages in oldest->newest order. Defaults to
``True`` if `after` is specified, otherwise ``False``.
"""
def __init__(self, messageable, limit, before=None, after=None, around=None, oldest_first=None):
if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False))
if isinstance(after, datetime.datetime):
after = Object(id=time_snowflake(after, high=True))
if isinstance(around, datetime.datetime):
around = Object(id=time_snowflake(around))
if oldest_first is None:
self.reverse = after is not None
else:
self.reverse = oldest_first
self.messageable = messageable
self.limit = limit
self.before = before
self.after = after or OLDEST_OBJECT
self.around = around
self._filter = None # message dict -> bool
self.state = self.messageable._state
self.logs_from = self.state.http.logs_from
self.messages = asyncio.Queue()
if self.around:
if self.limit is None:
raise ValueError('history does not support around with limit=None')
if self.limit > 101:
raise ValueError("history max limit 101 when specifying around parameter")
elif self.limit == 101:
self.limit = 100 # Thanks discord
self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore
if self.before and self.after:
self._filter = lambda m: self.after.id < int(m['id']) < self.before.id
elif self.before:
self._filter = lambda m: int(m['id']) < self.before.id
elif self.after:
self._filter = lambda m: self.after.id < int(m['id'])
else:
if self.reverse:
self._retrieve_messages = self._retrieve_messages_after_strategy # type: ignore
if self.before:
self._filter = lambda m: int(m['id']) < self.before.id
else:
self._retrieve_messages = self._retrieve_messages_before_strategy # type: ignore
if self.after and self.after != OLDEST_OBJECT:
self._filter = lambda m: int(m['id']) > self.after.id
async def next(self) -> Message:
if self.messages.empty():
await self.fill_messages()
try:
return self.messages.get_nowait()
except asyncio.QueueEmpty:
raise NoMoreItems()
def _get_retrieve(self):
l = self.limit
if l is None or l > 100:
r = 100
else:
r = l
self.retrieve = r
return r > 0
async def fill_messages(self):
if not hasattr(self, 'channel'):
# do the required set up
channel = await self.messageable._get_channel()
self.channel = channel
if self._get_retrieve():
data = await self._retrieve_messages(self.retrieve)
if len(data) < 100:
self.limit = 0 # terminate the infinite loop
if self.reverse:
data = reversed(data)
if self._filter:
data = filter(self._filter, data)
channel = self.channel
for element in data:
await self.messages.put(self.state.create_message(channel=channel, data=element))
async def _retrieve_messages(self, retrieve) -> List[Message]:
"""Retrieve messages and update next parameters."""
raise NotImplementedError
async def _retrieve_messages_before_strategy(self, retrieve):
"""Retrieve messages using before parameter."""
before = self.before.id if self.before else None
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, before=before)
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(data[-1]['id']))
return data
async def _retrieve_messages_after_strategy(self, retrieve):
"""Retrieve messages using after parameter."""
after = self.after.id if self.after else None
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, after=after)
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(data[0]['id']))
return data
async def _retrieve_messages_around_strategy(self, retrieve):
"""Retrieve messages using around parameter."""
if self.around:
around = self.around.id if self.around else None
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, around=around)
self.around = None
return data
return []
class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None):
if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False))
if isinstance(after, datetime.datetime):
after = Object(id=time_snowflake(after, high=True))
if oldest_first is None:
self.reverse = after is not None
else:
self.reverse = oldest_first
self.guild = guild
self.loop = guild._state.loop
self.request = guild._state.http.get_audit_logs
self.limit = limit
self.before = before
self.user_id = user_id
self.action_type = action_type
self.after = OLDEST_OBJECT
self._users = {}
self._state = guild._state
self._filter = None # entry dict -> bool
self.entries = asyncio.Queue()
if self.reverse:
self._strategy = self._after_strategy
if self.before:
self._filter = lambda m: int(m['id']) < self.before.id
else:
self._strategy = self._before_strategy
if self.after and self.after != OLDEST_OBJECT:
self._filter = lambda m: int(m['id']) > self.after.id
async def _before_strategy(self, retrieve):
before = self.before.id if self.before else None
data: AuditLogPayload = await self.request(
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before
)
entries = data.get('audit_log_entries', [])
if len(data) and entries:
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(entries[-1]['id']))
return data.get('users', []), entries
async def _after_strategy(self, retrieve):
after = self.after.id if self.after else None
data: AuditLogPayload = await self.request(
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after
)
entries = data.get('audit_log_entries', [])
if len(data) and entries:
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(entries[0]['id']))
return data.get('users', []), entries
async def next(self) -> AuditLogEntry:
if self.entries.empty():
await self._fill()
try:
return self.entries.get_nowait()
except asyncio.QueueEmpty:
raise NoMoreItems()
def _get_retrieve(self):
l = self.limit
if l is None or l > 100:
r = 100
else:
r = l
self.retrieve = r
return r > 0
async def _fill(self):
from .user import User
if self._get_retrieve():
users, data = await self._strategy(self.retrieve)
if len(data) < 100:
self.limit = 0 # terminate the infinite loop
if self.reverse:
data = reversed(data)
if self._filter:
data = filter(self._filter, data)
for user in users:
u = User(data=user, state=self._state)
self._users[u.id] = u
for element in data:
# TODO: remove this if statement later
if element['action_type'] is None:
continue
await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild))
class GuildIterator(_AsyncIterator['Guild']):
"""Iterator for receiving the client's guilds.
The guilds endpoint has the same two behaviours as described
in :class:`HistoryIterator`:
If ``before`` is specified, the guilds endpoint returns the ``limit``
newest guilds before ``before``, sorted with newest first. For filling over
100 guilds, update the ``before`` parameter to the oldest guild received.
Guilds will be returned in order by time.
If `after` is specified, it returns the ``limit`` oldest guilds after ``after``,
sorted with newest first. For filling over 100 guilds, update the ``after``
parameter to the newest guild received, If guilds are not reversed, they
will be out of order (99-0, 199-100, so on)
Not that if both ``before`` and ``after`` are specified, ``before`` is ignored by the
guilds endpoint.
Parameters
-----------
bot: :class:`discord.Client`
The client to retrieve the guilds from.
limit: :class:`int`
Maximum number of guilds to retrieve.
before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
Object before which all guilds must be.
after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
Object after which all guilds must be.
"""
def __init__(self, bot, limit, before=None, after=None):
if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False))
if isinstance(after, datetime.datetime):
after = Object(id=time_snowflake(after, high=True))
self.bot = bot
self.limit = limit
self.before = before
self.after = after
self._filter = None
self.state = self.bot._connection
self.get_guilds = self.bot.http.get_guilds
self.guilds = asyncio.Queue()
if self.before and self.after:
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore
self._filter = lambda m: int(m['id']) > self.after.id
elif self.after:
self._retrieve_guilds = self._retrieve_guilds_after_strategy # type: ignore
else:
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore
async def next(self) -> Guild:
if self.guilds.empty():
await self.fill_guilds()
try:
return self.guilds.get_nowait()
except asyncio.QueueEmpty:
raise NoMoreItems()
def _get_retrieve(self):
l = self.limit
if l is None or l > 100:
r = 100
else:
r = l
self.retrieve = r
return r > 0
def create_guild(self, data):
from .guild import Guild
return Guild(state=self.state, data=data)
async def fill_guilds(self):
if self._get_retrieve():
data = await self._retrieve_guilds(self.retrieve)
if self.limit is None or len(data) < 100:
self.limit = 0
if self._filter:
data = filter(self._filter, data)
for element in data:
await self.guilds.put(self.create_guild(element))
async def _retrieve_guilds(self, retrieve) -> List[Guild]:
"""Retrieve guilds and update next parameters."""
raise NotImplementedError
async def _retrieve_guilds_before_strategy(self, retrieve):
"""Retrieve guilds using before parameter."""
before = self.before.id if self.before else None
data: List[GuildPayload] = await self.get_guilds(retrieve, before=before)
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(data[-1]['id']))
return data
async def _retrieve_guilds_after_strategy(self, retrieve):
"""Retrieve guilds using after parameter."""
after = self.after.id if self.after else None
data: List[GuildPayload] = await self.get_guilds(retrieve, after=after)
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(data[0]['id']))
return data
class MemberIterator(_AsyncIterator['Member']):
def __init__(self, guild, limit=1000, after=None):
if isinstance(after, datetime.datetime):
after = Object(id=time_snowflake(after, high=True))
self.guild = guild
self.limit = limit
self.after = after or OLDEST_OBJECT
self.state = self.guild._state
self.get_members = self.state.http.get_members
self.members = asyncio.Queue()
async def next(self) -> Member:
if self.members.empty():
await self.fill_members()
try:
return self.members.get_nowait()
except asyncio.QueueEmpty:
raise NoMoreItems()
def _get_retrieve(self):
l = self.limit
if l is None or l > 1000:
r = 1000
else:
r = l
self.retrieve = r
return r > 0
async def fill_members(self):
if self._get_retrieve():
after = self.after.id if self.after else None
data = await self.get_members(self.guild.id, self.retrieve, after)
if not data:
# no data, terminate
return
if len(data) < 1000:
self.limit = 0 # terminate loop
self.after = Object(id=int(data[-1]['user']['id']))
for element in reversed(data):
await self.members.put(self.create_member(element))
def create_member(self, data):
from .member import Member
return Member(data=data, guild=self.guild, state=self.state)
class ArchivedThreadIterator(_AsyncIterator['Thread']):
def __init__(
self,
channel_id: int,
guild: Guild,
limit: Optional[int],
joined: bool,
private: bool,
before: Optional[Union[Snowflake, datetime.datetime]] = None,
):
self.channel_id = channel_id
self.guild = guild
self.limit = limit
self.joined = joined
self.private = private
self.http = guild._state.http
if joined and not private:
raise ValueError('Cannot iterate over joined public archived threads')
self.before: Optional[str]
if before is None:
self.before = None
elif isinstance(before, datetime.datetime):
if joined:
self.before = str(time_snowflake(before, high=False))
else:
self.before = before.isoformat()
else:
if joined:
self.before = str(before.id)
else:
self.before = snowflake_time(before.id).isoformat()
self.update_before: Callable[[ThreadPayload], str] = self.get_archive_timestamp
if joined:
self.endpoint = self.http.get_joined_private_archived_threads
self.update_before = self.get_thread_id
elif private:
self.endpoint = self.http.get_private_archived_threads
else:
self.endpoint = self.http.get_public_archived_threads
self.queue: asyncio.Queue[Thread] = asyncio.Queue()
self.has_more: bool = True
async def next(self) -> Thread:
if self.queue.empty():
await self.fill_queue()
try:
return self.queue.get_nowait()
except asyncio.QueueEmpty:
raise NoMoreItems()
@staticmethod
def get_archive_timestamp(data: ThreadPayload) -> str:
return data['thread_metadata']['archive_timestamp']
@staticmethod
def get_thread_id(data: ThreadPayload) -> str:
return data['id'] # type: ignore
async def fill_queue(self) -> None:
if not self.has_more:
raise NoMoreItems()
limit = 50 if self.limit is None else max(self.limit, 50)
data = await self.endpoint(self.channel_id, before=self.before, limit=limit)
# This stuff is obviously WIP because 'members' is always empty
threads: List[ThreadPayload] = data.get('threads', [])
for d in reversed(threads):
self.queue.put_nowait(self.create_thread(d))
self.has_more = data.get('has_more', False)
if self.limit is not None:
self.limit -= len(threads)
if self.limit <= 0:
self.has_more = False
if self.has_more:
self.before = self.update_before(threads[-1])
def create_thread(self, data: ThreadPayload) -> Thread:
from .threads import Thread
return Thread(guild=self.guild, state=self.guild._state, data=data)

7
discord/object.py

@ -24,8 +24,8 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from . import utils
from .mixins import Hashable
from .utils import snowflake_time
from typing import (
SupportsInt,
@ -89,4 +89,7 @@ class Object(Hashable):
@property
def created_at(self) -> datetime.datetime:
""":class:`datetime.datetime`: Returns the snowflake's creation time in UTC."""
return utils.snowflake_time(self.id)
return snowflake_time(self.id)
OLDEST_OBJECT = Object(id=0)

42
discord/reaction.py

@ -23,15 +23,18 @@ DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Any, TYPE_CHECKING, Union, Optional
from typing import Any, TYPE_CHECKING, AsyncIterator, List, Union, Optional
from typing_extensions import reveal_type
from .iterators import ReactionIterator
from .object import Object
__all__ = (
'Reaction',
)
if TYPE_CHECKING:
from .user import User
from .member import Member
from .types.message import Reaction as ReactionPayload
from .message import Message
from .partial_emoji import PartialEmoji
@ -155,8 +158,8 @@ class Reaction:
"""
await self.message.clear_reaction(self.emoji)
def users(self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None) -> ReactionIterator:
"""Returns an :class:`AsyncIterator` representing the users that have reacted to the message.
async def users(self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None) -> AsyncIterator[Union[Member, User]]:
"""Returns an :term:`asynchronous iterator` representing the users that have reacted to the message.
The ``after`` parameter must represent a member
and meet the :class:`abc.Snowflake` abc.
@ -176,7 +179,7 @@ class Reaction:
Flattening into a list: ::
users = await reaction.users().flatten()
users = [user async for user in reaction.users()]
# users is now a list of User...
winner = random.choice(users)
await channel.send(f'{winner} has won the raffle.')
@ -212,4 +215,31 @@ class Reaction:
if limit is None:
limit = self.count
return ReactionIterator(self.message, emoji, limit, after)
while limit > 0:
retrieve = min(limit, 100)
message = self.message
guild = message.guild
state = message._state
after_id = after.id if after else None
data = await state.http.get_reaction_users(
message.channel.id, message.id, emoji, retrieve, after=after_id
)
if data:
limit -= len(data)
after = Object(id=int(data[-1]['id']))
if guild is None or isinstance(guild, Object):
for raw_user in reversed(data):
yield User(state=state, data=raw_user)
continue
for raw_user in reversed(data):
member_id = int(raw_user['id'])
member = guild.get_member(member_id)
yield member or User(state=state, data=raw_user)

129
docs/api.rst

@ -2614,135 +2614,6 @@ of :class:`enum.Enum`.
The guild may contain NSFW content.
Async Iterator
----------------
Some API functions return an "async iterator". An async iterator is something that is
capable of being used in an :ref:`async for statement <py:async for>`.
These async iterators can be used as follows: ::
async for elem in channel.history():
# do stuff with elem here
Certain utilities make working with async iterators easier, detailed below.
.. class:: AsyncIterator
Represents the "AsyncIterator" concept. Note that no such class exists,
it is purely abstract.
.. container:: operations
.. describe:: async for x in y
Iterates over the contents of the async iterator.
.. method:: next()
:async:
|coro|
Advances the iterator by one, if possible. If no more items are found
then this raises :exc:`NoMoreItems`.
.. method:: get(**attrs)
:async:
|coro|
Similar to :func:`utils.get` except run over the async iterator.
Getting the last message by a user named 'Dave' or ``None``: ::
msg = await channel.history().get(author__name='Dave')
.. method:: find(predicate)
:async:
|coro|
Similar to :func:`utils.find` except run over the async iterator.
Unlike :func:`utils.find`\, the predicate provided can be a
|coroutine_link|_.
Getting the last audit log with a reason or ``None``: ::
def predicate(event):
return event.reason is not None
event = await guild.audit_logs().find(predicate)
:param predicate: The predicate to use. Could be a |coroutine_link|_.
:return: The first element that returns ``True`` for the predicate or ``None``.
.. method:: flatten()
:async:
|coro|
Flattens the async iterator into a :class:`list` with all the elements.
:return: A list of every element in the async iterator.
:rtype: list
.. method:: chunk(max_size)
Collects items into chunks of up to a given maximum size.
Another :class:`AsyncIterator` is returned which collects items into
:class:`list`\s of a given size. The maximum chunk size must be a positive integer.
.. versionadded:: 1.6
Collecting groups of users: ::
async for leader, *users in reaction.users().chunk(3):
...
.. warning::
The last chunk collected may not be as large as ``max_size``.
:param max_size: The size of individual chunks.
:rtype: :class:`AsyncIterator`
.. method:: map(func)
This is similar to the built-in :func:`map <py:map>` function. Another
:class:`AsyncIterator` is returned that executes the function on
every element it is iterating over. This function can either be a
regular function or a |coroutine_link|_.
Creating a content iterator: ::
def transform(message):
return message.content
async for content in channel.history().map(transform):
message_length = len(content)
:param func: The function to call on every element. Could be a |coroutine_link|_.
:rtype: :class:`AsyncIterator`
.. method:: filter(predicate)
This is similar to the built-in :func:`filter <py:filter>` function. Another
:class:`AsyncIterator` is returned that filters over the original
async iterator. This predicate can be a regular function or a |coroutine_link|_.
Getting messages by non-bot accounts: ::
def predicate(message):
return not message.author.bot
async for elem in channel.history().filter(predicate):
...
:param predicate: The predicate to call on every element. Could be a |coroutine_link|_.
:rtype: :class:`AsyncIterator`
.. _discord-api-audit-logs:
Audit Log Data

Loading…
Cancel
Save