From b35596f7c8cbff9a31338a8deafb4ae9d911720d Mon Sep 17 00:00:00 2001
From: Rapptz <rapptz@gmail.com>
Date: Sun, 18 Apr 2021 08:43:09 -0400
Subject: [PATCH] Add typings for discord.utils

---
 discord/utils.py | 182 ++++++++++++++++++++++++++++++++++-------------
 1 file changed, 131 insertions(+), 51 deletions(-)

diff --git a/discord/utils.py b/discord/utils.py
index c756260e1..4650cfe7b 100644
--- a/discord/utils.py
+++ b/discord/utils.py
@@ -26,7 +26,23 @@ from __future__ import annotations
 import array
 import asyncio
 import collections.abc
-from typing import Any, Callable, Generic, Optional, Type, TypeVar, overload, TYPE_CHECKING
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Generic,
+    Iterable,
+    Iterator,
+    List,
+    Optional,
+    Protocol,
+    Sequence,
+    Type,
+    TypeVar,
+    Union,
+    overload,
+    TYPE_CHECKING,
+)
 import unicodedata
 from base64 import b64encode
 from bisect import bisect_left
@@ -52,8 +68,10 @@ __all__ = (
     'escape_markdown',
     'escape_mentions',
 )
+
 DISCORD_EPOCH = 1420070400000
 
+
 class cached_property:
     def __init__(self, function):
         self.function = function
@@ -68,13 +86,24 @@ class cached_property:
 
         return value
 
+
 if TYPE_CHECKING:
     from functools import cached_property
+    from .permissions import Permissions
+    from .abc import Snowflake
+    from .invite import Invite
+    from .template import Template
+    from types import FunctionType as _Func
+
+    class _RequestLike(Protocol):
+        headers: Dict[str, Any]
+
 
 T = TypeVar('T')
 T_co = TypeVar('T_co', covariant=True)
 CSP = TypeVar('CSP', bound='CachedSlotProperty')
 
+
 class CachedSlotProperty(Generic[T, T_co]):
     def __init__(self, name: str, function: Callable[[T], T_co]) -> None:
         self.name = name
@@ -100,74 +129,93 @@ class CachedSlotProperty(Generic[T, T_co]):
             setattr(instance, self.name, value)
             return value
 
+
 def cached_slot_property(name: str) -> Callable[[Callable[[T], T_co]], CachedSlotProperty[T, T_co]]:
     def decorator(func: Callable[[T], T_co]) -> CachedSlotProperty[T, T_co]:
         return CachedSlotProperty(name, func)
+
     return decorator
 
-class SequenceProxy(collections.abc.Sequence):
+
+class SequenceProxy(Generic[T_co], collections.abc.Sequence):
     """Read-only proxy of a Sequence."""
-    def __init__(self, proxied):
+
+    def __init__(self, proxied: Sequence[T_co]):
         self.__proxied = proxied
 
-    def __getitem__(self, idx):
+    def __getitem__(self, idx: int) -> T_co:
         return self.__proxied[idx]
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.__proxied)
 
-    def __contains__(self, item):
+    def __contains__(self, item: Any) -> bool:
         return item in self.__proxied
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[T_co]:
         return iter(self.__proxied)
 
-    def __reversed__(self):
+    def __reversed__(self) -> Iterator[T_co]:
         return reversed(self.__proxied)
 
-    def index(self, value, *args, **kwargs):
+    def index(self, value: Any, *args, **kwargs) -> int:
         return self.__proxied.index(value, *args, **kwargs)
 
-    def count(self, value):
+    def count(self, value: Any) -> int:
         return self.__proxied.count(value)
 
+
 @overload
 def parse_time(timestamp: None) -> None:
     ...
 
+
 @overload
 def parse_time(timestamp: str) -> datetime.datetime:
     ...
 
+
 def parse_time(timestamp: Optional[str]) -> Optional[datetime.datetime]:
     if timestamp:
         return datetime.datetime.fromisoformat(timestamp)
     return None
 
-def copy_doc(original):
-    def decorator(overriden):
+
+def copy_doc(original: _Func) -> Callable[[_Func], _Func]:
+    def decorator(overriden: _Func) -> _Func:
         overriden.__doc__ = original.__doc__
-        overriden.__signature__ = _signature(original)
+        overriden.__signature__ = _signature(original)  # type: ignore
         return overriden
+
     return decorator
 
-def deprecated(instead=None):
-    def actual_decorator(func):
+
+def deprecated(instead: Optional[str] = None) -> Callable[[Callable[..., T]], Callable[..., T]]:
+    def actual_decorator(func: Callable[..., T]) -> Callable[..., T]:
         @functools.wraps(func)
-        def decorated(*args, **kwargs):
-            warnings.simplefilter('always', DeprecationWarning) # turn off filter
+        def decorated(*args, **kwargs) -> T:
+            warnings.simplefilter('always', DeprecationWarning)  # turn off filter
             if instead:
                 fmt = "{0.__name__} is deprecated, use {1} instead."
             else:
                 fmt = '{0.__name__} is deprecated.'
 
             warnings.warn(fmt.format(func, instead), stacklevel=3, category=DeprecationWarning)
-            warnings.simplefilter('default', DeprecationWarning) # reset filter
+            warnings.simplefilter('default', DeprecationWarning)  # reset filter
             return func(*args, **kwargs)
+
         return decorated
+
     return actual_decorator
 
-def oauth_url(client_id, permissions=None, guild=None, redirect_uri=None, scopes=None):
+
+def oauth_url(
+    client_id: str,
+    permissions: Optional[Permissions] = None,
+    guild: Optional[Snowflake] = None,
+    redirect_uri: Optional[str] = None,
+    scopes: Optional[Iterable[str]] = None,
+):
     """A helper function that returns the OAuth2 URL for inviting the bot
     into guilds.
 
@@ -178,7 +226,7 @@ def oauth_url(client_id, permissions=None, guild=None, redirect_uri=None, scopes
     permissions: :class:`~discord.Permissions`
         The permissions you're requesting. If not given then you won't be requesting any
         permissions.
-    guild: :class:`~discord.Guild`
+    guild: :class:`~discord.abc.Snowflake`
         The guild to pre-select in the authorization screen, if available.
     redirect_uri: :class:`str`
         An optional valid redirect URI.
@@ -200,6 +248,7 @@ def oauth_url(client_id, permissions=None, guild=None, redirect_uri=None, scopes
         url = url + "&guild_id=" + str(guild.id)
     if redirect_uri is not None:
         from urllib.parse import urlencode
+
         url = url + "&response_type=code&" + urlencode({'redirect_uri': redirect_uri})
     return url
 
@@ -219,6 +268,7 @@ def snowflake_time(id: int) -> datetime.datetime:
     timestamp = ((id >> 22) + DISCORD_EPOCH) / 1000
     return datetime.datetime.utcfromtimestamp(timestamp).replace(tzinfo=datetime.timezone.utc)
 
+
 def time_snowflake(dt: datetime.datetime, high: bool = False) -> int:
     """Returns a numeric snowflake pretending to be created at the given date.
 
@@ -242,9 +292,10 @@ def time_snowflake(dt: datetime.datetime, high: bool = False) -> int:
         The snowflake representing the time given.
     """
     discord_millis = int(dt.timestamp() * 1000 - DISCORD_EPOCH)
-    return (discord_millis << 22) + (2**22-1 if high else 0)
+    return (discord_millis << 22) + (2 ** 22 - 1 if high else 0)
+
 
-def find(predicate, seq):
+def find(predicate: Callable[[T], bool], seq: Iterable[T]) -> Optional[T]:
     """A helper to return the first element found in the sequence
     that meets the predicate. For example: ::
 
@@ -269,7 +320,8 @@ def find(predicate, seq):
             return element
     return None
 
-def get(iterable, **attrs):
+
+def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]:
     r"""A helper that returns the first element in the iterable that meets
     all the traits passed in ``attrs``. This is an alternative for
     :func:`~discord.utils.find`.
@@ -326,21 +378,20 @@ def get(iterable, **attrs):
                 return elem
         return None
 
-    converted = [
-        (attrget(attr.replace('__', '.')), value)
-        for attr, value in attrs.items()
-    ]
+    converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()]
 
     for elem in iterable:
         if _all(pred(elem) == value for pred, value in converted):
             return elem
     return None
 
-def _unique(iterable):
+
+def _unique(iterable: Iterable[T]) -> List[T]:
     seen = set()
     adder = seen.add
     return [x for x in iterable if not (x in seen or adder(x))]
 
+
 def _get_as_snowflake(data: Any, key: str) -> Optional[int]:
     try:
         value = data[key]
@@ -349,7 +400,8 @@ def _get_as_snowflake(data: Any, key: str) -> Optional[int]:
     else:
         return value and int(value)
 
-def _get_mime_type_for_image(data):
+
+def _get_mime_type_for_image(data: bytes):
     if data.startswith(b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A'):
         return 'image/png'
     elif data[0:3] == b'\xff\xd8\xff' or data[6:10] in (b'JFIF', b'Exif'):
@@ -361,16 +413,19 @@ def _get_mime_type_for_image(data):
     else:
         raise InvalidArgument('Unsupported image type given')
 
-def _bytes_to_base64_data(data):
+
+def _bytes_to_base64_data(data: bytes) -> str:
     fmt = 'data:{mime};base64,{data}'
     mime = _get_mime_type_for_image(data)
     b64 = b64encode(data).decode('ascii')
     return fmt.format(mime=mime, data=b64)
 
-def to_json(obj):
+
+def to_json(obj: Any) -> str:
     return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
 
-def _parse_ratelimit_header(request, *, use_clock=False):
+
+def _parse_ratelimit_header(request: _RequestLike, *, use_clock: bool = False) -> float:
     reset_after = request.headers.get('X-Ratelimit-Reset-After')
     if use_clock or not reset_after:
         utc = datetime.timezone.utc
@@ -380,6 +435,7 @@ def _parse_ratelimit_header(request, *, use_clock=False):
     else:
         return float(reset_after)
 
+
 async def maybe_coroutine(f, *args, **kwargs):
     value = f(*args, **kwargs)
     if _isawaitable(value):
@@ -387,6 +443,7 @@ async def maybe_coroutine(f, *args, **kwargs):
     else:
         return value
 
+
 async def async_all(gen, *, check=_isawaitable):
     for elem in gen:
         if check(elem):
@@ -395,10 +452,9 @@ async def async_all(gen, *, check=_isawaitable):
             return False
     return True
 
+
 async def sane_wait_for(futures, *, timeout):
-    ensured = [
-        asyncio.ensure_future(fut) for fut in futures
-    ]
+    ensured = [asyncio.ensure_future(fut) for fut in futures]
     done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED)
 
     if len(pending) != 0:
@@ -406,7 +462,18 @@ async def sane_wait_for(futures, *, timeout):
 
     return done
 
-async def sleep_until(when, result=None):
+
+@overload
+async def sleep_until(when: datetime.datetime, result: None) -> None:
+    ...
+
+
+@overload
+async def sleep_until(when: datetime.datetime, result: T) -> T:
+    ...
+
+
+async def sleep_until(when: datetime.datetime, result: Optional[T] = None) -> Optional[T]:
     """|coro|
 
     Sleep until a specified time.
@@ -429,6 +496,7 @@ async def sleep_until(when, result=None):
     delta = (when - now).total_seconds()
     return await asyncio.sleep(max(delta, 0), result)
 
+
 def utcnow() -> datetime.datetime:
     """A helper function to return an aware UTC datetime representing the current time.
 
@@ -444,10 +512,12 @@ def utcnow() -> datetime.datetime:
     """
     return datetime.datetime.now(datetime.timezone.utc)
 
-def valid_icon_size(size):
+
+def valid_icon_size(size: int) -> bool:
     """Icons must be power of 2 within [16, 4096]."""
     return not size & (size - 1) and size in range(16, 4097)
 
+
 class SnowflakeList(array.array):
     """Internal data storage class to efficiently store a list of snowflakes.
 
@@ -462,24 +532,26 @@ class SnowflakeList(array.array):
 
     __slots__ = ()
 
-    def __new__(cls, data, *, is_sorted=False):
-        return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data))
+    def __new__(cls, data: Sequence[int], *, is_sorted: bool = False):
+        return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data))  # type: ignore
 
-    def add(self, element):
+    def add(self, element: int) -> None:
         i = bisect_left(self, element)
         self.insert(i, element)
 
-    def get(self, element):
+    def get(self, element: int) -> Optional[int]:
         i = bisect_left(self, element)
         return self[i] if i != len(self) and self[i] == element else None
 
-    def has(self, element):
+    def has(self, element: int) -> bool:
         i = bisect_left(self, element)
         return i != len(self) and self[i] == element
 
+
 _IS_ASCII = re.compile(r'^[\x00-\x7f]+$')
 
-def _string_width(string, *, _IS_ASCII=_IS_ASCII):
+
+def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int:
     """Returns string's width."""
     match = _IS_ASCII.match(string)
     if match:
@@ -489,7 +561,8 @@ def _string_width(string, *, _IS_ASCII=_IS_ASCII):
     func = unicodedata.east_asian_width
     return sum(2 if func(char) in UNICODE_WIDE_CHAR_TYPE else 1 for char in string)
 
-def resolve_invite(invite):
+
+def resolve_invite(invite: Union[Invite, str]) -> str:
     """
     Resolves an invite from a :class:`~discord.Invite`, URL or code.
 
@@ -504,6 +577,7 @@ def resolve_invite(invite):
         The invite code.
     """
     from .invite import Invite  # circular import
+
     if isinstance(invite, Invite):
         return invite.code
     else:
@@ -513,7 +587,8 @@ def resolve_invite(invite):
             return m.group(1)
     return invite
 
-def resolve_template(code):
+
+def resolve_template(code: Union[Template, str]) -> str:
     """
     Resolves a template code from a :class:`~discord.Template`, URL or code.
 
@@ -529,7 +604,8 @@ def resolve_template(code):
     :class:`str`
         The template code.
     """
-    from .template import Template # circular import
+    from .template import Template  # circular import
+
     if isinstance(code, Template):
         return code.code
     else:
@@ -539,8 +615,8 @@ def resolve_template(code):
             return m.group(1)
     return code
 
-_MARKDOWN_ESCAPE_SUBREGEX = '|'.join(r'\{0}(?=([\s\S]*((?<!\{0})\{0})))'.format(c)
-                                     for c in ('*', '`', '_', '~', '|'))
+
+_MARKDOWN_ESCAPE_SUBREGEX = '|'.join(r'\{0}(?=([\s\S]*((?<!\{0})\{0})))'.format(c) for c in ('*', '`', '_', '~', '|'))
 
 _MARKDOWN_ESCAPE_COMMON = r'^>(?:>>)?\s|\[.+\]\(.+\)'
 
@@ -550,7 +626,8 @@ _URL_REGEX = r'(?P<url><[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\
 
 _MARKDOWN_STOCK_REGEX = fr'(?P<markdown>[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})'
 
-def remove_markdown(text, *, ignore_links=True):
+
+def remove_markdown(text: str, *, ignore_links: bool = True) -> str:
     """A helper function that removes markdown characters.
 
     .. versionadded:: 1.7
@@ -583,7 +660,8 @@ def remove_markdown(text, *, ignore_links=True):
         regex = f'(?:{_URL_REGEX}|{regex})'
     return re.sub(regex, replacement, text, 0, re.MULTILINE)
 
-def escape_markdown(text, *, as_needed=False, ignore_links=True):
+
+def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool = True) -> str:
     r"""A helper function that escapes Discord's markdown.
 
     Parameters
@@ -609,6 +687,7 @@ def escape_markdown(text, *, as_needed=False, ignore_links=True):
     """
 
     if not as_needed:
+
         def replacement(match):
             groupdict = match.groupdict()
             is_url = groupdict.get('url')
@@ -624,7 +703,8 @@ def escape_markdown(text, *, as_needed=False, ignore_links=True):
         text = re.sub(r'\\', r'\\\\', text)
         return _MARKDOWN_ESCAPE_REGEX.sub(r'\\\1', text)
 
-def escape_mentions(text):
+
+def escape_mentions(text: str) -> str:
     """A helper function that escapes everyone, here, role, and user mentions.
 
     .. note::