|
|
@ -26,7 +26,7 @@ from __future__ import annotations |
|
|
|
|
|
|
|
import asyncio |
|
|
|
import datetime |
|
|
|
from typing import TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator, Coroutine |
|
|
|
from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator |
|
|
|
|
|
|
|
from .errors import NoMoreItems |
|
|
|
from .utils import time_snowflake, maybe_coroutine |
|
|
@ -50,16 +50,18 @@ if TYPE_CHECKING: |
|
|
|
|
|
|
|
T = TypeVar('T') |
|
|
|
OT = TypeVar('OT') |
|
|
|
_Func = Callable[[T], Union[OT, Coroutine[Any, Any, OT]]] |
|
|
|
_Predicate = Callable[[T], Union[T, Coroutine[Any, Any, T]]] |
|
|
|
_Func = Callable[[T], Union[OT, Awaitable[OT]]] |
|
|
|
|
|
|
|
OLDEST_OBJECT = Object(id=0) |
|
|
|
|
|
|
|
class _AsyncIterator(AsyncIterator[T]): |
|
|
|
__slots__ = () |
|
|
|
|
|
|
|
def get(self, **attrs: Any) -> Optional[T]: |
|
|
|
def predicate(elem): |
|
|
|
async def next(self) -> T: |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def get(self, **attrs: Any) -> Awaitable[Optional[T]]: |
|
|
|
def predicate(elem: T): |
|
|
|
for attr, val in attrs.items(): |
|
|
|
nested = attr.split('__') |
|
|
|
obj = elem |
|
|
@ -72,7 +74,7 @@ class _AsyncIterator(AsyncIterator[T]): |
|
|
|
|
|
|
|
return self.find(predicate) |
|
|
|
|
|
|
|
async def find(self, predicate: _Predicate[T]) -> Optional[T]: |
|
|
|
async def find(self, predicate: _Func[T, bool]) -> Optional[T]: |
|
|
|
while True: |
|
|
|
try: |
|
|
|
elem = await self.next() |
|
|
@ -91,7 +93,7 @@ class _AsyncIterator(AsyncIterator[T]): |
|
|
|
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]: |
|
|
|
return _MappedAsyncIterator(self, func) |
|
|
|
|
|
|
|
def filter(self, predicate: _Predicate[T]) -> _FilteredAsyncIterator[T]: |
|
|
|
def filter(self, predicate: _Func[T, bool]) -> _FilteredAsyncIterator[T]: |
|
|
|
return _FilteredAsyncIterator(self, predicate) |
|
|
|
|
|
|
|
async def flatten(self) -> List[T]: |
|
|
@ -106,13 +108,13 @@ class _AsyncIterator(AsyncIterator[T]): |
|
|
|
def _identity(x): |
|
|
|
return x |
|
|
|
|
|
|
|
class _ChunkedAsyncIterator(_AsyncIterator[T]): |
|
|
|
class _ChunkedAsyncIterator(_AsyncIterator[List[T]]): |
|
|
|
def __init__(self, iterator, max_size): |
|
|
|
self.iterator = iterator |
|
|
|
self.max_size = max_size |
|
|
|
|
|
|
|
async def next(self) -> T: |
|
|
|
ret = [] |
|
|
|
async def next(self) -> List[T]: |
|
|
|
ret: List[T] = [] |
|
|
|
n = 0 |
|
|
|
while n < self.max_size: |
|
|
|
try: |
|
|
@ -168,7 +170,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): |
|
|
|
self.channel_id = message.channel.id |
|
|
|
self.users = asyncio.Queue() |
|
|
|
|
|
|
|
async def next(self) -> T: |
|
|
|
async def next(self) -> Union[User, Member]: |
|
|
|
if self.users.empty(): |
|
|
|
await self.fill_users() |
|
|
|
|
|
|
@ -289,7 +291,7 @@ class HistoryIterator(_AsyncIterator['Message']): |
|
|
|
if (self.after and self.after != OLDEST_OBJECT): |
|
|
|
self._filter = lambda m: int(m['id']) > self.after.id |
|
|
|
|
|
|
|
async def next(self) -> T: |
|
|
|
async def next(self) -> Message: |
|
|
|
if self.messages.empty(): |
|
|
|
await self.fill_messages() |
|
|
|
|
|
|
@ -422,7 +424,7 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']): |
|
|
|
self.after = Object(id=int(entries[0]['id'])) |
|
|
|
return data.get('users', []), entries |
|
|
|
|
|
|
|
async def next(self) -> T: |
|
|
|
async def next(self) -> AuditLogEntry: |
|
|
|
if self.entries.empty(): |
|
|
|
await self._fill() |
|
|
|
|
|
|
@ -519,7 +521,7 @@ class GuildIterator(_AsyncIterator['Guild']): |
|
|
|
else: |
|
|
|
self._retrieve_guilds = self._retrieve_guilds_before_strategy |
|
|
|
|
|
|
|
async def next(self) -> T: |
|
|
|
async def next(self) -> Guild: |
|
|
|
if self.guilds.empty(): |
|
|
|
await self.fill_guilds() |
|
|
|
|
|
|
@ -591,7 +593,7 @@ class MemberIterator(_AsyncIterator['Member']): |
|
|
|
self.get_members = self.state.http.get_members |
|
|
|
self.members = asyncio.Queue() |
|
|
|
|
|
|
|
async def next(self) -> T: |
|
|
|
async def next(self) -> Member: |
|
|
|
if self.members.empty(): |
|
|
|
await self.fill_members() |
|
|
|
|
|
|
|