|
|
@ -22,20 +22,43 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
|
|
|
DEALINGS IN THE SOFTWARE. |
|
|
|
""" |
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
|
|
import asyncio |
|
|
|
import datetime |
|
|
|
from typing import TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator, Coroutine |
|
|
|
|
|
|
|
from .errors import NoMoreItems |
|
|
|
from .utils import time_snowflake, maybe_coroutine |
|
|
|
from .object import Object |
|
|
|
from .audit_logs import AuditLogEntry |
|
|
|
|
|
|
|
__all__ = ( |
|
|
|
'ReactionIterator', |
|
|
|
'HistoryIterator', |
|
|
|
'AuditLogIterator', |
|
|
|
'GuildIterator', |
|
|
|
'MemberIterator', |
|
|
|
) |
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
from .member import Member |
|
|
|
from .user import User |
|
|
|
from .message import Message |
|
|
|
from .audit_logs import AuditLogEntry |
|
|
|
from .guild import Guild |
|
|
|
|
|
|
|
T = TypeVar('T') |
|
|
|
OT = TypeVar('OT') |
|
|
|
_Func = Callable[[T], Union[OT, Coroutine[Any, Any, OT]]] |
|
|
|
_Predicate = Callable[[T], Union[T, Coroutine[Any, Any, T]]] |
|
|
|
|
|
|
|
OLDEST_OBJECT = Object(id=0) |
|
|
|
|
|
|
|
class _AsyncIterator: |
|
|
|
class _AsyncIterator(AsyncIterator[T]): |
|
|
|
__slots__ = () |
|
|
|
|
|
|
|
def get(self, **attrs): |
|
|
|
def get(self, **attrs: Any) -> Optional[T]: |
|
|
|
def predicate(elem): |
|
|
|
for attr, val in attrs.items(): |
|
|
|
nested = attr.split('__') |
|
|
@ -49,7 +72,7 @@ class _AsyncIterator: |
|
|
|
|
|
|
|
return self.find(predicate) |
|
|
|
|
|
|
|
async def find(self, predicate): |
|
|
|
async def find(self, predicate: _Predicate[T]) -> Optional[T]: |
|
|
|
while True: |
|
|
|
try: |
|
|
|
elem = await self.next() |
|
|
@ -60,40 +83,35 @@ class _AsyncIterator: |
|
|
|
if ret: |
|
|
|
return elem |
|
|
|
|
|
|
|
def chunk(self, max_size): |
|
|
|
def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]: |
|
|
|
if max_size <= 0: |
|
|
|
raise ValueError('async iterator chunk sizes must be greater than 0.') |
|
|
|
return _ChunkedAsyncIterator(self, max_size) |
|
|
|
|
|
|
|
def map(self, func): |
|
|
|
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]: |
|
|
|
return _MappedAsyncIterator(self, func) |
|
|
|
|
|
|
|
def filter(self, predicate): |
|
|
|
def filter(self, predicate: _Predicate[T]) -> _FilteredAsyncIterator[T]: |
|
|
|
return _FilteredAsyncIterator(self, predicate) |
|
|
|
|
|
|
|
async def flatten(self): |
|
|
|
async def flatten(self) -> List[T]: |
|
|
|
return [element async for element in self] |
|
|
|
|
|
|
|
def __aiter__(self): |
|
|
|
return self |
|
|
|
|
|
|
|
async def __anext__(self): |
|
|
|
async def __anext__(self) -> T: |
|
|
|
try: |
|
|
|
msg = await self.next() |
|
|
|
return await self.next() |
|
|
|
except NoMoreItems: |
|
|
|
raise StopAsyncIteration() |
|
|
|
else: |
|
|
|
return msg |
|
|
|
|
|
|
|
def _identity(x): |
|
|
|
return x |
|
|
|
|
|
|
|
class _ChunkedAsyncIterator(_AsyncIterator): |
|
|
|
class _ChunkedAsyncIterator(_AsyncIterator[T]): |
|
|
|
def __init__(self, iterator, max_size): |
|
|
|
self.iterator = iterator |
|
|
|
self.max_size = max_size |
|
|
|
|
|
|
|
async def next(self): |
|
|
|
async def next(self) -> T: |
|
|
|
ret = [] |
|
|
|
n = 0 |
|
|
|
while n < self.max_size: |
|
|
@ -108,17 +126,17 @@ class _ChunkedAsyncIterator(_AsyncIterator): |
|
|
|
n += 1 |
|
|
|
return ret |
|
|
|
|
|
|
|
class _MappedAsyncIterator(_AsyncIterator): |
|
|
|
class _MappedAsyncIterator(_AsyncIterator[T]): |
|
|
|
def __init__(self, iterator, func): |
|
|
|
self.iterator = iterator |
|
|
|
self.func = func |
|
|
|
|
|
|
|
async def next(self): |
|
|
|
async def next(self) -> T: |
|
|
|
# this raises NoMoreItems and will propagate appropriately |
|
|
|
item = await self.iterator.next() |
|
|
|
return await maybe_coroutine(self.func, item) |
|
|
|
|
|
|
|
class _FilteredAsyncIterator(_AsyncIterator): |
|
|
|
class _FilteredAsyncIterator(_AsyncIterator[T]): |
|
|
|
def __init__(self, iterator, predicate): |
|
|
|
self.iterator = iterator |
|
|
|
|
|
|
@ -127,7 +145,7 @@ class _FilteredAsyncIterator(_AsyncIterator): |
|
|
|
|
|
|
|
self.predicate = predicate |
|
|
|
|
|
|
|
async def next(self): |
|
|
|
async def next(self) -> T: |
|
|
|
getter = self.iterator.next |
|
|
|
pred = self.predicate |
|
|
|
while True: |
|
|
@ -137,7 +155,7 @@ class _FilteredAsyncIterator(_AsyncIterator): |
|
|
|
if ret: |
|
|
|
return item |
|
|
|
|
|
|
|
class ReactionIterator(_AsyncIterator): |
|
|
|
class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): |
|
|
|
def __init__(self, message, emoji, limit=100, after=None): |
|
|
|
self.message = message |
|
|
|
self.limit = limit |
|
|
@ -150,7 +168,7 @@ class ReactionIterator(_AsyncIterator): |
|
|
|
self.channel_id = message.channel.id |
|
|
|
self.users = asyncio.Queue() |
|
|
|
|
|
|
|
async def next(self): |
|
|
|
async def next(self) -> T: |
|
|
|
if self.users.empty(): |
|
|
|
await self.fill_users() |
|
|
|
|
|
|
@ -185,7 +203,7 @@ class ReactionIterator(_AsyncIterator): |
|
|
|
else: |
|
|
|
await self.users.put(User(state=self.state, data=element)) |
|
|
|
|
|
|
|
class HistoryIterator(_AsyncIterator): |
|
|
|
class HistoryIterator(_AsyncIterator['Message']): |
|
|
|
"""Iterator for receiving a channel's message history. |
|
|
|
|
|
|
|
The messages endpoint has two behaviours we care about here: |
|
|
@ -271,7 +289,7 @@ class HistoryIterator(_AsyncIterator): |
|
|
|
if (self.after and self.after != OLDEST_OBJECT): |
|
|
|
self._filter = lambda m: int(m['id']) > self.after.id |
|
|
|
|
|
|
|
async def next(self): |
|
|
|
async def next(self) -> T: |
|
|
|
if self.messages.empty(): |
|
|
|
await self.fill_messages() |
|
|
|
|
|
|
@ -342,7 +360,7 @@ class HistoryIterator(_AsyncIterator): |
|
|
|
return data |
|
|
|
return [] |
|
|
|
|
|
|
|
class AuditLogIterator(_AsyncIterator): |
|
|
|
class AuditLogIterator(_AsyncIterator['AuditLogEntry']): |
|
|
|
def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None): |
|
|
|
if isinstance(before, datetime.datetime): |
|
|
|
before = Object(id=time_snowflake(before, high=False)) |
|
|
@ -404,7 +422,7 @@ class AuditLogIterator(_AsyncIterator): |
|
|
|
self.after = Object(id=int(entries[0]['id'])) |
|
|
|
return data.get('users', []), entries |
|
|
|
|
|
|
|
async def next(self): |
|
|
|
async def next(self) -> T: |
|
|
|
if self.entries.empty(): |
|
|
|
await self._fill() |
|
|
|
|
|
|
@ -447,7 +465,7 @@ class AuditLogIterator(_AsyncIterator): |
|
|
|
await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild)) |
|
|
|
|
|
|
|
|
|
|
|
class GuildIterator(_AsyncIterator): |
|
|
|
class GuildIterator(_AsyncIterator['Guild']): |
|
|
|
"""Iterator for receiving the client's guilds. |
|
|
|
|
|
|
|
The guilds endpoint has the same two behaviours as described |
|
|
@ -501,7 +519,7 @@ class GuildIterator(_AsyncIterator): |
|
|
|
else: |
|
|
|
self._retrieve_guilds = self._retrieve_guilds_before_strategy |
|
|
|
|
|
|
|
async def next(self): |
|
|
|
async def next(self) -> T: |
|
|
|
if self.guilds.empty(): |
|
|
|
await self.fill_guilds() |
|
|
|
|
|
|
@ -559,7 +577,7 @@ class GuildIterator(_AsyncIterator): |
|
|
|
self.after = Object(id=int(data[0]['id'])) |
|
|
|
return data |
|
|
|
|
|
|
|
class MemberIterator(_AsyncIterator): |
|
|
|
class MemberIterator(_AsyncIterator['Member']): |
|
|
|
def __init__(self, guild, limit=1000, after=None): |
|
|
|
|
|
|
|
if isinstance(after, datetime.datetime): |
|
|
@ -573,7 +591,7 @@ class MemberIterator(_AsyncIterator): |
|
|
|
self.get_members = self.state.http.get_members |
|
|
|
self.members = asyncio.Queue() |
|
|
|
|
|
|
|
async def next(self): |
|
|
|
async def next(self) -> T: |
|
|
|
if self.members.empty(): |
|
|
|
await self.fill_members() |
|
|
|
|
|
|
|