From 3433e13848112c0b26b396cd3c50973743442ef1 Mon Sep 17 00:00:00 2001 From: James Hilton-Balfe Date: Fri, 8 Jul 2022 00:30:21 +0100 Subject: [PATCH] Prioritise async iteration before sync iteration in utils.find/get --- discord/utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/discord/utils.py b/discord/utils.py index 67b3b7fa7..69041ffcc 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -398,12 +398,12 @@ async def _afind(predicate: Callable[[T], Any], iterable: AsyncIterable[T], /) - @overload -def find(predicate: Callable[[T], Any], iterable: Iterable[T], /) -> Optional[T]: +def find(predicate: Callable[[T], Any], iterable: AsyncIterable[T], /) -> Coro[Optional[T]]: ... @overload -def find(predicate: Callable[[T], Any], iterable: AsyncIterable[T], /) -> Coro[Optional[T]]: +def find(predicate: Callable[[T], Any], iterable: Iterable[T], /) -> Optional[T]: ... @@ -437,9 +437,9 @@ def find(predicate: Callable[[T], Any], iterable: _Iter[T], /) -> Union[Optional """ return ( - _find(predicate, iterable) # type: ignore - if hasattr(iterable, '__iter__') # isinstance(iterable, collections.abc.Iterable) is too slow - else _afind(predicate, iterable) # type: ignore + _afind(predicate, iterable) # type: ignore + if hasattr(iterable, '__aiter__') # isinstance(iterable, collections.abc.AsyncIterable) is too slow + else _find(predicate, iterable) # type: ignore ) @@ -484,12 +484,12 @@ async def _aget(iterable: AsyncIterable[T], /, **attrs: Any) -> Optional[T]: @overload -def get(iterable: Iterable[T], /, **attrs: Any) -> Optional[T]: +def get(iterable: AsyncIterable[T], /, **attrs: Any) -> Coro[Optional[T]]: ... @overload -def get(iterable: AsyncIterable[T], /, **attrs: Any) -> Coro[Optional[T]]: +def get(iterable: Iterable[T], /, **attrs: Any) -> Optional[T]: ... @@ -553,9 +553,9 @@ def get(iterable: _Iter[T], /, **attrs: Any) -> Union[Optional[T], Coro[Optional """ return ( - _get(iterable, **attrs) # type: ignore - if hasattr(iterable, '__iter__') # isinstance(iterable, collections.abc.Iterable) is too slow - else _aget(iterable, **attrs) # type: ignore + _aget(iterable, **attrs) # type: ignore + if hasattr(iterable, '__aiter__') # isinstance(iterable, collections.abc.AsyncIterable) is too slow + else _get(iterable, **attrs) # type: ignore )