Browse Source

Add support for AsyncIterables in find and get

pull/7494/head
James Hilton-Balfe 3 years ago
committed by GitHub
parent
commit
eb6f5728e2
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 140
      discord/utils.py

140
discord/utils.py

@ -26,11 +26,13 @@ from __future__ import annotations
import array
import asyncio
import collections.abc
import inspect
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Callable,
Coroutine,
Dict,
ForwardRef,
Generic,
@ -141,6 +143,7 @@ else:
T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
_Iter = Union[Iterable[T], AsyncIterable[T]]
Coro = Coroutine[Any, Any, T]
class CachedSlotProperty(Generic[T, T_co]):
@ -363,8 +366,30 @@ def time_snowflake(dt: datetime.datetime, high: bool = False) -> int:
return (discord_millis << 22) + (2**22 - 1 if high else 0)
def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> Optional[T]:
"""A helper to return the first element found in the sequence
def _find(predicate: Callable[[T], Any], iterable: Iterable[T], /) -> Optional[T]:
return next((element for element in iterable if predicate(element)), None)
async def _afind(predicate: Callable[[T], Any], iterable: AsyncIterable[T], /) -> Optional[T]:
async for element in iterable:
if predicate(element):
return element
return None
@overload
def find(predicate: Callable[[T], Any], iterable: Iterable[T], /) -> Optional[T]:
...
@overload
def find(predicate: Callable[[T], Any], iterable: AsyncIterable[T], /) -> Coro[Optional[T]]:
...
def find(predicate: Callable[[T], Any], iterable: _Iter[T], /) -> Union[Optional[T], Coro[Optional[T]]]:
r"""A helper to return the first element found in the sequence
that meets the predicate. For example: ::
member = discord.utils.find(lambda m: m.name == 'Mighty', channel.guild.members)
@ -379,17 +404,77 @@ def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> Optional[T]:
-----------
predicate
A function that returns a boolean-like result.
seq: :class:`collections.abc.Iterable`
The iterable to search through.
iterable: Union[:class:`collections.abc.Iterable`, :class:`collections.abc.AsyncIterable`]
The iterable to search through. Using a :class:`collections.abc.AsyncIterable`,
makes this function return a :term:`coroutine`.
.. versionchanged:: 2.0.0
Both parameters are now positional-only.
.. versionchanged:: 2.0.0
The ``iterable`` parameter supports :term:`asynchronous iterable`\s.
"""
for element in seq:
if predicate(element):
return element
return (
_find(predicate, iterable) # type: ignore
if hasattr(iterable, '__iter__') # isinstance(iterable, collections.abc.Iterable) is too slow
else _afind(predicate, iterable) # type: ignore
)
def _get(iterable: Iterable[T], /, **attrs: Any) -> Optional[T]:
# global -> local
_all = all
attrget = attrgetter
# Special case the single element call
if len(attrs) == 1:
k, v = attrs.popitem()
pred = attrget(k.replace('__', '.'))
return next((elem for elem in iterable if pred(elem) == v), None)
converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()]
for elem in iterable:
if _all(pred(elem) == value for pred, value in converted):
return elem
return None
async def _aget(iterable: AsyncIterable[T], /, **attrs: Any) -> Optional[T]:
# global -> local
_all = all
attrget = attrgetter
# Special case the single element call
if len(attrs) == 1:
k, v = attrs.popitem()
pred = attrget(k.replace('__', '.'))
async for elem in iterable:
if pred(elem) == v:
return elem
return None
converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()]
async for elem in iterable:
if _all(pred(elem) == value for pred, value in converted):
return elem
return None
def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]:
@overload
def get(iterable: Iterable[T], /, **attrs: Any) -> Optional[T]:
...
@overload
def get(iterable: AsyncIterable[T], /, **attrs: Any) -> Coro[Optional[T]]:
...
def get(iterable: _Iter[T], /, **attrs: Any) -> Union[Optional[T], Coro[Optional[T]]]:
r"""A helper that returns the first element in the iterable that meets
all the traits passed in ``attrs``. This is an alternative for
:func:`~discord.utils.find`.
@ -425,33 +510,34 @@ def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]:
channel = discord.utils.get(client.get_all_channels(), guild__name='Cool', name='general')
Async iterables:
.. code-block:: python3
msg = await discord.utils.get(channel.history(), author__name='Dave')
Parameters
-----------
iterable
An iterable to search through.
iterable: Union[:class:`collections.abc.Iterable`, :class:`collections.abc.AsyncIterable`]
The iterable to search through. Using a :class:`collections.abc.AsyncIterable`,
makes this function return a :term:`coroutine`.
\*\*attrs
Keyword arguments that denote attributes to search with.
"""
# global -> local
_all = all
attrget = attrgetter
.. versionchanged:: 2.0
# Special case the single element call
if len(attrs) == 1:
k, v = attrs.popitem()
pred = attrget(k.replace('__', '.'))
for elem in iterable:
if pred(elem) == v:
return elem
return None
The ``iterable`` parameter is now positional-only.
converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()]
.. versionchanged:: 2.0
for elem in iterable:
if _all(pred(elem) == value for pred, value in converted):
return elem
return None
The ``iterable`` parameter supports :term:`asynchronous iterable`\s.
"""
return (
_get(iterable, **attrs) # type: ignore
if hasattr(iterable, '__iter__') # isinstance(iterable, collections.abc.Iterable) is too slow
else _aget(predicate, **attrs) # type: ignore
)
def _unique(iterable: Iterable[T]) -> List[T]:

Loading…
Cancel
Save