From b35596f7c8cbff9a31338a8deafb4ae9d911720d Mon Sep 17 00:00:00 2001 From: Rapptz <rapptz@gmail.com> Date: Sun, 18 Apr 2021 08:43:09 -0400 Subject: [PATCH] Add typings for discord.utils --- discord/utils.py | 182 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 131 insertions(+), 51 deletions(-) diff --git a/discord/utils.py b/discord/utils.py index c756260e1..4650cfe7b 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -26,7 +26,23 @@ from __future__ import annotations import array import asyncio import collections.abc -from typing import Any, Callable, Generic, Optional, Type, TypeVar, overload, TYPE_CHECKING +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + Iterator, + List, + Optional, + Protocol, + Sequence, + Type, + TypeVar, + Union, + overload, + TYPE_CHECKING, +) import unicodedata from base64 import b64encode from bisect import bisect_left @@ -52,8 +68,10 @@ __all__ = ( 'escape_markdown', 'escape_mentions', ) + DISCORD_EPOCH = 1420070400000 + class cached_property: def __init__(self, function): self.function = function @@ -68,13 +86,24 @@ class cached_property: return value + if TYPE_CHECKING: from functools import cached_property + from .permissions import Permissions + from .abc import Snowflake + from .invite import Invite + from .template import Template + from types import FunctionType as _Func + + class _RequestLike(Protocol): + headers: Dict[str, Any] + T = TypeVar('T') T_co = TypeVar('T_co', covariant=True) CSP = TypeVar('CSP', bound='CachedSlotProperty') + class CachedSlotProperty(Generic[T, T_co]): def __init__(self, name: str, function: Callable[[T], T_co]) -> None: self.name = name @@ -100,74 +129,93 @@ class CachedSlotProperty(Generic[T, T_co]): setattr(instance, self.name, value) return value + def cached_slot_property(name: str) -> Callable[[Callable[[T], T_co]], CachedSlotProperty[T, T_co]]: def decorator(func: Callable[[T], T_co]) -> CachedSlotProperty[T, T_co]: return CachedSlotProperty(name, func) + return decorator -class SequenceProxy(collections.abc.Sequence): + +class SequenceProxy(Generic[T_co], collections.abc.Sequence): """Read-only proxy of a Sequence.""" - def __init__(self, proxied): + + def __init__(self, proxied: Sequence[T_co]): self.__proxied = proxied - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> T_co: return self.__proxied[idx] - def __len__(self): + def __len__(self) -> int: return len(self.__proxied) - def __contains__(self, item): + def __contains__(self, item: Any) -> bool: return item in self.__proxied - def __iter__(self): + def __iter__(self) -> Iterator[T_co]: return iter(self.__proxied) - def __reversed__(self): + def __reversed__(self) -> Iterator[T_co]: return reversed(self.__proxied) - def index(self, value, *args, **kwargs): + def index(self, value: Any, *args, **kwargs) -> int: return self.__proxied.index(value, *args, **kwargs) - def count(self, value): + def count(self, value: Any) -> int: return self.__proxied.count(value) + @overload def parse_time(timestamp: None) -> None: ... + @overload def parse_time(timestamp: str) -> datetime.datetime: ... + def parse_time(timestamp: Optional[str]) -> Optional[datetime.datetime]: if timestamp: return datetime.datetime.fromisoformat(timestamp) return None -def copy_doc(original): - def decorator(overriden): + +def copy_doc(original: _Func) -> Callable[[_Func], _Func]: + def decorator(overriden: _Func) -> _Func: overriden.__doc__ = original.__doc__ - overriden.__signature__ = _signature(original) + overriden.__signature__ = _signature(original) # type: ignore return overriden + return decorator -def deprecated(instead=None): - def actual_decorator(func): + +def deprecated(instead: Optional[str] = None) -> Callable[[Callable[..., T]], Callable[..., T]]: + def actual_decorator(func: Callable[..., T]) -> Callable[..., T]: @functools.wraps(func) - def decorated(*args, **kwargs): - warnings.simplefilter('always', DeprecationWarning) # turn off filter + def decorated(*args, **kwargs) -> T: + warnings.simplefilter('always', DeprecationWarning) # turn off filter if instead: fmt = "{0.__name__} is deprecated, use {1} instead." else: fmt = '{0.__name__} is deprecated.' warnings.warn(fmt.format(func, instead), stacklevel=3, category=DeprecationWarning) - warnings.simplefilter('default', DeprecationWarning) # reset filter + warnings.simplefilter('default', DeprecationWarning) # reset filter return func(*args, **kwargs) + return decorated + return actual_decorator -def oauth_url(client_id, permissions=None, guild=None, redirect_uri=None, scopes=None): + +def oauth_url( + client_id: str, + permissions: Optional[Permissions] = None, + guild: Optional[Snowflake] = None, + redirect_uri: Optional[str] = None, + scopes: Optional[Iterable[str]] = None, +): """A helper function that returns the OAuth2 URL for inviting the bot into guilds. @@ -178,7 +226,7 @@ def oauth_url(client_id, permissions=None, guild=None, redirect_uri=None, scopes permissions: :class:`~discord.Permissions` The permissions you're requesting. If not given then you won't be requesting any permissions. - guild: :class:`~discord.Guild` + guild: :class:`~discord.abc.Snowflake` The guild to pre-select in the authorization screen, if available. redirect_uri: :class:`str` An optional valid redirect URI. @@ -200,6 +248,7 @@ def oauth_url(client_id, permissions=None, guild=None, redirect_uri=None, scopes url = url + "&guild_id=" + str(guild.id) if redirect_uri is not None: from urllib.parse import urlencode + url = url + "&response_type=code&" + urlencode({'redirect_uri': redirect_uri}) return url @@ -219,6 +268,7 @@ def snowflake_time(id: int) -> datetime.datetime: timestamp = ((id >> 22) + DISCORD_EPOCH) / 1000 return datetime.datetime.utcfromtimestamp(timestamp).replace(tzinfo=datetime.timezone.utc) + def time_snowflake(dt: datetime.datetime, high: bool = False) -> int: """Returns a numeric snowflake pretending to be created at the given date. @@ -242,9 +292,10 @@ def time_snowflake(dt: datetime.datetime, high: bool = False) -> int: The snowflake representing the time given. """ discord_millis = int(dt.timestamp() * 1000 - DISCORD_EPOCH) - return (discord_millis << 22) + (2**22-1 if high else 0) + return (discord_millis << 22) + (2 ** 22 - 1 if high else 0) + -def find(predicate, seq): +def find(predicate: Callable[[T], bool], seq: Iterable[T]) -> Optional[T]: """A helper to return the first element found in the sequence that meets the predicate. For example: :: @@ -269,7 +320,8 @@ def find(predicate, seq): return element return None -def get(iterable, **attrs): + +def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]: r"""A helper that returns the first element in the iterable that meets all the traits passed in ``attrs``. This is an alternative for :func:`~discord.utils.find`. @@ -326,21 +378,20 @@ def get(iterable, **attrs): return elem return None - converted = [ - (attrget(attr.replace('__', '.')), value) - for attr, value in attrs.items() - ] + converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()] for elem in iterable: if _all(pred(elem) == value for pred, value in converted): return elem return None -def _unique(iterable): + +def _unique(iterable: Iterable[T]) -> List[T]: seen = set() adder = seen.add return [x for x in iterable if not (x in seen or adder(x))] + def _get_as_snowflake(data: Any, key: str) -> Optional[int]: try: value = data[key] @@ -349,7 +400,8 @@ def _get_as_snowflake(data: Any, key: str) -> Optional[int]: else: return value and int(value) -def _get_mime_type_for_image(data): + +def _get_mime_type_for_image(data: bytes): if data.startswith(b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A'): return 'image/png' elif data[0:3] == b'\xff\xd8\xff' or data[6:10] in (b'JFIF', b'Exif'): @@ -361,16 +413,19 @@ def _get_mime_type_for_image(data): else: raise InvalidArgument('Unsupported image type given') -def _bytes_to_base64_data(data): + +def _bytes_to_base64_data(data: bytes) -> str: fmt = 'data:{mime};base64,{data}' mime = _get_mime_type_for_image(data) b64 = b64encode(data).decode('ascii') return fmt.format(mime=mime, data=b64) -def to_json(obj): + +def to_json(obj: Any) -> str: return json.dumps(obj, separators=(',', ':'), ensure_ascii=True) -def _parse_ratelimit_header(request, *, use_clock=False): + +def _parse_ratelimit_header(request: _RequestLike, *, use_clock: bool = False) -> float: reset_after = request.headers.get('X-Ratelimit-Reset-After') if use_clock or not reset_after: utc = datetime.timezone.utc @@ -380,6 +435,7 @@ def _parse_ratelimit_header(request, *, use_clock=False): else: return float(reset_after) + async def maybe_coroutine(f, *args, **kwargs): value = f(*args, **kwargs) if _isawaitable(value): @@ -387,6 +443,7 @@ async def maybe_coroutine(f, *args, **kwargs): else: return value + async def async_all(gen, *, check=_isawaitable): for elem in gen: if check(elem): @@ -395,10 +452,9 @@ async def async_all(gen, *, check=_isawaitable): return False return True + async def sane_wait_for(futures, *, timeout): - ensured = [ - asyncio.ensure_future(fut) for fut in futures - ] + ensured = [asyncio.ensure_future(fut) for fut in futures] done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED) if len(pending) != 0: @@ -406,7 +462,18 @@ async def sane_wait_for(futures, *, timeout): return done -async def sleep_until(when, result=None): + +@overload +async def sleep_until(when: datetime.datetime, result: None) -> None: + ... + + +@overload +async def sleep_until(when: datetime.datetime, result: T) -> T: + ... + + +async def sleep_until(when: datetime.datetime, result: Optional[T] = None) -> Optional[T]: """|coro| Sleep until a specified time. @@ -429,6 +496,7 @@ async def sleep_until(when, result=None): delta = (when - now).total_seconds() return await asyncio.sleep(max(delta, 0), result) + def utcnow() -> datetime.datetime: """A helper function to return an aware UTC datetime representing the current time. @@ -444,10 +512,12 @@ def utcnow() -> datetime.datetime: """ return datetime.datetime.now(datetime.timezone.utc) -def valid_icon_size(size): + +def valid_icon_size(size: int) -> bool: """Icons must be power of 2 within [16, 4096].""" return not size & (size - 1) and size in range(16, 4097) + class SnowflakeList(array.array): """Internal data storage class to efficiently store a list of snowflakes. @@ -462,24 +532,26 @@ class SnowflakeList(array.array): __slots__ = () - def __new__(cls, data, *, is_sorted=False): - return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) + def __new__(cls, data: Sequence[int], *, is_sorted: bool = False): + return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) # type: ignore - def add(self, element): + def add(self, element: int) -> None: i = bisect_left(self, element) self.insert(i, element) - def get(self, element): + def get(self, element: int) -> Optional[int]: i = bisect_left(self, element) return self[i] if i != len(self) and self[i] == element else None - def has(self, element): + def has(self, element: int) -> bool: i = bisect_left(self, element) return i != len(self) and self[i] == element + _IS_ASCII = re.compile(r'^[\x00-\x7f]+$') -def _string_width(string, *, _IS_ASCII=_IS_ASCII): + +def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int: """Returns string's width.""" match = _IS_ASCII.match(string) if match: @@ -489,7 +561,8 @@ def _string_width(string, *, _IS_ASCII=_IS_ASCII): func = unicodedata.east_asian_width return sum(2 if func(char) in UNICODE_WIDE_CHAR_TYPE else 1 for char in string) -def resolve_invite(invite): + +def resolve_invite(invite: Union[Invite, str]) -> str: """ Resolves an invite from a :class:`~discord.Invite`, URL or code. @@ -504,6 +577,7 @@ def resolve_invite(invite): The invite code. """ from .invite import Invite # circular import + if isinstance(invite, Invite): return invite.code else: @@ -513,7 +587,8 @@ def resolve_invite(invite): return m.group(1) return invite -def resolve_template(code): + +def resolve_template(code: Union[Template, str]) -> str: """ Resolves a template code from a :class:`~discord.Template`, URL or code. @@ -529,7 +604,8 @@ def resolve_template(code): :class:`str` The template code. """ - from .template import Template # circular import + from .template import Template # circular import + if isinstance(code, Template): return code.code else: @@ -539,8 +615,8 @@ def resolve_template(code): return m.group(1) return code -_MARKDOWN_ESCAPE_SUBREGEX = '|'.join(r'\{0}(?=([\s\S]*((?<!\{0})\{0})))'.format(c) - for c in ('*', '`', '_', '~', '|')) + +_MARKDOWN_ESCAPE_SUBREGEX = '|'.join(r'\{0}(?=([\s\S]*((?<!\{0})\{0})))'.format(c) for c in ('*', '`', '_', '~', '|')) _MARKDOWN_ESCAPE_COMMON = r'^>(?:>>)?\s|\[.+\]\(.+\)' @@ -550,7 +626,8 @@ _URL_REGEX = r'(?P<url><[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\ _MARKDOWN_STOCK_REGEX = fr'(?P<markdown>[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})' -def remove_markdown(text, *, ignore_links=True): + +def remove_markdown(text: str, *, ignore_links: bool = True) -> str: """A helper function that removes markdown characters. .. versionadded:: 1.7 @@ -583,7 +660,8 @@ def remove_markdown(text, *, ignore_links=True): regex = f'(?:{_URL_REGEX}|{regex})' return re.sub(regex, replacement, text, 0, re.MULTILINE) -def escape_markdown(text, *, as_needed=False, ignore_links=True): + +def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool = True) -> str: r"""A helper function that escapes Discord's markdown. Parameters @@ -609,6 +687,7 @@ def escape_markdown(text, *, as_needed=False, ignore_links=True): """ if not as_needed: + def replacement(match): groupdict = match.groupdict() is_url = groupdict.get('url') @@ -624,7 +703,8 @@ def escape_markdown(text, *, as_needed=False, ignore_links=True): text = re.sub(r'\\', r'\\\\', text) return _MARKDOWN_ESCAPE_REGEX.sub(r'\\\1', text) -def escape_mentions(text): + +def escape_mentions(text: str) -> str: """A helper function that escapes everyone, here, role, and user mentions. .. note::