|
|
@ -26,7 +26,23 @@ from __future__ import annotations |
|
|
|
import array |
|
|
|
import asyncio |
|
|
|
import collections.abc |
|
|
|
from typing import Any, Callable, Generic, Optional, Type, TypeVar, overload, TYPE_CHECKING |
|
|
|
from typing import ( |
|
|
|
Any, |
|
|
|
Callable, |
|
|
|
Dict, |
|
|
|
Generic, |
|
|
|
Iterable, |
|
|
|
Iterator, |
|
|
|
List, |
|
|
|
Optional, |
|
|
|
Protocol, |
|
|
|
Sequence, |
|
|
|
Type, |
|
|
|
TypeVar, |
|
|
|
Union, |
|
|
|
overload, |
|
|
|
TYPE_CHECKING, |
|
|
|
) |
|
|
|
import unicodedata |
|
|
|
from base64 import b64encode |
|
|
|
from bisect import bisect_left |
|
|
@ -52,8 +68,10 @@ __all__ = ( |
|
|
|
'escape_markdown', |
|
|
|
'escape_mentions', |
|
|
|
) |
|
|
|
|
|
|
|
DISCORD_EPOCH = 1420070400000 |
|
|
|
|
|
|
|
|
|
|
|
class cached_property: |
|
|
|
def __init__(self, function): |
|
|
|
self.function = function |
|
|
@ -68,13 +86,24 @@ class cached_property: |
|
|
|
|
|
|
|
return value |
|
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
from functools import cached_property |
|
|
|
from .permissions import Permissions |
|
|
|
from .abc import Snowflake |
|
|
|
from .invite import Invite |
|
|
|
from .template import Template |
|
|
|
from types import FunctionType as _Func |
|
|
|
|
|
|
|
class _RequestLike(Protocol): |
|
|
|
headers: Dict[str, Any] |
|
|
|
|
|
|
|
|
|
|
|
T = TypeVar('T') |
|
|
|
T_co = TypeVar('T_co', covariant=True) |
|
|
|
CSP = TypeVar('CSP', bound='CachedSlotProperty') |
|
|
|
|
|
|
|
|
|
|
|
class CachedSlotProperty(Generic[T, T_co]): |
|
|
|
def __init__(self, name: str, function: Callable[[T], T_co]) -> None: |
|
|
|
self.name = name |
|
|
@ -100,74 +129,93 @@ class CachedSlotProperty(Generic[T, T_co]): |
|
|
|
setattr(instance, self.name, value) |
|
|
|
return value |
|
|
|
|
|
|
|
|
|
|
|
def cached_slot_property(name: str) -> Callable[[Callable[[T], T_co]], CachedSlotProperty[T, T_co]]: |
|
|
|
def decorator(func: Callable[[T], T_co]) -> CachedSlotProperty[T, T_co]: |
|
|
|
return CachedSlotProperty(name, func) |
|
|
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
class SequenceProxy(collections.abc.Sequence): |
|
|
|
|
|
|
|
class SequenceProxy(Generic[T_co], collections.abc.Sequence): |
|
|
|
"""Read-only proxy of a Sequence.""" |
|
|
|
def __init__(self, proxied): |
|
|
|
|
|
|
|
def __init__(self, proxied: Sequence[T_co]): |
|
|
|
self.__proxied = proxied |
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
def __getitem__(self, idx: int) -> T_co: |
|
|
|
return self.__proxied[idx] |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
def __len__(self) -> int: |
|
|
|
return len(self.__proxied) |
|
|
|
|
|
|
|
def __contains__(self, item): |
|
|
|
def __contains__(self, item: Any) -> bool: |
|
|
|
return item in self.__proxied |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
def __iter__(self) -> Iterator[T_co]: |
|
|
|
return iter(self.__proxied) |
|
|
|
|
|
|
|
def __reversed__(self): |
|
|
|
def __reversed__(self) -> Iterator[T_co]: |
|
|
|
return reversed(self.__proxied) |
|
|
|
|
|
|
|
def index(self, value, *args, **kwargs): |
|
|
|
def index(self, value: Any, *args, **kwargs) -> int: |
|
|
|
return self.__proxied.index(value, *args, **kwargs) |
|
|
|
|
|
|
|
def count(self, value): |
|
|
|
def count(self, value: Any) -> int: |
|
|
|
return self.__proxied.count(value) |
|
|
|
|
|
|
|
|
|
|
|
@overload |
|
|
|
def parse_time(timestamp: None) -> None: |
|
|
|
... |
|
|
|
|
|
|
|
|
|
|
|
@overload |
|
|
|
def parse_time(timestamp: str) -> datetime.datetime: |
|
|
|
... |
|
|
|
|
|
|
|
|
|
|
|
def parse_time(timestamp: Optional[str]) -> Optional[datetime.datetime]: |
|
|
|
if timestamp: |
|
|
|
return datetime.datetime.fromisoformat(timestamp) |
|
|
|
return None |
|
|
|
|
|
|
|
def copy_doc(original): |
|
|
|
def decorator(overriden): |
|
|
|
|
|
|
|
def copy_doc(original: _Func) -> Callable[[_Func], _Func]: |
|
|
|
def decorator(overriden: _Func) -> _Func: |
|
|
|
overriden.__doc__ = original.__doc__ |
|
|
|
overriden.__signature__ = _signature(original) |
|
|
|
overriden.__signature__ = _signature(original) # type: ignore |
|
|
|
return overriden |
|
|
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
def deprecated(instead=None): |
|
|
|
def actual_decorator(func): |
|
|
|
|
|
|
|
def deprecated(instead: Optional[str] = None) -> Callable[[Callable[..., T]], Callable[..., T]]: |
|
|
|
def actual_decorator(func: Callable[..., T]) -> Callable[..., T]: |
|
|
|
@functools.wraps(func) |
|
|
|
def decorated(*args, **kwargs): |
|
|
|
warnings.simplefilter('always', DeprecationWarning) # turn off filter |
|
|
|
def decorated(*args, **kwargs) -> T: |
|
|
|
warnings.simplefilter('always', DeprecationWarning) # turn off filter |
|
|
|
if instead: |
|
|
|
fmt = "{0.__name__} is deprecated, use {1} instead." |
|
|
|
else: |
|
|
|
fmt = '{0.__name__} is deprecated.' |
|
|
|
|
|
|
|
warnings.warn(fmt.format(func, instead), stacklevel=3, category=DeprecationWarning) |
|
|
|
warnings.simplefilter('default', DeprecationWarning) # reset filter |
|
|
|
warnings.simplefilter('default', DeprecationWarning) # reset filter |
|
|
|
return func(*args, **kwargs) |
|
|
|
|
|
|
|
return decorated |
|
|
|
|
|
|
|
return actual_decorator |
|
|
|
|
|
|
|
def oauth_url(client_id, permissions=None, guild=None, redirect_uri=None, scopes=None): |
|
|
|
|
|
|
|
def oauth_url( |
|
|
|
client_id: str, |
|
|
|
permissions: Optional[Permissions] = None, |
|
|
|
guild: Optional[Snowflake] = None, |
|
|
|
redirect_uri: Optional[str] = None, |
|
|
|
scopes: Optional[Iterable[str]] = None, |
|
|
|
): |
|
|
|
"""A helper function that returns the OAuth2 URL for inviting the bot |
|
|
|
into guilds. |
|
|
|
|
|
|
@ -178,7 +226,7 @@ def oauth_url(client_id, permissions=None, guild=None, redirect_uri=None, scopes |
|
|
|
permissions: :class:`~discord.Permissions` |
|
|
|
The permissions you're requesting. If not given then you won't be requesting any |
|
|
|
permissions. |
|
|
|
guild: :class:`~discord.Guild` |
|
|
|
guild: :class:`~discord.abc.Snowflake` |
|
|
|
The guild to pre-select in the authorization screen, if available. |
|
|
|
redirect_uri: :class:`str` |
|
|
|
An optional valid redirect URI. |
|
|
@ -200,6 +248,7 @@ def oauth_url(client_id, permissions=None, guild=None, redirect_uri=None, scopes |
|
|
|
url = url + "&guild_id=" + str(guild.id) |
|
|
|
if redirect_uri is not None: |
|
|
|
from urllib.parse import urlencode |
|
|
|
|
|
|
|
url = url + "&response_type=code&" + urlencode({'redirect_uri': redirect_uri}) |
|
|
|
return url |
|
|
|
|
|
|
@ -219,6 +268,7 @@ def snowflake_time(id: int) -> datetime.datetime: |
|
|
|
timestamp = ((id >> 22) + DISCORD_EPOCH) / 1000 |
|
|
|
return datetime.datetime.utcfromtimestamp(timestamp).replace(tzinfo=datetime.timezone.utc) |
|
|
|
|
|
|
|
|
|
|
|
def time_snowflake(dt: datetime.datetime, high: bool = False) -> int: |
|
|
|
"""Returns a numeric snowflake pretending to be created at the given date. |
|
|
|
|
|
|
@ -242,9 +292,10 @@ def time_snowflake(dt: datetime.datetime, high: bool = False) -> int: |
|
|
|
The snowflake representing the time given. |
|
|
|
""" |
|
|
|
discord_millis = int(dt.timestamp() * 1000 - DISCORD_EPOCH) |
|
|
|
return (discord_millis << 22) + (2**22-1 if high else 0) |
|
|
|
return (discord_millis << 22) + (2 ** 22 - 1 if high else 0) |
|
|
|
|
|
|
|
|
|
|
|
def find(predicate, seq): |
|
|
|
def find(predicate: Callable[[T], bool], seq: Iterable[T]) -> Optional[T]: |
|
|
|
"""A helper to return the first element found in the sequence |
|
|
|
that meets the predicate. For example: :: |
|
|
|
|
|
|
@ -269,7 +320,8 @@ def find(predicate, seq): |
|
|
|
return element |
|
|
|
return None |
|
|
|
|
|
|
|
def get(iterable, **attrs): |
|
|
|
|
|
|
|
def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]: |
|
|
|
r"""A helper that returns the first element in the iterable that meets |
|
|
|
all the traits passed in ``attrs``. This is an alternative for |
|
|
|
:func:`~discord.utils.find`. |
|
|
@ -326,21 +378,20 @@ def get(iterable, **attrs): |
|
|
|
return elem |
|
|
|
return None |
|
|
|
|
|
|
|
converted = [ |
|
|
|
(attrget(attr.replace('__', '.')), value) |
|
|
|
for attr, value in attrs.items() |
|
|
|
] |
|
|
|
converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()] |
|
|
|
|
|
|
|
for elem in iterable: |
|
|
|
if _all(pred(elem) == value for pred, value in converted): |
|
|
|
return elem |
|
|
|
return None |
|
|
|
|
|
|
|
def _unique(iterable): |
|
|
|
|
|
|
|
def _unique(iterable: Iterable[T]) -> List[T]: |
|
|
|
seen = set() |
|
|
|
adder = seen.add |
|
|
|
return [x for x in iterable if not (x in seen or adder(x))] |
|
|
|
|
|
|
|
|
|
|
|
def _get_as_snowflake(data: Any, key: str) -> Optional[int]: |
|
|
|
try: |
|
|
|
value = data[key] |
|
|
@ -349,7 +400,8 @@ def _get_as_snowflake(data: Any, key: str) -> Optional[int]: |
|
|
|
else: |
|
|
|
return value and int(value) |
|
|
|
|
|
|
|
def _get_mime_type_for_image(data): |
|
|
|
|
|
|
|
def _get_mime_type_for_image(data: bytes): |
|
|
|
if data.startswith(b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A'): |
|
|
|
return 'image/png' |
|
|
|
elif data[0:3] == b'\xff\xd8\xff' or data[6:10] in (b'JFIF', b'Exif'): |
|
|
@ -361,16 +413,19 @@ def _get_mime_type_for_image(data): |
|
|
|
else: |
|
|
|
raise InvalidArgument('Unsupported image type given') |
|
|
|
|
|
|
|
def _bytes_to_base64_data(data): |
|
|
|
|
|
|
|
def _bytes_to_base64_data(data: bytes) -> str: |
|
|
|
fmt = 'data:{mime};base64,{data}' |
|
|
|
mime = _get_mime_type_for_image(data) |
|
|
|
b64 = b64encode(data).decode('ascii') |
|
|
|
return fmt.format(mime=mime, data=b64) |
|
|
|
|
|
|
|
def to_json(obj): |
|
|
|
|
|
|
|
def to_json(obj: Any) -> str: |
|
|
|
return json.dumps(obj, separators=(',', ':'), ensure_ascii=True) |
|
|
|
|
|
|
|
def _parse_ratelimit_header(request, *, use_clock=False): |
|
|
|
|
|
|
|
def _parse_ratelimit_header(request: _RequestLike, *, use_clock: bool = False) -> float: |
|
|
|
reset_after = request.headers.get('X-Ratelimit-Reset-After') |
|
|
|
if use_clock or not reset_after: |
|
|
|
utc = datetime.timezone.utc |
|
|
@ -380,6 +435,7 @@ def _parse_ratelimit_header(request, *, use_clock=False): |
|
|
|
else: |
|
|
|
return float(reset_after) |
|
|
|
|
|
|
|
|
|
|
|
async def maybe_coroutine(f, *args, **kwargs): |
|
|
|
value = f(*args, **kwargs) |
|
|
|
if _isawaitable(value): |
|
|
@ -387,6 +443,7 @@ async def maybe_coroutine(f, *args, **kwargs): |
|
|
|
else: |
|
|
|
return value |
|
|
|
|
|
|
|
|
|
|
|
async def async_all(gen, *, check=_isawaitable): |
|
|
|
for elem in gen: |
|
|
|
if check(elem): |
|
|
@ -395,10 +452,9 @@ async def async_all(gen, *, check=_isawaitable): |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
async def sane_wait_for(futures, *, timeout): |
|
|
|
ensured = [ |
|
|
|
asyncio.ensure_future(fut) for fut in futures |
|
|
|
] |
|
|
|
ensured = [asyncio.ensure_future(fut) for fut in futures] |
|
|
|
done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED) |
|
|
|
|
|
|
|
if len(pending) != 0: |
|
|
@ -406,7 +462,18 @@ async def sane_wait_for(futures, *, timeout): |
|
|
|
|
|
|
|
return done |
|
|
|
|
|
|
|
async def sleep_until(when, result=None): |
|
|
|
|
|
|
|
@overload |
|
|
|
async def sleep_until(when: datetime.datetime, result: None) -> None: |
|
|
|
... |
|
|
|
|
|
|
|
|
|
|
|
@overload |
|
|
|
async def sleep_until(when: datetime.datetime, result: T) -> T: |
|
|
|
... |
|
|
|
|
|
|
|
|
|
|
|
async def sleep_until(when: datetime.datetime, result: Optional[T] = None) -> Optional[T]: |
|
|
|
"""|coro| |
|
|
|
|
|
|
|
Sleep until a specified time. |
|
|
@ -429,6 +496,7 @@ async def sleep_until(when, result=None): |
|
|
|
delta = (when - now).total_seconds() |
|
|
|
return await asyncio.sleep(max(delta, 0), result) |
|
|
|
|
|
|
|
|
|
|
|
def utcnow() -> datetime.datetime: |
|
|
|
"""A helper function to return an aware UTC datetime representing the current time. |
|
|
|
|
|
|
@ -444,10 +512,12 @@ def utcnow() -> datetime.datetime: |
|
|
|
""" |
|
|
|
return datetime.datetime.now(datetime.timezone.utc) |
|
|
|
|
|
|
|
def valid_icon_size(size): |
|
|
|
|
|
|
|
def valid_icon_size(size: int) -> bool: |
|
|
|
"""Icons must be power of 2 within [16, 4096].""" |
|
|
|
return not size & (size - 1) and size in range(16, 4097) |
|
|
|
|
|
|
|
|
|
|
|
class SnowflakeList(array.array): |
|
|
|
"""Internal data storage class to efficiently store a list of snowflakes. |
|
|
|
|
|
|
@ -462,24 +532,26 @@ class SnowflakeList(array.array): |
|
|
|
|
|
|
|
__slots__ = () |
|
|
|
|
|
|
|
def __new__(cls, data, *, is_sorted=False): |
|
|
|
return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) |
|
|
|
def __new__(cls, data: Sequence[int], *, is_sorted: bool = False): |
|
|
|
return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) # type: ignore |
|
|
|
|
|
|
|
def add(self, element): |
|
|
|
def add(self, element: int) -> None: |
|
|
|
i = bisect_left(self, element) |
|
|
|
self.insert(i, element) |
|
|
|
|
|
|
|
def get(self, element): |
|
|
|
def get(self, element: int) -> Optional[int]: |
|
|
|
i = bisect_left(self, element) |
|
|
|
return self[i] if i != len(self) and self[i] == element else None |
|
|
|
|
|
|
|
def has(self, element): |
|
|
|
def has(self, element: int) -> bool: |
|
|
|
i = bisect_left(self, element) |
|
|
|
return i != len(self) and self[i] == element |
|
|
|
|
|
|
|
|
|
|
|
_IS_ASCII = re.compile(r'^[\x00-\x7f]+$') |
|
|
|
|
|
|
|
def _string_width(string, *, _IS_ASCII=_IS_ASCII): |
|
|
|
|
|
|
|
def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int: |
|
|
|
"""Returns string's width.""" |
|
|
|
match = _IS_ASCII.match(string) |
|
|
|
if match: |
|
|
@ -489,7 +561,8 @@ def _string_width(string, *, _IS_ASCII=_IS_ASCII): |
|
|
|
func = unicodedata.east_asian_width |
|
|
|
return sum(2 if func(char) in UNICODE_WIDE_CHAR_TYPE else 1 for char in string) |
|
|
|
|
|
|
|
def resolve_invite(invite): |
|
|
|
|
|
|
|
def resolve_invite(invite: Union[Invite, str]) -> str: |
|
|
|
""" |
|
|
|
Resolves an invite from a :class:`~discord.Invite`, URL or code. |
|
|
|
|
|
|
@ -504,6 +577,7 @@ def resolve_invite(invite): |
|
|
|
The invite code. |
|
|
|
""" |
|
|
|
from .invite import Invite # circular import |
|
|
|
|
|
|
|
if isinstance(invite, Invite): |
|
|
|
return invite.code |
|
|
|
else: |
|
|
@ -513,7 +587,8 @@ def resolve_invite(invite): |
|
|
|
return m.group(1) |
|
|
|
return invite |
|
|
|
|
|
|
|
def resolve_template(code): |
|
|
|
|
|
|
|
def resolve_template(code: Union[Template, str]) -> str: |
|
|
|
""" |
|
|
|
Resolves a template code from a :class:`~discord.Template`, URL or code. |
|
|
|
|
|
|
@ -529,7 +604,8 @@ def resolve_template(code): |
|
|
|
:class:`str` |
|
|
|
The template code. |
|
|
|
""" |
|
|
|
from .template import Template # circular import |
|
|
|
from .template import Template # circular import |
|
|
|
|
|
|
|
if isinstance(code, Template): |
|
|
|
return code.code |
|
|
|
else: |
|
|
@ -539,8 +615,8 @@ def resolve_template(code): |
|
|
|
return m.group(1) |
|
|
|
return code |
|
|
|
|
|
|
|
_MARKDOWN_ESCAPE_SUBREGEX = '|'.join(r'\{0}(?=([\s\S]*((?<!\{0})\{0})))'.format(c) |
|
|
|
for c in ('*', '`', '_', '~', '|')) |
|
|
|
|
|
|
|
_MARKDOWN_ESCAPE_SUBREGEX = '|'.join(r'\{0}(?=([\s\S]*((?<!\{0})\{0})))'.format(c) for c in ('*', '`', '_', '~', '|')) |
|
|
|
|
|
|
|
_MARKDOWN_ESCAPE_COMMON = r'^>(?:>>)?\s|\[.+\]\(.+\)' |
|
|
|
|
|
|
@ -550,7 +626,8 @@ _URL_REGEX = r'(?P<url><[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\ |
|
|
|
|
|
|
|
_MARKDOWN_STOCK_REGEX = fr'(?P<markdown>[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})' |
|
|
|
|
|
|
|
def remove_markdown(text, *, ignore_links=True): |
|
|
|
|
|
|
|
def remove_markdown(text: str, *, ignore_links: bool = True) -> str: |
|
|
|
"""A helper function that removes markdown characters. |
|
|
|
|
|
|
|
.. versionadded:: 1.7 |
|
|
@ -583,7 +660,8 @@ def remove_markdown(text, *, ignore_links=True): |
|
|
|
regex = f'(?:{_URL_REGEX}|{regex})' |
|
|
|
return re.sub(regex, replacement, text, 0, re.MULTILINE) |
|
|
|
|
|
|
|
def escape_markdown(text, *, as_needed=False, ignore_links=True): |
|
|
|
|
|
|
|
def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool = True) -> str: |
|
|
|
r"""A helper function that escapes Discord's markdown. |
|
|
|
|
|
|
|
Parameters |
|
|
@ -609,6 +687,7 @@ def escape_markdown(text, *, as_needed=False, ignore_links=True): |
|
|
|
""" |
|
|
|
|
|
|
|
if not as_needed: |
|
|
|
|
|
|
|
def replacement(match): |
|
|
|
groupdict = match.groupdict() |
|
|
|
is_url = groupdict.get('url') |
|
|
@ -624,7 +703,8 @@ def escape_markdown(text, *, as_needed=False, ignore_links=True): |
|
|
|
text = re.sub(r'\\', r'\\\\', text) |
|
|
|
return _MARKDOWN_ESCAPE_REGEX.sub(r'\\\1', text) |
|
|
|
|
|
|
|
def escape_mentions(text): |
|
|
|
|
|
|
|
def escape_mentions(text: str) -> str: |
|
|
|
"""A helper function that escapes everyone, here, role, and user mentions. |
|
|
|
|
|
|
|
.. note:: |
|
|
|