Browse Source

Fix inaccuracies with `AsyncIterator` typings

pull/6671/head
Nadir Chowdhury 4 years ago
committed by GitHub
parent
commit
f8bea3bb05
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 32
      discord/iterators.py

32
discord/iterators.py

@ -26,7 +26,7 @@ from __future__ import annotations
import asyncio
import datetime
from typing import TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator, Coroutine
from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator
from .errors import NoMoreItems
from .utils import time_snowflake, maybe_coroutine
@ -50,16 +50,18 @@ if TYPE_CHECKING:
T = TypeVar('T')
OT = TypeVar('OT')
_Func = Callable[[T], Union[OT, Coroutine[Any, Any, OT]]]
_Predicate = Callable[[T], Union[T, Coroutine[Any, Any, T]]]
_Func = Callable[[T], Union[OT, Awaitable[OT]]]
OLDEST_OBJECT = Object(id=0)
class _AsyncIterator(AsyncIterator[T]):
__slots__ = ()
def get(self, **attrs: Any) -> Optional[T]:
def predicate(elem):
async def next(self) -> T:
raise NotImplementedError
def get(self, **attrs: Any) -> Awaitable[Optional[T]]:
def predicate(elem: T):
for attr, val in attrs.items():
nested = attr.split('__')
obj = elem
@ -72,7 +74,7 @@ class _AsyncIterator(AsyncIterator[T]):
return self.find(predicate)
async def find(self, predicate: _Predicate[T]) -> Optional[T]:
async def find(self, predicate: _Func[T, bool]) -> Optional[T]:
while True:
try:
elem = await self.next()
@ -91,7 +93,7 @@ class _AsyncIterator(AsyncIterator[T]):
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
return _MappedAsyncIterator(self, func)
def filter(self, predicate: _Predicate[T]) -> _FilteredAsyncIterator[T]:
def filter(self, predicate: _Func[T, bool]) -> _FilteredAsyncIterator[T]:
return _FilteredAsyncIterator(self, predicate)
async def flatten(self) -> List[T]:
@ -106,13 +108,13 @@ class _AsyncIterator(AsyncIterator[T]):
def _identity(x):
return x
class _ChunkedAsyncIterator(_AsyncIterator[T]):
class _ChunkedAsyncIterator(_AsyncIterator[List[T]]):
def __init__(self, iterator, max_size):
self.iterator = iterator
self.max_size = max_size
async def next(self) -> T:
ret = []
async def next(self) -> List[T]:
ret: List[T] = []
n = 0
while n < self.max_size:
try:
@ -168,7 +170,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
self.channel_id = message.channel.id
self.users = asyncio.Queue()
async def next(self) -> T:
async def next(self) -> Union[User, Member]:
if self.users.empty():
await self.fill_users()
@ -289,7 +291,7 @@ class HistoryIterator(_AsyncIterator['Message']):
if (self.after and self.after != OLDEST_OBJECT):
self._filter = lambda m: int(m['id']) > self.after.id
async def next(self) -> T:
async def next(self) -> Message:
if self.messages.empty():
await self.fill_messages()
@ -422,7 +424,7 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
self.after = Object(id=int(entries[0]['id']))
return data.get('users', []), entries
async def next(self) -> T:
async def next(self) -> AuditLogEntry:
if self.entries.empty():
await self._fill()
@ -519,7 +521,7 @@ class GuildIterator(_AsyncIterator['Guild']):
else:
self._retrieve_guilds = self._retrieve_guilds_before_strategy
async def next(self) -> T:
async def next(self) -> Guild:
if self.guilds.empty():
await self.fill_guilds()
@ -591,7 +593,7 @@ class MemberIterator(_AsyncIterator['Member']):
self.get_members = self.state.http.get_members
self.members = asyncio.Queue()
async def next(self) -> T:
async def next(self) -> Member:
if self.members.empty():
await self.fill_members()

Loading…
Cancel
Save