Browse Source

use `typing.AsyncIterator` for iterators

pull/6664/head
Nadir Chowdhury 4 years ago
committed by GitHub
parent
commit
9f0c701a7a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 78
      discord/iterators.py

78
discord/iterators.py

@ -22,20 +22,43 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import asyncio import asyncio
import datetime import datetime
from typing import TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator, Coroutine
from .errors import NoMoreItems from .errors import NoMoreItems
from .utils import time_snowflake, maybe_coroutine from .utils import time_snowflake, maybe_coroutine
from .object import Object from .object import Object
from .audit_logs import AuditLogEntry from .audit_logs import AuditLogEntry
__all__ = (
'ReactionIterator',
'HistoryIterator',
'AuditLogIterator',
'GuildIterator',
'MemberIterator',
)
if TYPE_CHECKING:
from .member import Member
from .user import User
from .message import Message
from .audit_logs import AuditLogEntry
from .guild import Guild
T = TypeVar('T')
OT = TypeVar('OT')
_Func = Callable[[T], Union[OT, Coroutine[Any, Any, OT]]]
_Predicate = Callable[[T], Union[T, Coroutine[Any, Any, T]]]
OLDEST_OBJECT = Object(id=0) OLDEST_OBJECT = Object(id=0)
class _AsyncIterator: class _AsyncIterator(AsyncIterator[T]):
__slots__ = () __slots__ = ()
def get(self, **attrs): def get(self, **attrs: Any) -> Optional[T]:
def predicate(elem): def predicate(elem):
for attr, val in attrs.items(): for attr, val in attrs.items():
nested = attr.split('__') nested = attr.split('__')
@ -49,7 +72,7 @@ class _AsyncIterator:
return self.find(predicate) return self.find(predicate)
async def find(self, predicate): async def find(self, predicate: _Predicate[T]) -> Optional[T]:
while True: while True:
try: try:
elem = await self.next() elem = await self.next()
@ -60,40 +83,35 @@ class _AsyncIterator:
if ret: if ret:
return elem return elem
def chunk(self, max_size): def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]:
if max_size <= 0: if max_size <= 0:
raise ValueError('async iterator chunk sizes must be greater than 0.') raise ValueError('async iterator chunk sizes must be greater than 0.')
return _ChunkedAsyncIterator(self, max_size) return _ChunkedAsyncIterator(self, max_size)
def map(self, func): def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
return _MappedAsyncIterator(self, func) return _MappedAsyncIterator(self, func)
def filter(self, predicate): def filter(self, predicate: _Predicate[T]) -> _FilteredAsyncIterator[T]:
return _FilteredAsyncIterator(self, predicate) return _FilteredAsyncIterator(self, predicate)
async def flatten(self): async def flatten(self) -> List[T]:
return [element async for element in self] return [element async for element in self]
def __aiter__(self): async def __anext__(self) -> T:
return self
async def __anext__(self):
try: try:
msg = await self.next() return await self.next()
except NoMoreItems: except NoMoreItems:
raise StopAsyncIteration() raise StopAsyncIteration()
else:
return msg
def _identity(x): def _identity(x):
return x return x
class _ChunkedAsyncIterator(_AsyncIterator): class _ChunkedAsyncIterator(_AsyncIterator[T]):
def __init__(self, iterator, max_size): def __init__(self, iterator, max_size):
self.iterator = iterator self.iterator = iterator
self.max_size = max_size self.max_size = max_size
async def next(self): async def next(self) -> T:
ret = [] ret = []
n = 0 n = 0
while n < self.max_size: while n < self.max_size:
@ -108,17 +126,17 @@ class _ChunkedAsyncIterator(_AsyncIterator):
n += 1 n += 1
return ret return ret
class _MappedAsyncIterator(_AsyncIterator): class _MappedAsyncIterator(_AsyncIterator[T]):
def __init__(self, iterator, func): def __init__(self, iterator, func):
self.iterator = iterator self.iterator = iterator
self.func = func self.func = func
async def next(self): async def next(self) -> T:
# this raises NoMoreItems and will propagate appropriately # this raises NoMoreItems and will propagate appropriately
item = await self.iterator.next() item = await self.iterator.next()
return await maybe_coroutine(self.func, item) return await maybe_coroutine(self.func, item)
class _FilteredAsyncIterator(_AsyncIterator): class _FilteredAsyncIterator(_AsyncIterator[T]):
def __init__(self, iterator, predicate): def __init__(self, iterator, predicate):
self.iterator = iterator self.iterator = iterator
@ -127,7 +145,7 @@ class _FilteredAsyncIterator(_AsyncIterator):
self.predicate = predicate self.predicate = predicate
async def next(self): async def next(self) -> T:
getter = self.iterator.next getter = self.iterator.next
pred = self.predicate pred = self.predicate
while True: while True:
@ -137,7 +155,7 @@ class _FilteredAsyncIterator(_AsyncIterator):
if ret: if ret:
return item return item
class ReactionIterator(_AsyncIterator): class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
def __init__(self, message, emoji, limit=100, after=None): def __init__(self, message, emoji, limit=100, after=None):
self.message = message self.message = message
self.limit = limit self.limit = limit
@ -150,7 +168,7 @@ class ReactionIterator(_AsyncIterator):
self.channel_id = message.channel.id self.channel_id = message.channel.id
self.users = asyncio.Queue() self.users = asyncio.Queue()
async def next(self): async def next(self) -> T:
if self.users.empty(): if self.users.empty():
await self.fill_users() await self.fill_users()
@ -185,7 +203,7 @@ class ReactionIterator(_AsyncIterator):
else: else:
await self.users.put(User(state=self.state, data=element)) await self.users.put(User(state=self.state, data=element))
class HistoryIterator(_AsyncIterator): class HistoryIterator(_AsyncIterator['Message']):
"""Iterator for receiving a channel's message history. """Iterator for receiving a channel's message history.
The messages endpoint has two behaviours we care about here: The messages endpoint has two behaviours we care about here:
@ -271,7 +289,7 @@ class HistoryIterator(_AsyncIterator):
if (self.after and self.after != OLDEST_OBJECT): if (self.after and self.after != OLDEST_OBJECT):
self._filter = lambda m: int(m['id']) > self.after.id self._filter = lambda m: int(m['id']) > self.after.id
async def next(self): async def next(self) -> T:
if self.messages.empty(): if self.messages.empty():
await self.fill_messages() await self.fill_messages()
@ -342,7 +360,7 @@ class HistoryIterator(_AsyncIterator):
return data return data
return [] return []
class AuditLogIterator(_AsyncIterator): class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None): def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None):
if isinstance(before, datetime.datetime): if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False)) before = Object(id=time_snowflake(before, high=False))
@ -404,7 +422,7 @@ class AuditLogIterator(_AsyncIterator):
self.after = Object(id=int(entries[0]['id'])) self.after = Object(id=int(entries[0]['id']))
return data.get('users', []), entries return data.get('users', []), entries
async def next(self): async def next(self) -> T:
if self.entries.empty(): if self.entries.empty():
await self._fill() await self._fill()
@ -447,7 +465,7 @@ class AuditLogIterator(_AsyncIterator):
await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild)) await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild))
class GuildIterator(_AsyncIterator): class GuildIterator(_AsyncIterator['Guild']):
"""Iterator for receiving the client's guilds. """Iterator for receiving the client's guilds.
The guilds endpoint has the same two behaviours as described The guilds endpoint has the same two behaviours as described
@ -501,7 +519,7 @@ class GuildIterator(_AsyncIterator):
else: else:
self._retrieve_guilds = self._retrieve_guilds_before_strategy self._retrieve_guilds = self._retrieve_guilds_before_strategy
async def next(self): async def next(self) -> T:
if self.guilds.empty(): if self.guilds.empty():
await self.fill_guilds() await self.fill_guilds()
@ -559,7 +577,7 @@ class GuildIterator(_AsyncIterator):
self.after = Object(id=int(data[0]['id'])) self.after = Object(id=int(data[0]['id']))
return data return data
class MemberIterator(_AsyncIterator): class MemberIterator(_AsyncIterator['Member']):
def __init__(self, guild, limit=1000, after=None): def __init__(self, guild, limit=1000, after=None):
if isinstance(after, datetime.datetime): if isinstance(after, datetime.datetime):
@ -573,7 +591,7 @@ class MemberIterator(_AsyncIterator):
self.get_members = self.state.http.get_members self.get_members = self.state.http.get_members
self.members = asyncio.Queue() self.members = asyncio.Queue()
async def next(self): async def next(self) -> T:
if self.members.empty(): if self.members.empty():
await self.fill_members() await self.fill_members()

Loading…
Cancel
Save