From f8bea3bb05fc9e3960d5ac4b2773e8dbe4f083c0 Mon Sep 17 00:00:00 2001 From: Nadir Chowdhury Date: Thu, 8 Apr 2021 01:28:12 +0100 Subject: [PATCH] Fix inaccuracies with `AsyncIterator` typings --- discord/iterators.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/discord/iterators.py b/discord/iterators.py index 0bf474604..d717d83f5 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -26,7 +26,7 @@ from __future__ import annotations import asyncio import datetime -from typing import TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator, Coroutine +from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator from .errors import NoMoreItems from .utils import time_snowflake, maybe_coroutine @@ -50,16 +50,18 @@ if TYPE_CHECKING: T = TypeVar('T') OT = TypeVar('OT') -_Func = Callable[[T], Union[OT, Coroutine[Any, Any, OT]]] -_Predicate = Callable[[T], Union[T, Coroutine[Any, Any, T]]] +_Func = Callable[[T], Union[OT, Awaitable[OT]]] OLDEST_OBJECT = Object(id=0) class _AsyncIterator(AsyncIterator[T]): __slots__ = () - def get(self, **attrs: Any) -> Optional[T]: - def predicate(elem): + async def next(self) -> T: + raise NotImplementedError + + def get(self, **attrs: Any) -> Awaitable[Optional[T]]: + def predicate(elem: T): for attr, val in attrs.items(): nested = attr.split('__') obj = elem @@ -72,7 +74,7 @@ class _AsyncIterator(AsyncIterator[T]): return self.find(predicate) - async def find(self, predicate: _Predicate[T]) -> Optional[T]: + async def find(self, predicate: _Func[T, bool]) -> Optional[T]: while True: try: elem = await self.next() @@ -91,7 +93,7 @@ class _AsyncIterator(AsyncIterator[T]): def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]: return _MappedAsyncIterator(self, func) - def filter(self, predicate: _Predicate[T]) -> _FilteredAsyncIterator[T]: + def filter(self, predicate: _Func[T, bool]) -> _FilteredAsyncIterator[T]: return _FilteredAsyncIterator(self, predicate) async def flatten(self) -> List[T]: @@ -106,13 +108,13 @@ class _AsyncIterator(AsyncIterator[T]): def _identity(x): return x -class _ChunkedAsyncIterator(_AsyncIterator[T]): +class _ChunkedAsyncIterator(_AsyncIterator[List[T]]): def __init__(self, iterator, max_size): self.iterator = iterator self.max_size = max_size - async def next(self) -> T: - ret = [] + async def next(self) -> List[T]: + ret: List[T] = [] n = 0 while n < self.max_size: try: @@ -168,7 +170,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): self.channel_id = message.channel.id self.users = asyncio.Queue() - async def next(self) -> T: + async def next(self) -> Union[User, Member]: if self.users.empty(): await self.fill_users() @@ -289,7 +291,7 @@ class HistoryIterator(_AsyncIterator['Message']): if (self.after and self.after != OLDEST_OBJECT): self._filter = lambda m: int(m['id']) > self.after.id - async def next(self) -> T: + async def next(self) -> Message: if self.messages.empty(): await self.fill_messages() @@ -422,7 +424,7 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']): self.after = Object(id=int(entries[0]['id'])) return data.get('users', []), entries - async def next(self) -> T: + async def next(self) -> AuditLogEntry: if self.entries.empty(): await self._fill() @@ -519,7 +521,7 @@ class GuildIterator(_AsyncIterator['Guild']): else: self._retrieve_guilds = self._retrieve_guilds_before_strategy - async def next(self) -> T: + async def next(self) -> Guild: if self.guilds.empty(): await self.fill_guilds() @@ -591,7 +593,7 @@ class MemberIterator(_AsyncIterator['Member']): self.get_members = self.state.http.get_members self.members = asyncio.Queue() - async def next(self) -> T: + async def next(self) -> Member: if self.members.empty(): await self.fill_members()