From eb6f5728e29f95377c8ac33b984922b1167718d1 Mon Sep 17 00:00:00 2001 From: James Hilton-Balfe Date: Sun, 20 Feb 2022 13:09:17 +0000 Subject: [PATCH] Add support for AsyncIterables in find and get --- discord/utils.py | 140 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 113 insertions(+), 27 deletions(-) diff --git a/discord/utils.py b/discord/utils.py index 68b14407d..364de56f6 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -26,11 +26,13 @@ from __future__ import annotations import array import asyncio import collections.abc +import inspect from typing import ( Any, AsyncIterable, AsyncIterator, Callable, + Coroutine, Dict, ForwardRef, Generic, @@ -141,6 +143,7 @@ else: T = TypeVar('T') T_co = TypeVar('T_co', covariant=True) _Iter = Union[Iterable[T], AsyncIterable[T]] +Coro = Coroutine[Any, Any, T] class CachedSlotProperty(Generic[T, T_co]): @@ -363,8 +366,30 @@ def time_snowflake(dt: datetime.datetime, high: bool = False) -> int: return (discord_millis << 22) + (2**22 - 1 if high else 0) -def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> Optional[T]: - """A helper to return the first element found in the sequence +def _find(predicate: Callable[[T], Any], iterable: Iterable[T], /) -> Optional[T]: + return next((element for element in iterable if predicate(element)), None) + + +async def _afind(predicate: Callable[[T], Any], iterable: AsyncIterable[T], /) -> Optional[T]: + async for element in iterable: + if predicate(element): + return element + + return None + + +@overload +def find(predicate: Callable[[T], Any], iterable: Iterable[T], /) -> Optional[T]: + ... + + +@overload +def find(predicate: Callable[[T], Any], iterable: AsyncIterable[T], /) -> Coro[Optional[T]]: + ... + + +def find(predicate: Callable[[T], Any], iterable: _Iter[T], /) -> Union[Optional[T], Coro[Optional[T]]]: + r"""A helper to return the first element found in the sequence that meets the predicate. For example: :: member = discord.utils.find(lambda m: m.name == 'Mighty', channel.guild.members) @@ -379,17 +404,77 @@ def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> Optional[T]: ----------- predicate A function that returns a boolean-like result. - seq: :class:`collections.abc.Iterable` - The iterable to search through. + iterable: Union[:class:`collections.abc.Iterable`, :class:`collections.abc.AsyncIterable`] + The iterable to search through. Using a :class:`collections.abc.AsyncIterable`, + makes this function return a :term:`coroutine`. + + .. versionchanged:: 2.0.0 + + Both parameters are now positional-only. + + .. versionchanged:: 2.0.0 + + The ``iterable`` parameter supports :term:`asynchronous iterable`\s. """ - for element in seq: - if predicate(element): - return element + return ( + _find(predicate, iterable) # type: ignore + if hasattr(iterable, '__iter__') # isinstance(iterable, collections.abc.Iterable) is too slow + else _afind(predicate, iterable) # type: ignore + ) + + +def _get(iterable: Iterable[T], /, **attrs: Any) -> Optional[T]: + # global -> local + _all = all + attrget = attrgetter + + # Special case the single element call + if len(attrs) == 1: + k, v = attrs.popitem() + pred = attrget(k.replace('__', '.')) + return next((elem for elem in iterable if pred(elem) == v), None) + + converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()] + for elem in iterable: + if _all(pred(elem) == value for pred, value in converted): + return elem + return None + + +async def _aget(iterable: AsyncIterable[T], /, **attrs: Any) -> Optional[T]: + # global -> local + _all = all + attrget = attrgetter + + # Special case the single element call + if len(attrs) == 1: + k, v = attrs.popitem() + pred = attrget(k.replace('__', '.')) + async for elem in iterable: + if pred(elem) == v: + return elem + return None + + converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()] + + async for elem in iterable: + if _all(pred(elem) == value for pred, value in converted): + return elem return None -def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]: +@overload +def get(iterable: Iterable[T], /, **attrs: Any) -> Optional[T]: + ... + + +@overload +def get(iterable: AsyncIterable[T], /, **attrs: Any) -> Coro[Optional[T]]: + ... + + +def get(iterable: _Iter[T], /, **attrs: Any) -> Union[Optional[T], Coro[Optional[T]]]: r"""A helper that returns the first element in the iterable that meets all the traits passed in ``attrs``. This is an alternative for :func:`~discord.utils.find`. @@ -425,33 +510,34 @@ def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]: channel = discord.utils.get(client.get_all_channels(), guild__name='Cool', name='general') + Async iterables: + + .. code-block:: python3 + + msg = await discord.utils.get(channel.history(), author__name='Dave') + Parameters ----------- - iterable - An iterable to search through. + iterable: Union[:class:`collections.abc.Iterable`, :class:`collections.abc.AsyncIterable`] + The iterable to search through. Using a :class:`collections.abc.AsyncIterable`, + makes this function return a :term:`coroutine`. \*\*attrs Keyword arguments that denote attributes to search with. - """ - # global -> local - _all = all - attrget = attrgetter + .. versionchanged:: 2.0 - # Special case the single element call - if len(attrs) == 1: - k, v = attrs.popitem() - pred = attrget(k.replace('__', '.')) - for elem in iterable: - if pred(elem) == v: - return elem - return None + The ``iterable`` parameter is now positional-only. - converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()] + .. versionchanged:: 2.0 - for elem in iterable: - if _all(pred(elem) == value for pred, value in converted): - return elem - return None + The ``iterable`` parameter supports :term:`asynchronous iterable`\s. + """ + + return ( + _get(iterable, **attrs) # type: ignore + if hasattr(iterable, '__iter__') # isinstance(iterable, collections.abc.Iterable) is too slow + else _aget(predicate, **attrs) # type: ignore + ) def _unique(iterable: Iterable[T]) -> List[T]: