diff --git a/discord/iterators.py b/discord/iterators.py index d67f3006d..0bf474604 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -22,20 +22,43 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import asyncio import datetime +from typing import TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator, Coroutine from .errors import NoMoreItems from .utils import time_snowflake, maybe_coroutine from .object import Object from .audit_logs import AuditLogEntry +__all__ = ( + 'ReactionIterator', + 'HistoryIterator', + 'AuditLogIterator', + 'GuildIterator', + 'MemberIterator', +) + +if TYPE_CHECKING: + from .member import Member + from .user import User + from .message import Message + from .audit_logs import AuditLogEntry + from .guild import Guild + +T = TypeVar('T') +OT = TypeVar('OT') +_Func = Callable[[T], Union[OT, Coroutine[Any, Any, OT]]] +_Predicate = Callable[[T], Union[T, Coroutine[Any, Any, T]]] + OLDEST_OBJECT = Object(id=0) -class _AsyncIterator: +class _AsyncIterator(AsyncIterator[T]): __slots__ = () - def get(self, **attrs): + def get(self, **attrs: Any) -> Optional[T]: def predicate(elem): for attr, val in attrs.items(): nested = attr.split('__') @@ -49,7 +72,7 @@ class _AsyncIterator: return self.find(predicate) - async def find(self, predicate): + async def find(self, predicate: _Predicate[T]) -> Optional[T]: while True: try: elem = await self.next() @@ -60,40 +83,35 @@ class _AsyncIterator: if ret: return elem - def chunk(self, max_size): + def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]: if max_size <= 0: raise ValueError('async iterator chunk sizes must be greater than 0.') return _ChunkedAsyncIterator(self, max_size) - def map(self, func): + def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]: return _MappedAsyncIterator(self, func) - def filter(self, predicate): + def filter(self, predicate: _Predicate[T]) -> _FilteredAsyncIterator[T]: return _FilteredAsyncIterator(self, predicate) - async def flatten(self): + async def flatten(self) -> List[T]: return [element async for element in self] - def __aiter__(self): - return self - - async def __anext__(self): + async def __anext__(self) -> T: try: - msg = await self.next() + return await self.next() except NoMoreItems: raise StopAsyncIteration() - else: - return msg def _identity(x): return x -class _ChunkedAsyncIterator(_AsyncIterator): +class _ChunkedAsyncIterator(_AsyncIterator[T]): def __init__(self, iterator, max_size): self.iterator = iterator self.max_size = max_size - async def next(self): + async def next(self) -> T: ret = [] n = 0 while n < self.max_size: @@ -108,17 +126,17 @@ class _ChunkedAsyncIterator(_AsyncIterator): n += 1 return ret -class _MappedAsyncIterator(_AsyncIterator): +class _MappedAsyncIterator(_AsyncIterator[T]): def __init__(self, iterator, func): self.iterator = iterator self.func = func - async def next(self): + async def next(self) -> T: # this raises NoMoreItems and will propagate appropriately item = await self.iterator.next() return await maybe_coroutine(self.func, item) -class _FilteredAsyncIterator(_AsyncIterator): +class _FilteredAsyncIterator(_AsyncIterator[T]): def __init__(self, iterator, predicate): self.iterator = iterator @@ -127,7 +145,7 @@ class _FilteredAsyncIterator(_AsyncIterator): self.predicate = predicate - async def next(self): + async def next(self) -> T: getter = self.iterator.next pred = self.predicate while True: @@ -137,7 +155,7 @@ class _FilteredAsyncIterator(_AsyncIterator): if ret: return item -class ReactionIterator(_AsyncIterator): +class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): def __init__(self, message, emoji, limit=100, after=None): self.message = message self.limit = limit @@ -150,7 +168,7 @@ class ReactionIterator(_AsyncIterator): self.channel_id = message.channel.id self.users = asyncio.Queue() - async def next(self): + async def next(self) -> T: if self.users.empty(): await self.fill_users() @@ -185,7 +203,7 @@ class ReactionIterator(_AsyncIterator): else: await self.users.put(User(state=self.state, data=element)) -class HistoryIterator(_AsyncIterator): +class HistoryIterator(_AsyncIterator['Message']): """Iterator for receiving a channel's message history. The messages endpoint has two behaviours we care about here: @@ -271,7 +289,7 @@ class HistoryIterator(_AsyncIterator): if (self.after and self.after != OLDEST_OBJECT): self._filter = lambda m: int(m['id']) > self.after.id - async def next(self): + async def next(self) -> T: if self.messages.empty(): await self.fill_messages() @@ -342,7 +360,7 @@ class HistoryIterator(_AsyncIterator): return data return [] -class AuditLogIterator(_AsyncIterator): +class AuditLogIterator(_AsyncIterator['AuditLogEntry']): def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None): if isinstance(before, datetime.datetime): before = Object(id=time_snowflake(before, high=False)) @@ -404,7 +422,7 @@ class AuditLogIterator(_AsyncIterator): self.after = Object(id=int(entries[0]['id'])) return data.get('users', []), entries - async def next(self): + async def next(self) -> T: if self.entries.empty(): await self._fill() @@ -447,7 +465,7 @@ class AuditLogIterator(_AsyncIterator): await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild)) -class GuildIterator(_AsyncIterator): +class GuildIterator(_AsyncIterator['Guild']): """Iterator for receiving the client's guilds. The guilds endpoint has the same two behaviours as described @@ -501,7 +519,7 @@ class GuildIterator(_AsyncIterator): else: self._retrieve_guilds = self._retrieve_guilds_before_strategy - async def next(self): + async def next(self) -> T: if self.guilds.empty(): await self.fill_guilds() @@ -559,7 +577,7 @@ class GuildIterator(_AsyncIterator): self.after = Object(id=int(data[0]['id'])) return data -class MemberIterator(_AsyncIterator): +class MemberIterator(_AsyncIterator['Member']): def __init__(self, guild, limit=1000, after=None): if isinstance(after, datetime.datetime): @@ -573,7 +591,7 @@ class MemberIterator(_AsyncIterator): self.get_members = self.state.http.get_members self.members = asyncio.Queue() - async def next(self): + async def next(self) -> T: if self.members.empty(): await self.fill_members()