Browse Source

Reformat code using black

Segments where readability was hampered were fixed by appropriate
format skipping directives. New code should hopefully be black
compatible. The moment they remove the -S option is probably the moment
I stop using black though.
pull/7494/head
Rapptz 3 years ago
parent
commit
88b520b5ab
  1. 37
      discord/__main__.py
  2. 7
      discord/activity.py
  3. 5
      discord/asset.py
  4. 8
      discord/audit_logs.py
  5. 9
      discord/backoff.py
  6. 4
      discord/channel.py
  7. 91
      discord/client.py
  8. 44
      discord/colour.py
  9. 10
      discord/context_managers.py
  10. 6
      discord/embeds.py
  11. 2
      discord/emoji.py
  12. 1
      discord/enums.py
  13. 4
      discord/ext/commands/_types.py
  14. 30
      discord/ext/commands/bot.py
  15. 14
      discord/ext/commands/cog.py
  16. 7
      discord/ext/commands/context.py
  17. 20
      discord/ext/commands/cooldowns.py
  18. 162
      discord/ext/commands/core.py
  19. 108
      discord/ext/commands/errors.py
  20. 2
      discord/ext/commands/help.py
  21. 10
      discord/ext/commands/view.py
  22. 2
      discord/ext/tasks/__init__.py
  23. 2
      discord/file.py
  24. 4
      discord/flags.py
  25. 97
      discord/gateway.py
  26. 9
      discord/guild.py
  27. 4
      discord/http.py
  28. 8
      discord/member.py
  29. 2
      discord/mentions.py
  30. 2
      discord/message.py
  31. 2
      discord/mixins.py
  32. 4
      discord/object.py
  33. 12
      discord/oggparse.py
  34. 114
      discord/opus.py
  35. 3
      discord/partial_emoji.py
  36. 10
      discord/permissions.py
  37. 33
      discord/player.py
  38. 5
      discord/raw_models.py
  39. 13
      discord/reaction.py
  40. 12
      discord/stage_instance.py
  41. 6
      discord/state.py
  42. 2
      discord/sticker.py
  43. 4
      discord/template.py
  44. 5
      discord/types/appinfo.py
  45. 11
      discord/types/embed.py
  46. 2
      discord/types/team.py
  47. 2
      discord/ui/item.py
  48. 2
      discord/ui/modal.py
  49. 1
      discord/ui/select.py
  50. 5
      discord/ui/text_input.py
  51. 13
      discord/ui/view.py
  52. 2
      discord/utils.py
  53. 19
      discord/voice_client.py
  54. 4
      discord/webhook/async_.py
  55. 2
      discord/webhook/sync.py
  56. 28
      discord/widget.py

37
discord/__main__.py

@ -31,6 +31,7 @@ import pkg_resources
import aiohttp import aiohttp
import platform import platform
def show_version(): def show_version():
entries = [] entries = []
@ -47,10 +48,12 @@ def show_version():
entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname)) entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname))
print('\n'.join(entries)) print('\n'.join(entries))
def core(parser, args): def core(parser, args):
if args.version: if args.version:
show_version() show_version()
_bot_template = """#!/usr/bin/env python3 _bot_template = """#!/usr/bin/env python3
from discord.ext import commands from discord.ext import commands
@ -172,13 +175,36 @@ _base_table.update((chr(i), None) for i in range(32))
_translation_table = str.maketrans(_base_table) _translation_table = str.maketrans(_base_table)
def to_path(parser, name, *, replace_spaces=False): def to_path(parser, name, *, replace_spaces=False):
if isinstance(name, Path): if isinstance(name, Path):
return name return name
if sys.platform == 'win32': if sys.platform == 'win32':
forbidden = ('CON', 'PRN', 'AUX', 'NUL', 'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7', \ forbidden = (
'COM8', 'COM9', 'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9') 'CON',
'PRN',
'AUX',
'NUL',
'COM1',
'COM2',
'COM3',
'COM4',
'COM5',
'COM6',
'COM7',
'COM8',
'COM9',
'LPT1',
'LPT2',
'LPT3',
'LPT4',
'LPT5',
'LPT6',
'LPT7',
'LPT8',
'LPT9',
)
if len(name) <= 4 and name.upper() in forbidden: if len(name) <= 4 and name.upper() in forbidden:
parser.error('invalid directory name given, use a different one') parser.error('invalid directory name given, use a different one')
@ -187,6 +213,7 @@ def to_path(parser, name, *, replace_spaces=False):
name = name.replace(' ', '-') name = name.replace(' ', '-')
return Path(name) return Path(name)
def newbot(parser, args): def newbot(parser, args):
new_directory = to_path(parser, args.directory) / to_path(parser, args.name) new_directory = to_path(parser, args.directory) / to_path(parser, args.name)
@ -228,6 +255,7 @@ def newbot(parser, args):
print('successfully made bot at', new_directory) print('successfully made bot at', new_directory)
def newcog(parser, args): def newcog(parser, args):
cog_dir = to_path(parser, args.directory) cog_dir = to_path(parser, args.directory)
try: try:
@ -261,6 +289,7 @@ def newcog(parser, args):
else: else:
print('successfully made cog at', directory) print('successfully made cog at', directory)
def add_newbot_args(subparser): def add_newbot_args(subparser):
parser = subparser.add_parser('newbot', help='creates a command bot project quickly') parser = subparser.add_parser('newbot', help='creates a command bot project quickly')
parser.set_defaults(func=newbot) parser.set_defaults(func=newbot)
@ -271,6 +300,7 @@ def add_newbot_args(subparser):
parser.add_argument('--sharded', help='whether to use AutoShardedBot', action='store_true') parser.add_argument('--sharded', help='whether to use AutoShardedBot', action='store_true')
parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git') parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git')
def add_newcog_args(subparser): def add_newcog_args(subparser):
parser = subparser.add_parser('newcog', help='creates a new cog template quickly') parser = subparser.add_parser('newcog', help='creates a new cog template quickly')
parser.set_defaults(func=newcog) parser.set_defaults(func=newcog)
@ -282,6 +312,7 @@ def add_newcog_args(subparser):
parser.add_argument('--hide-commands', help='whether to hide all commands in the cog', action='store_true') parser.add_argument('--hide-commands', help='whether to hide all commands in the cog', action='store_true')
parser.add_argument('--full', help='add all special methods as well', action='store_true') parser.add_argument('--full', help='add all special methods as well', action='store_true')
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py') parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py')
parser.add_argument('-v', '--version', action='store_true', help='shows the library version') parser.add_argument('-v', '--version', action='store_true', help='shows the library version')
@ -292,9 +323,11 @@ def parse_args():
add_newcog_args(subparser) add_newcog_args(subparser)
return parser, parser.parse_args() return parser, parser.parse_args()
def main(): def main():
parser, args = parse_args() parser, args = parse_args()
args.func(parser, args) args.func(parser, args)
if __name__ == '__main__': if __name__ == '__main__':
main() main()

7
discord/activity.py

@ -807,14 +807,17 @@ class CustomActivity(BaseActivity):
ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify] ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify]
@overload @overload
def create_activity(data: ActivityPayload) -> ActivityTypes: def create_activity(data: ActivityPayload) -> ActivityTypes:
... ...
@overload @overload
def create_activity(data: None) -> None: def create_activity(data: None) -> None:
... ...
def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]: def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
if not data: if not data:
return None return None
@ -831,11 +834,11 @@ def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
return Activity(**data) return Activity(**data)
else: else:
# we removed the name key from data already # we removed the name key from data already
return CustomActivity(name=name, **data) # type: ignore return CustomActivity(name=name, **data) # type: ignore
elif game_type is ActivityType.streaming: elif game_type is ActivityType.streaming:
if 'url' in data: if 'url' in data:
# the url won't be None here # the url won't be None here
return Streaming(**data) # type: ignore return Streaming(**data) # type: ignore
return Activity(**data) return Activity(**data)
elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data: elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data:
return Spotify(**data) return Spotify(**data)

5
discord/asset.py

@ -33,9 +33,11 @@ from . import utils
import yarl import yarl
# fmt: off
__all__ = ( __all__ = (
'Asset', 'Asset',
) )
# fmt: on
if TYPE_CHECKING: if TYPE_CHECKING:
ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png'] ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png']
@ -47,6 +49,7 @@ VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"}
MISSING = utils.MISSING MISSING = utils.MISSING
class AssetMixin: class AssetMixin:
url: str url: str
_state: Optional[Any] _state: Optional[Any]
@ -245,7 +248,7 @@ class Asset(AssetMixin):
state, state,
url=f'{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512', url=f'{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512',
key=banner_hash, key=banner_hash,
animated=animated animated=animated,
) )
def __str__(self) -> str: def __str__(self) -> str:

8
discord/audit_logs.py

@ -61,6 +61,10 @@ if TYPE_CHECKING:
from .sticker import GuildSticker from .sticker import GuildSticker
from .threads import Thread from .threads import Thread
TargetType = Union[
Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, StageInstance, GuildSticker, Thread, Object, None
]
def _transform_timestamp(entry: AuditLogEntry, data: Optional[str]) -> Optional[datetime.datetime]: def _transform_timestamp(entry: AuditLogEntry, data: Optional[str]) -> Optional[datetime.datetime]:
return utils.parse_time(data) return utils.parse_time(data)
@ -154,12 +158,14 @@ def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]:
return _transform return _transform
def _transform_type(entry: AuditLogEntry, data: int) -> Union[enums.ChannelType, enums.StickerType]: def _transform_type(entry: AuditLogEntry, data: int) -> Union[enums.ChannelType, enums.StickerType]:
if entry.action.name.startswith('sticker_'): if entry.action.name.startswith('sticker_'):
return enums.try_enum(enums.StickerType, data) return enums.try_enum(enums.StickerType, data)
else: else:
return enums.try_enum(enums.ChannelType, data) return enums.try_enum(enums.ChannelType, data)
class AuditLogDiff: class AuditLogDiff:
def __len__(self) -> int: def __len__(self) -> int:
return len(self.__dict__) return len(self.__dict__)
@ -456,7 +462,7 @@ class AuditLogEntry(Hashable):
return utils.snowflake_time(self.id) return utils.snowflake_time(self.id)
@utils.cached_property @utils.cached_property
def target(self) -> Union[Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, StageInstance, GuildSticker, Thread, Object, None]: def target(self) -> TargetType:
try: try:
converter = getattr(self, '_convert_target_' + self.action.target_type) converter = getattr(self, '_convert_target_' + self.action.target_type)
except AttributeError: except AttributeError:

9
discord/backoff.py

@ -31,9 +31,12 @@ from typing import Callable, Generic, Literal, TypeVar, overload, Union
T = TypeVar('T', bool, Literal[True], Literal[False]) T = TypeVar('T', bool, Literal[True], Literal[False])
# fmt: off
__all__ = ( __all__ = (
'ExponentialBackoff', 'ExponentialBackoff',
) )
# fmt: on
class ExponentialBackoff(Generic[T]): class ExponentialBackoff(Generic[T]):
"""An implementation of the exponential backoff algorithm """An implementation of the exponential backoff algorithm
@ -62,14 +65,14 @@ class ExponentialBackoff(Generic[T]):
self._exp: int = 0 self._exp: int = 0
self._max: int = 10 self._max: int = 10
self._reset_time: int = base * 2 ** 11 self._reset_time: int = base * 2**11
self._last_invocation: float = time.monotonic() self._last_invocation: float = time.monotonic()
# Use our own random instance to avoid messing with global one # Use our own random instance to avoid messing with global one
rand = random.Random() rand = random.Random()
rand.seed() rand.seed()
self._randfunc: Callable[..., Union[int, float]] = rand.randrange if integral else rand.uniform # type: ignore self._randfunc: Callable[..., Union[int, float]] = rand.randrange if integral else rand.uniform # type: ignore
@overload @overload
def delay(self: ExponentialBackoff[Literal[False]]) -> float: def delay(self: ExponentialBackoff[Literal[False]]) -> float:
@ -102,4 +105,4 @@ class ExponentialBackoff(Generic[T]):
self._exp = 0 self._exp = 0
self._exp = min(self._exp + 1, self._max) self._exp = min(self._exp + 1, self._max)
return self._randfunc(0, self._base * 2 ** self._exp) return self._randfunc(0, self._base * 2**self._exp)

4
discord/channel.py

@ -814,7 +814,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
before_timestamp = str(before.id) before_timestamp = str(before.id)
else: else:
before_timestamp = utils.snowflake_time(before.id).isoformat() before_timestamp = utils.snowflake_time(before.id).isoformat()
update_before = lambda data: data['thread_metadata']['archive_timestamp'] update_before = lambda data: data['thread_metadata']['archive_timestamp']
endpoint = self.guild._state.http.get_public_archived_threads endpoint = self.guild._state.http.get_public_archived_threads
@ -823,7 +823,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
endpoint = self.guild._state.http.get_joined_private_archived_threads endpoint = self.guild._state.http.get_joined_private_archived_threads
elif private: elif private:
endpoint = self.guild._state.http.get_private_archived_threads endpoint = self.guild._state.http.get_private_archived_threads
while True: while True:
retrieve = 50 if limit is None else max(limit, 50) retrieve = 50 if limit is None else max(limit, 50)
data = await endpoint(self.id, before=before_timestamp, limit=retrieve) data = await endpoint(self.id, before=before_timestamp, limit=retrieve)

91
discord/client.py

@ -43,7 +43,7 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Tuple, Tuple,
TypeVar, TypeVar,
Union Union,
) )
import aiohttp import aiohttp
@ -84,15 +84,18 @@ if TYPE_CHECKING:
from .member import Member from .member import Member
from .voice_client import VoiceProtocol from .voice_client import VoiceProtocol
# fmt: off
__all__ = ( __all__ = (
'Client', 'Client',
) )
# fmt: on
Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]]) Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None: def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()} tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
@ -110,11 +113,14 @@ def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
if task.cancelled(): if task.cancelled():
continue continue
if task.exception() is not None: if task.exception() is not None:
loop.call_exception_handler({ loop.call_exception_handler(
'message': 'Unhandled exception during Client.run shutdown.', {
'exception': task.exception(), 'message': 'Unhandled exception during Client.run shutdown.',
'task': task 'exception': task.exception(),
}) 'task': task,
}
)
def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None: def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
try: try:
@ -124,6 +130,7 @@ def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
_log.info('Closing the event loop.') _log.info('Closing the event loop.')
loop.close() loop.close()
class Client: class Client:
r"""Represents a client connection that connects to Discord. r"""Represents a client connection that connects to Discord.
This class is used to interact with the Discord WebSocket and API. This class is used to interact with the Discord WebSocket and API.
@ -215,6 +222,7 @@ class Client:
loop: :class:`asyncio.AbstractEventLoop` loop: :class:`asyncio.AbstractEventLoop`
The event loop that the client uses for asynchronous operations. The event loop that the client uses for asynchronous operations.
""" """
def __init__( def __init__(
self, self,
*, *,
@ -232,14 +240,16 @@ class Client:
proxy: Optional[str] = options.pop('proxy', None) proxy: Optional[str] = options.pop('proxy', None)
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
unsync_clock: bool = options.pop('assume_unsync_clock', True) unsync_clock: bool = options.pop('assume_unsync_clock', True)
self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop) self.http: HTTPClient = HTTPClient(
connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop
)
self._handlers: Dict[str, Callable] = { self._handlers: Dict[str, Callable] = {
'ready': self._handle_ready 'ready': self._handle_ready,
} }
self._hooks: Dict[str, Callable] = { self._hooks: Dict[str, Callable] = {
'before_identify': self._call_before_identify_hook 'before_identify': self._call_before_identify_hook,
} }
self._enable_debug_events: bool = options.pop('enable_debug_events', False) self._enable_debug_events: bool = options.pop('enable_debug_events', False)
@ -260,8 +270,9 @@ class Client:
return self.ws return self.ws
def _get_state(self, **options: Any) -> ConnectionState: def _get_state(self, **options: Any) -> ConnectionState:
return ConnectionState(dispatch=self.dispatch, handlers=self._handlers, return ConnectionState(
hooks=self._hooks, http=self.http, loop=self.loop, **options) dispatch=self.dispatch, handlers=self._handlers, hooks=self._hooks, http=self.http, loop=self.loop, **options
)
def _handle_ready(self) -> None: def _handle_ready(self) -> None:
self._ready.set() self._ready.set()
@ -344,7 +355,7 @@ class Client:
If this is not passed via ``__init__`` then this is retrieved If this is not passed via ``__init__`` then this is retrieved
through the gateway when an event contains the data. Usually through the gateway when an event contains the data. Usually
after :func:`~discord.on_connect` is called. after :func:`~discord.on_connect` is called.
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
return self._connection.application_id return self._connection.application_id
@ -361,7 +372,13 @@ class Client:
""":class:`bool`: Specifies if the client's internal cache is ready for use.""" """:class:`bool`: Specifies if the client's internal cache is ready for use."""
return self._ready.is_set() return self._ready.is_set()
async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None: async def _run_event(
self,
coro: Callable[..., Coroutine[Any, Any, Any]],
event_name: str,
*args: Any,
**kwargs: Any,
) -> None:
try: try:
await coro(*args, **kwargs) await coro(*args, **kwargs)
except asyncio.CancelledError: except asyncio.CancelledError:
@ -372,7 +389,13 @@ class Client:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task: def _schedule_event(
self,
coro: Callable[..., Coroutine[Any, Any, Any]],
event_name: str,
*args: Any,
**kwargs: Any,
) -> asyncio.Task:
wrapped = self._run_event(coro, event_name, *args, **kwargs) wrapped = self._run_event(coro, event_name, *args, **kwargs)
# Schedules the task # Schedules the task
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}') return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
@ -530,12 +553,14 @@ class Client:
self.dispatch('disconnect') self.dispatch('disconnect')
ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id) ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id)
continue continue
except (OSError, except (
HTTPException, OSError,
GatewayNotFound, HTTPException,
ConnectionClosed, GatewayNotFound,
aiohttp.ClientError, ConnectionClosed,
asyncio.TimeoutError) as exc: aiohttp.ClientError,
asyncio.TimeoutError,
) as exc:
self.dispatch('disconnect') self.dispatch('disconnect')
if not reconnect: if not reconnect:
@ -699,10 +724,10 @@ class Client:
self._connection._activity = None self._connection._activity = None
elif isinstance(value, BaseActivity): elif isinstance(value, BaseActivity):
# ConnectionState._activity is typehinted as ActivityPayload, we're passing Dict[str, Any] # ConnectionState._activity is typehinted as ActivityPayload, we're passing Dict[str, Any]
self._connection._activity = value.to_dict() # type: ignore self._connection._activity = value.to_dict() # type: ignore
else: else:
raise TypeError('activity must derive from BaseActivity.') raise TypeError('activity must derive from BaseActivity.')
@property @property
def status(self): def status(self):
""":class:`.Status`: """:class:`.Status`:
@ -777,7 +802,7 @@ class Client:
This is useful if you have a channel_id but don't want to do an API call This is useful if you have a channel_id but don't want to do an API call
to send messages to it. to send messages to it.
.. versionadded:: 2.0 .. versionadded:: 2.0
Parameters Parameters
@ -1030,8 +1055,10 @@ class Client:
future = self.loop.create_future() future = self.loop.create_future()
if check is None: if check is None:
def _check(*args): def _check(*args):
return True return True
check = _check check = _check
ev = event.lower() ev = event.lower()
@ -1273,7 +1300,7 @@ class Client:
""" """
code = utils.resolve_template(code) code = utils.resolve_template(code)
data = await self.http.get_template(code) data = await self.http.get_template(code)
return Template(data=data, state=self._connection) # type: ignore return Template(data=data, state=self._connection) # type: ignore
async def fetch_guild(self, guild_id: int, /) -> Guild: async def fetch_guild(self, guild_id: int, /) -> Guild:
"""|coro| """|coro|
@ -1402,7 +1429,9 @@ class Client:
# Invite management # Invite management
async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite: async def fetch_invite(
self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True
) -> Invite:
"""|coro| """|coro|
Gets an :class:`.Invite` from a discord.gg URL or ID. Gets an :class:`.Invite` from a discord.gg URL or ID.
@ -1604,13 +1633,13 @@ class Client:
if ch_type in (ChannelType.group, ChannelType.private): if ch_type in (ChannelType.group, ChannelType.private):
# the factory will be a DMChannel or GroupChannel here # the factory will be a DMChannel or GroupChannel here
channel = factory(me=self.user, data=data, state=self._connection) # type: ignore channel = factory(me=self.user, data=data, state=self._connection) # type: ignore
else: else:
# the factory can't be a DMChannel or GroupChannel here # the factory can't be a DMChannel or GroupChannel here
guild_id = int(data['guild_id']) # type: ignore guild_id = int(data['guild_id']) # type: ignore
guild = self.get_guild(guild_id) or Object(id=guild_id) guild = self.get_guild(guild_id) or Object(id=guild_id)
# GuildChannels expect a Guild, we may be passing an Object # GuildChannels expect a Guild, we may be passing an Object
channel = factory(guild=guild, state=self._connection, data=data) # type: ignore channel = factory(guild=guild, state=self._connection, data=data) # type: ignore
return channel return channel
@ -1661,7 +1690,7 @@ class Client:
""" """
data = await self.http.get_sticker(sticker_id) data = await self.http.get_sticker(sticker_id)
cls, _ = _sticker_factory(data['type']) # type: ignore cls, _ = _sticker_factory(data['type']) # type: ignore
return cls(state=self._connection, data=data) # type: ignore return cls(state=self._connection, data=data) # type: ignore
async def fetch_premium_sticker_packs(self) -> List[StickerPack]: async def fetch_premium_sticker_packs(self) -> List[StickerPack]:
"""|coro| """|coro|
@ -1716,7 +1745,7 @@ class Client:
This method should be used for when a view is comprised of components This method should be used for when a view is comprised of components
that last longer than the lifecycle of the program. that last longer than the lifecycle of the program.
.. versionadded:: 2.0 .. versionadded:: 2.0
Parameters Parameters
@ -1748,7 +1777,7 @@ class Client:
@property @property
def persistent_views(self) -> Sequence[View]: def persistent_views(self) -> Sequence[View]:
"""Sequence[:class:`.View`]: A sequence of persistent views added to the client. """Sequence[:class:`.View`]: A sequence of persistent views added to the client.
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
return self._connection.persistent_views return self._connection.persistent_views

44
discord/colour.py

@ -85,7 +85,7 @@ class Colour:
self.value: int = value self.value: int = value
def _get_byte(self, byte: int) -> int: def _get_byte(self, byte: int) -> int:
return (self.value >> (8 * byte)) & 0xff return (self.value >> (8 * byte)) & 0xFF
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
return isinstance(other, Colour) and self.value == other.value return isinstance(other, Colour) and self.value == other.value
@ -164,12 +164,12 @@ class Colour:
@classmethod @classmethod
def teal(cls: Type[CT]) -> CT: def teal(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x1abc9c``.""" """A factory method that returns a :class:`Colour` with a value of ``0x1abc9c``."""
return cls(0x1abc9c) return cls(0x1ABC9C)
@classmethod @classmethod
def dark_teal(cls: Type[CT]) -> CT: def dark_teal(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x11806a``.""" """A factory method that returns a :class:`Colour` with a value of ``0x11806a``."""
return cls(0x11806a) return cls(0x11806A)
@classmethod @classmethod
def brand_green(cls: Type[CT]) -> CT: def brand_green(cls: Type[CT]) -> CT:
@ -182,17 +182,17 @@ class Colour:
@classmethod @classmethod
def green(cls: Type[CT]) -> CT: def green(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x2ecc71``.""" """A factory method that returns a :class:`Colour` with a value of ``0x2ecc71``."""
return cls(0x2ecc71) return cls(0x2ECC71)
@classmethod @classmethod
def dark_green(cls: Type[CT]) -> CT: def dark_green(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x1f8b4c``.""" """A factory method that returns a :class:`Colour` with a value of ``0x1f8b4c``."""
return cls(0x1f8b4c) return cls(0x1F8B4C)
@classmethod @classmethod
def blue(cls: Type[CT]) -> CT: def blue(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x3498db``.""" """A factory method that returns a :class:`Colour` with a value of ``0x3498db``."""
return cls(0x3498db) return cls(0x3498DB)
@classmethod @classmethod
def dark_blue(cls: Type[CT]) -> CT: def dark_blue(cls: Type[CT]) -> CT:
@ -202,42 +202,42 @@ class Colour:
@classmethod @classmethod
def purple(cls: Type[CT]) -> CT: def purple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x9b59b6``.""" """A factory method that returns a :class:`Colour` with a value of ``0x9b59b6``."""
return cls(0x9b59b6) return cls(0x9B59B6)
@classmethod @classmethod
def dark_purple(cls: Type[CT]) -> CT: def dark_purple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x71368a``.""" """A factory method that returns a :class:`Colour` with a value of ``0x71368a``."""
return cls(0x71368a) return cls(0x71368A)
@classmethod @classmethod
def magenta(cls: Type[CT]) -> CT: def magenta(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xe91e63``.""" """A factory method that returns a :class:`Colour` with a value of ``0xe91e63``."""
return cls(0xe91e63) return cls(0xE91E63)
@classmethod @classmethod
def dark_magenta(cls: Type[CT]) -> CT: def dark_magenta(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xad1457``.""" """A factory method that returns a :class:`Colour` with a value of ``0xad1457``."""
return cls(0xad1457) return cls(0xAD1457)
@classmethod @classmethod
def gold(cls: Type[CT]) -> CT: def gold(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xf1c40f``.""" """A factory method that returns a :class:`Colour` with a value of ``0xf1c40f``."""
return cls(0xf1c40f) return cls(0xF1C40F)
@classmethod @classmethod
def dark_gold(cls: Type[CT]) -> CT: def dark_gold(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xc27c0e``.""" """A factory method that returns a :class:`Colour` with a value of ``0xc27c0e``."""
return cls(0xc27c0e) return cls(0xC27C0E)
@classmethod @classmethod
def orange(cls: Type[CT]) -> CT: def orange(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xe67e22``.""" """A factory method that returns a :class:`Colour` with a value of ``0xe67e22``."""
return cls(0xe67e22) return cls(0xE67E22)
@classmethod @classmethod
def dark_orange(cls: Type[CT]) -> CT: def dark_orange(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xa84300``.""" """A factory method that returns a :class:`Colour` with a value of ``0xa84300``."""
return cls(0xa84300) return cls(0xA84300)
@classmethod @classmethod
def brand_red(cls: Type[CT]) -> CT: def brand_red(cls: Type[CT]) -> CT:
@ -250,45 +250,45 @@ class Colour:
@classmethod @classmethod
def red(cls: Type[CT]) -> CT: def red(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``.""" """A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``."""
return cls(0xe74c3c) return cls(0xE74C3C)
@classmethod @classmethod
def dark_red(cls: Type[CT]) -> CT: def dark_red(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x992d22``.""" """A factory method that returns a :class:`Colour` with a value of ``0x992d22``."""
return cls(0x992d22) return cls(0x992D22)
@classmethod @classmethod
def lighter_grey(cls: Type[CT]) -> CT: def lighter_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``.""" """A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``."""
return cls(0x95a5a6) return cls(0x95A5A6)
lighter_gray = lighter_grey lighter_gray = lighter_grey
@classmethod @classmethod
def dark_grey(cls: Type[CT]) -> CT: def dark_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x607d8b``.""" """A factory method that returns a :class:`Colour` with a value of ``0x607d8b``."""
return cls(0x607d8b) return cls(0x607D8B)
dark_gray = dark_grey dark_gray = dark_grey
@classmethod @classmethod
def light_grey(cls: Type[CT]) -> CT: def light_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x979c9f``.""" """A factory method that returns a :class:`Colour` with a value of ``0x979c9f``."""
return cls(0x979c9f) return cls(0x979C9F)
light_gray = light_grey light_gray = light_grey
@classmethod @classmethod
def darker_grey(cls: Type[CT]) -> CT: def darker_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x546e7a``.""" """A factory method that returns a :class:`Colour` with a value of ``0x546e7a``."""
return cls(0x546e7a) return cls(0x546E7A)
darker_gray = darker_grey darker_gray = darker_grey
@classmethod @classmethod
def og_blurple(cls: Type[CT]) -> CT: def og_blurple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x7289da``.""" """A factory method that returns a :class:`Colour` with a value of ``0x7289da``."""
return cls(0x7289da) return cls(0x7289DA)
@classmethod @classmethod
def blurple(cls: Type[CT]) -> CT: def blurple(cls: Type[CT]) -> CT:
@ -298,7 +298,7 @@ class Colour:
@classmethod @classmethod
def greyple(cls: Type[CT]) -> CT: def greyple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x99aab5``.""" """A factory method that returns a :class:`Colour` with a value of ``0x99aab5``."""
return cls(0x99aab5) return cls(0x99AAB5)
@classmethod @classmethod
def dark_theme(cls: Type[CT]) -> CT: def dark_theme(cls: Type[CT]) -> CT:

10
discord/context_managers.py

@ -34,9 +34,12 @@ if TYPE_CHECKING:
TypingT = TypeVar('TypingT', bound='Typing') TypingT = TypeVar('TypingT', bound='Typing')
# fmt: off
__all__ = ( __all__ = (
'Typing', 'Typing',
) )
# fmt: on
def _typing_done_callback(fut: asyncio.Future) -> None: def _typing_done_callback(fut: asyncio.Future) -> None:
# just retrieve any exception and call it a day # just retrieve any exception and call it a day
@ -45,6 +48,7 @@ def _typing_done_callback(fut: asyncio.Future) -> None:
except (asyncio.CancelledError, Exception): except (asyncio.CancelledError, Exception):
pass pass
class Typing: class Typing:
def __init__(self, messageable: Messageable) -> None: def __init__(self, messageable: Messageable) -> None:
self.loop: asyncio.AbstractEventLoop = messageable._state.loop self.loop: asyncio.AbstractEventLoop = messageable._state.loop
@ -67,7 +71,8 @@ class Typing:
self.task.add_done_callback(_typing_done_callback) self.task.add_done_callback(_typing_done_callback)
return self return self
def __exit__(self, def __exit__(
self,
exc_type: Optional[Type[BaseException]], exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException], exc_value: Optional[BaseException],
traceback: Optional[TracebackType], traceback: Optional[TracebackType],
@ -79,7 +84,8 @@ class Typing:
await channel._state.http.send_typing(channel.id) await channel._state.http.send_typing(channel.id)
return self.__enter__() return self.__enter__()
async def __aexit__(self, async def __aexit__(
self,
exc_type: Optional[Type[BaseException]], exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException], exc_value: Optional[BaseException],
traceback: Optional[TracebackType], traceback: Optional[TracebackType],

6
discord/embeds.py

@ -30,9 +30,11 @@ from typing import Any, Dict, Final, List, Mapping, Protocol, TYPE_CHECKING, Typ
from . import utils from . import utils
from .colour import Colour from .colour import Colour
# fmt: off
__all__ = ( __all__ = (
'Embed', 'Embed',
) )
# fmt: on
class _EmptyEmbed: class _EmptyEmbed:
@ -366,7 +368,7 @@ class Embed:
self._footer['icon_url'] = str(icon_url) self._footer['icon_url'] = str(icon_url)
return self return self
def remove_footer(self: E) -> E: def remove_footer(self: E) -> E:
"""Clears embed's footer information. """Clears embed's footer information.
@ -381,7 +383,7 @@ class Embed:
pass pass
return self return self
@property @property
def image(self) -> _EmbedMediaProxy: def image(self) -> _EmbedMediaProxy:
"""Returns an ``EmbedProxy`` denoting the image contents. """Returns an ``EmbedProxy`` denoting the image contents.

2
discord/emoji.py

@ -30,9 +30,11 @@ from .utils import SnowflakeList, snowflake_time, MISSING
from .partial_emoji import _EmojiTag, PartialEmoji from .partial_emoji import _EmojiTag, PartialEmoji
from .user import User from .user import User
# fmt: off
__all__ = ( __all__ = (
'Emoji', 'Emoji',
) )
# fmt: on
if TYPE_CHECKING: if TYPE_CHECKING:
from .types.emoji import Emoji as EmojiPayload from .types.emoji import Emoji as EmojiPayload

1
discord/enums.py

@ -70,6 +70,7 @@ def _create_value_cls(name, comparable):
cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value
return cls return cls
def _is_descriptor(obj): def _is_descriptor(obj):
return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__') return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__')

4
discord/ext/commands/_types.py

@ -39,7 +39,9 @@ CoroFunc = Callable[..., Coro[Any]]
Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]] Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]]
Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]] Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]] Error = Union[
Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]
]
# This is merely a tag type to avoid circular import issues. # This is merely a tag type to avoid circular import issues.

30
discord/ext/commands/bot.py

@ -66,6 +66,7 @@ T = TypeVar('T')
CFT = TypeVar('CFT', bound='CoroFunc') CFT = TypeVar('CFT', bound='CoroFunc')
CXT = TypeVar('CXT', bound='Context') CXT = TypeVar('CXT', bound='Context')
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]: def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned. """A callable that implements a command prefix equivalent to being mentioned.
@ -74,6 +75,7 @@ def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
# bot.user will never be None when this is called # bot.user will never be None when this is called
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]: def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
"""A callable that implements when mentioned or other prefixes provided. """A callable that implements when mentioned or other prefixes provided.
@ -103,6 +105,7 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
---------- ----------
:func:`.when_mentioned` :func:`.when_mentioned`
""" """
def inner(bot, msg): def inner(bot, msg):
r = list(prefixes) r = list(prefixes)
r = when_mentioned(bot, msg) + r r = when_mentioned(bot, msg) + r
@ -110,15 +113,19 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
return inner return inner
def _is_submodule(parent: str, child: str) -> bool: def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + ".") return parent == child or child.startswith(parent + ".")
class _DefaultRepr: class _DefaultRepr:
def __repr__(self): def __repr__(self):
return '<default-help-command>' return '<default-help-command>'
_default = _DefaultRepr() _default = _DefaultRepr()
class BotBase(GroupMixin): class BotBase(GroupMixin):
def __init__(self, command_prefix, help_command=_default, description=None, **options): def __init__(self, command_prefix, help_command=_default, description=None, **options):
super().__init__(**options) super().__init__(**options)
@ -833,11 +840,13 @@ class BotBase(GroupMixin):
raise errors.ExtensionNotLoaded(name) raise errors.ExtensionNotLoaded(name)
# get the previous module states from sys modules # get the previous module states from sys modules
# fmt: off
modules = { modules = {
name: module name: module
for name, module in sys.modules.items() for name, module in sys.modules.items()
if _is_submodule(lib.__name__, name) if _is_submodule(lib.__name__, name)
} }
# fmt: on
try: try:
# Unload and then load the module... # Unload and then load the module...
@ -913,8 +922,10 @@ class BotBase(GroupMixin):
if isinstance(ret, collections.abc.Iterable): if isinstance(ret, collections.abc.Iterable):
raise raise
raise TypeError("command_prefix must be plain string, iterable of strings, or callable " raise TypeError(
f"returning either of these, not {ret.__class__.__name__}") "command_prefix must be plain string, iterable of strings, or callable "
f"returning either of these, not {ret.__class__.__name__}"
)
if not ret: if not ret:
raise ValueError("Iterable command_prefix must contain at least one prefix") raise ValueError("Iterable command_prefix must contain at least one prefix")
@ -974,14 +985,17 @@ class BotBase(GroupMixin):
except TypeError: except TypeError:
if not isinstance(prefix, list): if not isinstance(prefix, list):
raise TypeError("get_prefix must return either a string or a list of string, " raise TypeError(
f"not {prefix.__class__.__name__}") "get_prefix must return either a string or a list of string, " f"not {prefix.__class__.__name__}"
)
# It's possible a bad command_prefix got us here. # It's possible a bad command_prefix got us here.
for value in prefix: for value in prefix:
if not isinstance(value, str): if not isinstance(value, str):
raise TypeError("Iterable command_prefix or list returned from get_prefix must " raise TypeError(
f"contain only strings, not {value.__class__.__name__}") "Iterable command_prefix or list returned from get_prefix must "
f"contain only strings, not {value.__class__.__name__}"
)
# Getting here shouldn't happen # Getting here shouldn't happen
raise raise
@ -1053,6 +1067,7 @@ class BotBase(GroupMixin):
async def on_message(self, message): async def on_message(self, message):
await self.process_commands(message) await self.process_commands(message)
class Bot(BotBase, discord.Client): class Bot(BotBase, discord.Client):
"""Represents a discord bot. """Represents a discord bot.
@ -1123,10 +1138,13 @@ class Bot(BotBase, discord.Client):
.. versionadded:: 1.7 .. versionadded:: 1.7
""" """
pass pass
class AutoShardedBot(BotBase, discord.AutoShardedClient): class AutoShardedBot(BotBase, discord.AutoShardedClient):
"""This is similar to :class:`.Bot` except that it is inherited from """This is similar to :class:`.Bot` except that it is inherited from
:class:`discord.AutoShardedClient` instead. :class:`discord.AutoShardedClient` instead.
""" """
pass pass

14
discord/ext/commands/cog.py

@ -45,6 +45,7 @@ FuncT = TypeVar('FuncT', bound=Callable[..., Any])
MISSING: Any = discord.utils.MISSING MISSING: Any = discord.utils.MISSING
class CogMeta(type): class CogMeta(type):
"""A metaclass for defining a cog. """A metaclass for defining a cog.
@ -104,6 +105,7 @@ class CogMeta(type):
async def bar(self, ctx): async def bar(self, ctx):
pass # hidden -> False pass # hidden -> False
""" """
__cog_name__: str __cog_name__: str
__cog_settings__: Dict[str, Any] __cog_settings__: Dict[str, Any]
__cog_commands__: List[Command] __cog_commands__: List[Command]
@ -150,7 +152,7 @@ class CogMeta(type):
raise TypeError(no_bot_cog.format(base, elem)) raise TypeError(no_bot_cog.format(base, elem))
listeners[elem] = value listeners[elem] = value
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__ new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
listeners_as_list = [] listeners_as_list = []
for listener in listeners.values(): for listener in listeners.values():
@ -169,10 +171,12 @@ class CogMeta(type):
def qualified_name(cls) -> str: def qualified_name(cls) -> str:
return cls.__cog_name__ return cls.__cog_name__
def _cog_special_method(func: FuncT) -> FuncT: def _cog_special_method(func: FuncT) -> FuncT:
func.__cog_special_method__ = None func.__cog_special_method__ = None
return func return func
class Cog(metaclass=CogMeta): class Cog(metaclass=CogMeta):
"""The base class that all cogs must inherit from. """The base class that all cogs must inherit from.
@ -183,6 +187,7 @@ class Cog(metaclass=CogMeta):
When inheriting from this class, the options shown in :class:`CogMeta` When inheriting from this class, the options shown in :class:`CogMeta`
are equally valid here. are equally valid here.
""" """
__cog_name__: ClassVar[str] __cog_name__: ClassVar[str]
__cog_settings__: ClassVar[Dict[str, Any]] __cog_settings__: ClassVar[Dict[str, Any]]
__cog_commands__: ClassVar[List[Command]] __cog_commands__: ClassVar[List[Command]]
@ -199,10 +204,7 @@ class Cog(metaclass=CogMeta):
# r.e type ignore, type-checker complains about overriding a ClassVar # r.e type ignore, type-checker complains about overriding a ClassVar
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore
lookup = { lookup = {cmd.qualified_name: cmd for cmd in self.__cog_commands__}
cmd.qualified_name: cmd
for cmd in self.__cog_commands__
}
# Update the Command instances dynamically as well # Update the Command instances dynamically as well
for command in self.__cog_commands__: for command in self.__cog_commands__:
@ -255,6 +257,7 @@ class Cog(metaclass=CogMeta):
A command or group from the cog. A command or group from the cog.
""" """
from .core import GroupMixin from .core import GroupMixin
for command in self.__cog_commands__: for command in self.__cog_commands__:
if command.parent is None: if command.parent is None:
yield command yield command
@ -315,6 +318,7 @@ class Cog(metaclass=CogMeta):
# to pick it up but the metaclass unfurls the function and # to pick it up but the metaclass unfurls the function and
# thus the assignments need to be on the actual function # thus the assignments need to be on the actual function
return func return func
return decorator return decorator
def has_error_handler(self) -> bool: def has_error_handler(self) -> bool:

7
discord/ext/commands/context.py

@ -49,9 +49,11 @@ if TYPE_CHECKING:
from .help import HelpCommand from .help import HelpCommand
from .view import StringView from .view import StringView
# fmt: off
__all__ = ( __all__ = (
'Context', 'Context',
) )
# fmt: on
MISSING: Any = discord.utils.MISSING MISSING: Any = discord.utils.MISSING
@ -122,7 +124,8 @@ class Context(discord.abc.Messageable, Generic[BotT]):
or invoked. or invoked.
""" """
def __init__(self, def __init__(
self,
*, *,
message: Message, message: Message,
bot: BotT, bot: BotT,
@ -237,7 +240,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
view.index = len(self.prefix or '') view.index = len(self.prefix or '')
view.previous = 0 view.previous = 0
self.invoked_parents = [] self.invoked_parents = []
self.invoked_with = view.get_word() # advance to get the root command self.invoked_with = view.get_word() # advance to get the root command
else: else:
to_call = cmd to_call = cmd

20
discord/ext/commands/cooldowns.py

@ -48,14 +48,15 @@ __all__ = (
C = TypeVar('C', bound='CooldownMapping') C = TypeVar('C', bound='CooldownMapping')
MC = TypeVar('MC', bound='MaxConcurrency') MC = TypeVar('MC', bound='MaxConcurrency')
class BucketType(Enum): class BucketType(Enum):
default = 0 default = 0
user = 1 user = 1
guild = 2 guild = 2
channel = 3 channel = 3
member = 4 member = 4
category = 5 category = 5
role = 6 role = 6
def get_key(self, msg: Message) -> Any: def get_key(self, msg: Message) -> Any:
if self is BucketType.user: if self is BucketType.user:
@ -192,6 +193,7 @@ class Cooldown:
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>' return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
class CooldownMapping: class CooldownMapping:
def __init__( def __init__(
self, self,
@ -256,12 +258,12 @@ class CooldownMapping:
bucket = self.get_bucket(message, current) bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current) return bucket.update_rate_limit(current)
class DynamicCooldownMapping(CooldownMapping):
class DynamicCooldownMapping(CooldownMapping):
def __init__( def __init__(
self, self,
factory: Callable[[Message], Cooldown], factory: Callable[[Message], Cooldown],
type: Callable[[Message], Any] type: Callable[[Message], Any],
) -> None: ) -> None:
super().__init__(None, type) super().__init__(None, type)
self._factory: Callable[[Message], Cooldown] = factory self._factory: Callable[[Message], Cooldown] = factory
@ -278,6 +280,7 @@ class DynamicCooldownMapping(CooldownMapping):
def create_bucket(self, message: Message) -> Cooldown: def create_bucket(self, message: Message) -> Cooldown:
return self._factory(message) return self._factory(message)
class _Semaphore: class _Semaphore:
"""This class is a version of a semaphore. """This class is a version of a semaphore.
@ -337,6 +340,7 @@ class _Semaphore:
self.value += 1 self.value += 1
self.wake_up() self.wake_up()
class MaxConcurrency: class MaxConcurrency:
__slots__ = ('number', 'per', 'wait', '_mapping') __slots__ = ('number', 'per', 'wait', '_mapping')

162
discord/ext/commands/core.py

@ -93,7 +93,7 @@ __all__ = (
'is_owner', 'is_owner',
'is_nsfw', 'is_nsfw',
'has_guild_permissions', 'has_guild_permissions',
'bot_has_guild_permissions' 'bot_has_guild_permissions',
) )
MISSING: Any = discord.utils.MISSING MISSING: Any = discord.utils.MISSING
@ -112,6 +112,7 @@ if TYPE_CHECKING:
else: else:
P = TypeVar('P') P = TypeVar('P')
def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
partial = functools.partial partial = functools.partial
while True: while True:
@ -158,8 +159,10 @@ def wrap_callback(coro):
except Exception as exc: except Exception as exc:
raise CommandInvokeError(exc) from exc raise CommandInvokeError(exc) from exc
return ret return ret
return wrapped return wrapped
def hooked_wrapped_callback(command, ctx, coro): def hooked_wrapped_callback(command, ctx, coro):
@functools.wraps(coro) @functools.wraps(coro)
async def wrapped(*args, **kwargs): async def wrapped(*args, **kwargs):
@ -180,6 +183,7 @@ def hooked_wrapped_callback(command, ctx, coro):
await command.call_after_hooks(ctx) await command.call_after_hooks(ctx)
return ret return ret
return wrapped return wrapped
@ -202,6 +206,7 @@ class _CaseInsensitiveDict(dict):
def __setitem__(self, k, v): def __setitem__(self, k, v):
super().__setitem__(k.casefold(), v) super().__setitem__(k.casefold(), v)
class Command(_BaseCommand, Generic[CogT, P, T]): class Command(_BaseCommand, Generic[CogT, P, T]):
r"""A class that implements the protocol for a bot text command. r"""A class that implements the protocol for a bot text command.
@ -269,8 +274,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
which calls converters. If ``False`` then cooldown processing is done which calls converters. If ``False`` then cooldown processing is done
first and then the converters are called second. Defaults to ``False``. first and then the converters are called second. Defaults to ``False``.
extras: :class:`dict` extras: :class:`dict`
A dict of user provided extras to attach to the Command. A dict of user provided extras to attach to the Command.
.. note:: .. note::
This object may be copied by the library. This object may be copied by the library.
@ -295,10 +300,14 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.__original_kwargs__ = kwargs.copy() self.__original_kwargs__ = kwargs.copy()
return self return self
def __init__(self, func: Union[ def __init__(
self,
func: Union[
Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]],
], **kwargs: Any): ],
**kwargs: Any,
):
if not asyncio.iscoroutinefunction(func): if not asyncio.iscoroutinefunction(func):
raise TypeError('Callback must be a coroutine.') raise TypeError('Callback must be a coroutine.')
@ -344,7 +353,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
cooldown = func.__commands_cooldown__ cooldown = func.__commands_cooldown__
except AttributeError: except AttributeError:
cooldown = kwargs.get('cooldown') cooldown = kwargs.get('cooldown')
if cooldown is None: if cooldown is None:
buckets = CooldownMapping(cooldown, BucketType.default) buckets = CooldownMapping(cooldown, BucketType.default)
elif isinstance(cooldown, CooldownMapping): elif isinstance(cooldown, CooldownMapping):
@ -386,17 +395,19 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.after_invoke(after_invoke) self.after_invoke(after_invoke)
@property @property
def callback(self) -> Union[ def callback(
Callable[Concatenate[CogT, Context, P], Coro[T]], self,
Callable[Concatenate[Context, P], Coro[T]], ) -> Union[Callable[Concatenate[CogT, Context, P], Coro[T]], Callable[Concatenate[Context, P], Coro[T]],]:
]:
return self._callback return self._callback
@callback.setter @callback.setter
def callback(self, function: Union[ def callback(
self,
function: Union[
Callable[Concatenate[CogT, Context, P], Coro[T]], Callable[Concatenate[CogT, Context, P], Coro[T]],
Callable[Concatenate[Context, P], Coro[T]], Callable[Concatenate[Context, P], Coro[T]],
]) -> None: ],
) -> None:
self._callback = function self._callback = function
unwrap = unwrap_function(function) unwrap = unwrap_function(function)
self.module = unwrap.__module__ self.module = unwrap.__module__
@ -561,7 +572,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if view.eof: if view.eof:
if param.kind == param.VAR_POSITIONAL: if param.kind == param.VAR_POSITIONAL:
raise RuntimeError() # break the loop raise RuntimeError() # break the loop
if required: if required:
if self._is_typing_optional(param.annotation): if self._is_typing_optional(param.annotation):
return None return None
@ -616,7 +627,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
value = await run_converters(ctx, converter, argument, param) # type: ignore value = await run_converters(ctx, converter, argument, param) # type: ignore
except (CommandError, ArgumentParsingError): except (CommandError, ArgumentParsingError):
view.index = previous view.index = previous
raise RuntimeError() from None # break loop raise RuntimeError() from None # break loop
else: else:
return value return value
@ -653,9 +664,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
entries = [] entries = []
command = self command = self
# command.parent is type-hinted as GroupMixin some attributes are resolved via MRO # command.parent is type-hinted as GroupMixin some attributes are resolved via MRO
while command.parent is not None: # type: ignore while command.parent is not None: # type: ignore
command = command.parent # type: ignore command = command.parent # type: ignore
entries.append(command.name) # type: ignore entries.append(command.name) # type: ignore
return ' '.join(reversed(entries)) return ' '.join(reversed(entries))
@ -671,8 +682,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
""" """
entries = [] entries = []
command = self command = self
while command.parent is not None: # type: ignore while command.parent is not None: # type: ignore
command = command.parent # type: ignore command = command.parent # type: ignore
entries.append(command) entries.append(command)
return entries return entries
@ -1061,8 +1072,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
# do [name] since [name=None] or [name=] are not exactly useful for the user. # do [name] since [name=None] or [name=] are not exactly useful for the user.
should_print = param.default if isinstance(param.default, str) else param.default is not None should_print = param.default if isinstance(param.default, str) else param.default is not None
if should_print: if should_print:
result.append(f'[{name}={param.default}]' if not greedy else result.append(f'[{name}={param.default}]' if not greedy else f'[{name}={param.default}]...')
f'[{name}={param.default}]...')
continue continue
else: else:
result.append(f'[{name}]') result.append(f'[{name}]')
@ -1135,6 +1145,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
finally: finally:
ctx.command = original ctx.command = original
class GroupMixin(Generic[CogT]): class GroupMixin(Generic[CogT]):
"""A mixin that implements common functionality for classes that behave """A mixin that implements common functionality for classes that behave
similar to :class:`.Group` and are allowed to register commands. similar to :class:`.Group` and are allowed to register commands.
@ -1147,6 +1158,7 @@ class GroupMixin(Generic[CogT]):
case_insensitive: :class:`bool` case_insensitive: :class:`bool`
Whether the commands should be case insensitive. Defaults to ``False``. Whether the commands should be case insensitive. Defaults to ``False``.
""" """
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
case_insensitive = kwargs.get('case_insensitive', False) case_insensitive = kwargs.get('case_insensitive', False)
self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {} self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {}
@ -1320,7 +1332,9 @@ class GroupMixin(Generic[CogT]):
Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]],
] ]
], Command[CogT, P, T]]: ],
Command[CogT, P, T],
]:
... ...
@overload @overload
@ -1348,6 +1362,7 @@ class GroupMixin(Generic[CogT]):
Callable[..., :class:`Command`] Callable[..., :class:`Command`]
A decorator that converts the provided method into a Command, adds it to the bot, then returns it. A decorator that converts the provided method into a Command, adds it to the bot, then returns it.
""" """
def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> CommandT: def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> CommandT:
kwargs.setdefault('parent', self) kwargs.setdefault('parent', self)
result = command(name=name, cls=cls, *args, **kwargs)(func) result = command(name=name, cls=cls, *args, **kwargs)(func)
@ -1363,12 +1378,10 @@ class GroupMixin(Generic[CogT]):
cls: Type[Group[CogT, P, T]] = ..., cls: Type[Group[CogT, P, T]] = ...,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Callable[[ ) -> Callable[
Union[ [Union[Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]]]],
Callable[Concatenate[CogT, ContextT, P], Coro[T]], Group[CogT, P, T],
Callable[Concatenate[ContextT, P], Coro[T]] ]:
]
], Group[CogT, P, T]]:
... ...
@overload @overload
@ -1396,6 +1409,7 @@ class GroupMixin(Generic[CogT]):
Callable[..., :class:`Group`] Callable[..., :class:`Group`]
A decorator that converts the provided method into a Group, adds it to the bot, then returns it. A decorator that converts the provided method into a Group, adds it to the bot, then returns it.
""" """
def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> GroupT: def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> GroupT:
kwargs.setdefault('parent', self) kwargs.setdefault('parent', self)
result = group(name=name, cls=cls, *args, **kwargs)(func) result = group(name=name, cls=cls, *args, **kwargs)(func)
@ -1404,6 +1418,7 @@ class GroupMixin(Generic[CogT]):
return decorator return decorator
class Group(GroupMixin[CogT], Command[CogT, P, T]): class Group(GroupMixin[CogT], Command[CogT, P, T]):
"""A class that implements a grouping protocol for commands to be """A class that implements a grouping protocol for commands to be
executed as subcommands. executed as subcommands.
@ -1426,6 +1441,7 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
Indicates if the group's commands should be case insensitive. Indicates if the group's commands should be case insensitive.
Defaults to ``False``. Defaults to ``False``.
""" """
def __init__(self, *args: Any, **attrs: Any) -> None: def __init__(self, *args: Any, **attrs: Any) -> None:
self.invoke_without_command: bool = attrs.pop('invoke_without_command', False) self.invoke_without_command: bool = attrs.pop('invoke_without_command', False)
super().__init__(*args, **attrs) super().__init__(*args, **attrs)
@ -1514,8 +1530,10 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
view.previous = previous view.previous = previous
await super().reinvoke(ctx, call_hooks=call_hooks) await super().reinvoke(ctx, call_hooks=call_hooks)
# Decorators # Decorators
@overload @overload
def command( def command(
name: str = ..., name: str = ...,
@ -1527,10 +1545,12 @@ def command(
Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]],
] ]
] ],
, Command[CogT, P, T]]: Command[CogT, P, T],
]:
... ...
@overload @overload
def command( def command(
name: str = ..., name: str = ...,
@ -1542,22 +1562,25 @@ def command(
Callable[Concatenate[CogT, ContextT, P], Coro[Any]], Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
Callable[Concatenate[ContextT, P], Coro[Any]], Callable[Concatenate[ContextT, P], Coro[Any]],
] ]
] ],
, CommandT]: CommandT,
]:
... ...
def command( def command(
name: str = MISSING, name: str = MISSING,
cls: Type[CommandT] = MISSING, cls: Type[CommandT] = MISSING,
**attrs: Any **attrs: Any,
) -> Callable[ ) -> Callable[
[ [
Union[ Union[
Callable[Concatenate[ContextT, P], Coro[Any]], Callable[Concatenate[ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[CogT, ContextT, P], Coro[T]],
] ]
] ],
, Union[Command[CogT, P, T], CommandT]]: Union[Command[CogT, P, T], CommandT],
]:
"""A decorator that transforms a function into a :class:`.Command` """A decorator that transforms a function into a :class:`.Command`
or if called with :func:`.group`, :class:`.Group`. or if called with :func:`.group`, :class:`.Group`.
@ -1590,16 +1613,19 @@ def command(
if cls is MISSING: if cls is MISSING:
cls = Command # type: ignore cls = Command # type: ignore
def decorator(func: Union[ def decorator(
func: Union[
Callable[Concatenate[ContextT, P], Coro[Any]], Callable[Concatenate[ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[Any]], Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
]) -> CommandT: ]
) -> CommandT:
if isinstance(func, Command): if isinstance(func, Command):
raise TypeError('Callback is already a command.') raise TypeError('Callback is already a command.')
return cls(func, name=name, **attrs) return cls(func, name=name, **attrs)
return decorator return decorator
@overload @overload
def group( def group(
name: str = ..., name: str = ...,
@ -1611,10 +1637,12 @@ def group(
Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]],
] ]
] ],
, Group[CogT, P, T]]: Group[CogT, P, T],
]:
... ...
@overload @overload
def group( def group(
name: str = ..., name: str = ...,
@ -1626,10 +1654,12 @@ def group(
Callable[Concatenate[CogT, ContextT, P], Coro[Any]], Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
Callable[Concatenate[ContextT, P], Coro[Any]], Callable[Concatenate[ContextT, P], Coro[Any]],
] ]
] ],
, GroupT]: GroupT,
]:
... ...
def group( def group(
name: str = MISSING, name: str = MISSING,
cls: Type[GroupT] = MISSING, cls: Type[GroupT] = MISSING,
@ -1640,8 +1670,9 @@ def group(
Callable[Concatenate[ContextT, P], Coro[Any]], Callable[Concatenate[ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[CogT, ContextT, P], Coro[T]],
] ]
] ],
, Union[Group[CogT, P, T], GroupT]]: Union[Group[CogT, P, T], GroupT],
]:
"""A decorator that transforms a function into a :class:`.Group`. """A decorator that transforms a function into a :class:`.Group`.
This is similar to the :func:`.command` decorator but the ``cls`` This is similar to the :func:`.command` decorator but the ``cls``
@ -1654,6 +1685,7 @@ def group(
cls = Group # type: ignore cls = Group # type: ignore
return command(name=name, cls=cls, **attrs) # type: ignore return command(name=name, cls=cls, **attrs) # type: ignore
def check(predicate: Check) -> Callable[[T], T]: def check(predicate: Check) -> Callable[[T], T]:
r"""A decorator that adds a check to the :class:`.Command` or its r"""A decorator that adds a check to the :class:`.Command` or its
subclasses. These checks could be accessed via :attr:`.Command.checks`. subclasses. These checks could be accessed via :attr:`.Command.checks`.
@ -1739,13 +1771,16 @@ def check(predicate: Check) -> Callable[[T], T]:
if inspect.iscoroutinefunction(predicate): if inspect.iscoroutinefunction(predicate):
decorator.predicate = predicate decorator.predicate = predicate
else: else:
@functools.wraps(predicate) @functools.wraps(predicate)
async def wrapper(ctx): async def wrapper(ctx):
return predicate(ctx) # type: ignore return predicate(ctx) # type: ignore
decorator.predicate = wrapper decorator.predicate = wrapper
return decorator # type: ignore return decorator # type: ignore
def check_any(*checks: Check) -> Callable[[T], T]: def check_any(*checks: Check) -> Callable[[T], T]:
r"""A :func:`check` that is added that checks if any of the checks passed r"""A :func:`check` that is added that checks if any of the checks passed
will pass, i.e. using logical OR. will pass, i.e. using logical OR.
@ -1814,6 +1849,7 @@ def check_any(*checks: Check) -> Callable[[T], T]:
return check(predicate) return check(predicate)
def has_role(item: Union[int, str]) -> Callable[[T], T]: def has_role(item: Union[int, str]) -> Callable[[T], T]:
"""A :func:`.check` that is added that checks if the member invoking the """A :func:`.check` that is added that checks if the member invoking the
command has the role specified via the name or ID specified. command has the role specified via the name or ID specified.
@ -1856,6 +1892,7 @@ def has_role(item: Union[int, str]) -> Callable[[T], T]:
return check(predicate) return check(predicate)
def has_any_role(*items: Union[int, str]) -> Callable[[T], T]: def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
r"""A :func:`.check` that is added that checks if the member invoking the r"""A :func:`.check` that is added that checks if the member invoking the
command has **any** of the roles specified. This means that if they have command has **any** of the roles specified. This means that if they have
@ -1887,6 +1924,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
async def cool(ctx): async def cool(ctx):
await ctx.send('You are cool indeed') await ctx.send('You are cool indeed')
""" """
def predicate(ctx): def predicate(ctx):
if ctx.guild is None: if ctx.guild is None:
raise NoPrivateMessage() raise NoPrivateMessage()
@ -1899,6 +1937,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
return check(predicate) return check(predicate)
def bot_has_role(item: int) -> Callable[[T], T]: def bot_has_role(item: int) -> Callable[[T], T]:
"""Similar to :func:`.has_role` except checks if the bot itself has the """Similar to :func:`.has_role` except checks if the bot itself has the
role. role.
@ -1925,8 +1964,10 @@ def bot_has_role(item: int) -> Callable[[T], T]:
if role is None: if role is None:
raise BotMissingRole(item) raise BotMissingRole(item)
return True return True
return check(predicate) return check(predicate)
def bot_has_any_role(*items: int) -> Callable[[T], T]: def bot_has_any_role(*items: int) -> Callable[[T], T]:
"""Similar to :func:`.has_any_role` except checks if the bot itself has """Similar to :func:`.has_any_role` except checks if the bot itself has
any of the roles listed. any of the roles listed.
@ -1940,6 +1981,7 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]:
Raise :exc:`.BotMissingAnyRole` or :exc:`.NoPrivateMessage` Raise :exc:`.BotMissingAnyRole` or :exc:`.NoPrivateMessage`
instead of generic checkfailure instead of generic checkfailure
""" """
def predicate(ctx): def predicate(ctx):
if ctx.guild is None: if ctx.guild is None:
raise NoPrivateMessage() raise NoPrivateMessage()
@ -1949,8 +1991,10 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]:
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items): if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
return True return True
raise BotMissingAnyRole(list(items)) raise BotMissingAnyRole(list(items))
return check(predicate) return check(predicate)
def has_permissions(**perms: bool) -> Callable[[T], T]: def has_permissions(**perms: bool) -> Callable[[T], T]:
"""A :func:`.check` that is added that checks if the member has all of """A :func:`.check` that is added that checks if the member has all of
the permissions necessary. the permissions necessary.
@ -1998,6 +2042,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate) return check(predicate)
def bot_has_permissions(**perms: bool) -> Callable[[T], T]: def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
"""Similar to :func:`.has_permissions` except checks if the bot itself has """Similar to :func:`.has_permissions` except checks if the bot itself has
the permissions listed. the permissions listed.
@ -2024,6 +2069,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate) return check(predicate)
def has_guild_permissions(**perms: bool) -> Callable[[T], T]: def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
"""Similar to :func:`.has_permissions`, but operates on guild wide """Similar to :func:`.has_permissions`, but operates on guild wide
permissions instead of the current channel permissions. permissions instead of the current channel permissions.
@ -2052,6 +2098,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate) return check(predicate)
def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
"""Similar to :func:`.has_guild_permissions`, but checks the bot """Similar to :func:`.has_guild_permissions`, but checks the bot
members guild permissions. members guild permissions.
@ -2077,6 +2124,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate) return check(predicate)
def dm_only() -> Callable[[T], T]: def dm_only() -> Callable[[T], T]:
"""A :func:`.check` that indicates this command must only be used in a """A :func:`.check` that indicates this command must only be used in a
DM context. Only private messages are allowed when DM context. Only private messages are allowed when
@ -2095,6 +2143,7 @@ def dm_only() -> Callable[[T], T]:
return check(predicate) return check(predicate)
def guild_only() -> Callable[[T], T]: def guild_only() -> Callable[[T], T]:
"""A :func:`.check` that indicates this command must only be used in a """A :func:`.check` that indicates this command must only be used in a
guild context only. Basically, no private messages are allowed when guild context only. Basically, no private messages are allowed when
@ -2111,6 +2160,7 @@ def guild_only() -> Callable[[T], T]:
return check(predicate) return check(predicate)
def is_owner() -> Callable[[T], T]: def is_owner() -> Callable[[T], T]:
"""A :func:`.check` that checks if the person invoking this command is the """A :func:`.check` that checks if the person invoking this command is the
owner of the bot. owner of the bot.
@ -2128,6 +2178,7 @@ def is_owner() -> Callable[[T], T]:
return check(predicate) return check(predicate)
def is_nsfw() -> Callable[[T], T]: def is_nsfw() -> Callable[[T], T]:
"""A :func:`.check` that checks if the channel is a NSFW channel. """A :func:`.check` that checks if the channel is a NSFW channel.
@ -2139,14 +2190,21 @@ def is_nsfw() -> Callable[[T], T]:
Raise :exc:`.NSFWChannelRequired` instead of generic :exc:`.CheckFailure`. Raise :exc:`.NSFWChannelRequired` instead of generic :exc:`.CheckFailure`.
DM channels will also now pass this check. DM channels will also now pass this check.
""" """
def pred(ctx: Context) -> bool: def pred(ctx: Context) -> bool:
ch = ctx.channel ch = ctx.channel
if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()): if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()):
return True return True
raise NSFWChannelRequired(ch) # type: ignore raise NSFWChannelRequired(ch) # type: ignore
return check(pred) return check(pred)
def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default) -> Callable[[T], T]:
def cooldown(
rate: int,
per: float,
type: Union[BucketType, Callable[[Message], Any]] = BucketType.default,
) -> Callable[[T], T]:
"""A decorator that adds a cooldown to a :class:`.Command` """A decorator that adds a cooldown to a :class:`.Command`
A cooldown allows a command to only be used a specific amount A cooldown allows a command to only be used a specific amount
@ -2179,9 +2237,14 @@ def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message],
else: else:
func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type) func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type)
return func return func
return decorator # type: ignore return decorator # type: ignore
def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default) -> Callable[[T], T]:
def dynamic_cooldown(
cooldown: Union[BucketType, Callable[[Message], Any]],
type: BucketType = BucketType.default,
) -> Callable[[T], T]:
"""A decorator that adds a dynamic cooldown to a :class:`.Command` """A decorator that adds a dynamic cooldown to a :class:`.Command`
This differs from :func:`.cooldown` in that it takes a function that This differs from :func:`.cooldown` in that it takes a function that
@ -2219,8 +2282,10 @@ def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type
else: else:
func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type) func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type)
return func return func
return decorator # type: ignore return decorator # type: ignore
def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]: def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]:
"""A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses. """A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses.
@ -2252,8 +2317,10 @@ def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait:
else: else:
func.__commands_max_concurrency__ = value func.__commands_max_concurrency__ = value
return func return func
return decorator # type: ignore return decorator # type: ignore
def before_invoke(coro) -> Callable[[T], T]: def before_invoke(coro) -> Callable[[T], T]:
"""A decorator that registers a coroutine as a pre-invoke hook. """A decorator that registers a coroutine as a pre-invoke hook.
@ -2292,14 +2359,17 @@ def before_invoke(coro) -> Callable[[T], T]:
bot.add_cog(What()) bot.add_cog(What())
""" """
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
if isinstance(func, Command): if isinstance(func, Command):
func.before_invoke(coro) func.before_invoke(coro)
else: else:
func.__before_invoke__ = coro func.__before_invoke__ = coro
return func return func
return decorator # type: ignore return decorator # type: ignore
def after_invoke(coro) -> Callable[[T], T]: def after_invoke(coro) -> Callable[[T], T]:
"""A decorator that registers a coroutine as a post-invoke hook. """A decorator that registers a coroutine as a post-invoke hook.
@ -2308,10 +2378,12 @@ def after_invoke(coro) -> Callable[[T], T]:
.. versionadded:: 1.4 .. versionadded:: 1.4
""" """
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
if isinstance(func, Command): if isinstance(func, Command):
func.after_invoke(coro) func.after_invoke(coro)
else: else:
func.__after_invoke__ = coro func.__after_invoke__ = coro
return func return func
return decorator # type: ignore return decorator # type: ignore

108
discord/ext/commands/errors.py

@ -100,6 +100,7 @@ __all__ = (
'MissingRequiredFlag', 'MissingRequiredFlag',
) )
class CommandError(DiscordException): class CommandError(DiscordException):
r"""The base exception type for all command related errors. r"""The base exception type for all command related errors.
@ -109,6 +110,7 @@ class CommandError(DiscordException):
in a special way as they are caught and passed into a special event in a special way as they are caught and passed into a special event
from :class:`.Bot`\, :func:`.on_command_error`. from :class:`.Bot`\, :func:`.on_command_error`.
""" """
def __init__(self, message: Optional[str] = None, *args: Any) -> None: def __init__(self, message: Optional[str] = None, *args: Any) -> None:
if message is not None: if message is not None:
# clean-up @everyone and @here mentions # clean-up @everyone and @here mentions
@ -117,6 +119,7 @@ class CommandError(DiscordException):
else: else:
super().__init__(*args) super().__init__(*args)
class ConversionError(CommandError): class ConversionError(CommandError):
"""Exception raised when a Converter class raises non-CommandError. """Exception raised when a Converter class raises non-CommandError.
@ -130,18 +133,22 @@ class ConversionError(CommandError):
The original exception that was raised. You can also get this via The original exception that was raised. You can also get this via
the ``__cause__`` attribute. the ``__cause__`` attribute.
""" """
def __init__(self, converter: Converter, original: Exception) -> None: def __init__(self, converter: Converter, original: Exception) -> None:
self.converter: Converter = converter self.converter: Converter = converter
self.original: Exception = original self.original: Exception = original
class UserInputError(CommandError): class UserInputError(CommandError):
"""The base exception type for errors that involve errors """The base exception type for errors that involve errors
regarding user input. regarding user input.
This inherits from :exc:`CommandError`. This inherits from :exc:`CommandError`.
""" """
pass pass
class CommandNotFound(CommandError): class CommandNotFound(CommandError):
"""Exception raised when a command is attempted to be invoked """Exception raised when a command is attempted to be invoked
but no command under that name is found. but no command under that name is found.
@ -151,8 +158,10 @@ class CommandNotFound(CommandError):
This inherits from :exc:`CommandError`. This inherits from :exc:`CommandError`.
""" """
pass pass
class MissingRequiredArgument(UserInputError): class MissingRequiredArgument(UserInputError):
"""Exception raised when parsing a command and a parameter """Exception raised when parsing a command and a parameter
that is required is not encountered. that is required is not encountered.
@ -164,33 +173,41 @@ class MissingRequiredArgument(UserInputError):
param: :class:`inspect.Parameter` param: :class:`inspect.Parameter`
The argument that is missing. The argument that is missing.
""" """
def __init__(self, param: Parameter) -> None: def __init__(self, param: Parameter) -> None:
self.param: Parameter = param self.param: Parameter = param
super().__init__(f'{param.name} is a required argument that is missing.') super().__init__(f'{param.name} is a required argument that is missing.')
class TooManyArguments(UserInputError): class TooManyArguments(UserInputError):
"""Exception raised when the command was passed too many arguments and its """Exception raised when the command was passed too many arguments and its
:attr:`.Command.ignore_extra` attribute was not set to ``True``. :attr:`.Command.ignore_extra` attribute was not set to ``True``.
This inherits from :exc:`UserInputError` This inherits from :exc:`UserInputError`
""" """
pass pass
class BadArgument(UserInputError): class BadArgument(UserInputError):
"""Exception raised when a parsing or conversion failure is encountered """Exception raised when a parsing or conversion failure is encountered
on an argument to pass into a command. on an argument to pass into a command.
This inherits from :exc:`UserInputError` This inherits from :exc:`UserInputError`
""" """
pass pass
class CheckFailure(CommandError): class CheckFailure(CommandError):
"""Exception raised when the predicates in :attr:`.Command.checks` have failed. """Exception raised when the predicates in :attr:`.Command.checks` have failed.
This inherits from :exc:`CommandError` This inherits from :exc:`CommandError`
""" """
pass pass
class CheckAnyFailure(CheckFailure): class CheckAnyFailure(CheckFailure):
"""Exception raised when all predicates in :func:`check_any` fail. """Exception raised when all predicates in :func:`check_any` fail.
@ -211,15 +228,18 @@ class CheckAnyFailure(CheckFailure):
self.errors: List[Callable[[Context], bool]] = errors self.errors: List[Callable[[Context], bool]] = errors
super().__init__('You do not have permission to run this command.') super().__init__('You do not have permission to run this command.')
class PrivateMessageOnly(CheckFailure): class PrivateMessageOnly(CheckFailure):
"""Exception raised when an operation does not work outside of private """Exception raised when an operation does not work outside of private
message contexts. message contexts.
This inherits from :exc:`CheckFailure` This inherits from :exc:`CheckFailure`
""" """
def __init__(self, message: Optional[str] = None) -> None: def __init__(self, message: Optional[str] = None) -> None:
super().__init__(message or 'This command can only be used in private messages.') super().__init__(message or 'This command can only be used in private messages.')
class NoPrivateMessage(CheckFailure): class NoPrivateMessage(CheckFailure):
"""Exception raised when an operation does not work in private message """Exception raised when an operation does not work in private message
contexts. contexts.
@ -230,13 +250,16 @@ class NoPrivateMessage(CheckFailure):
def __init__(self, message: Optional[str] = None) -> None: def __init__(self, message: Optional[str] = None) -> None:
super().__init__(message or 'This command cannot be used in private messages.') super().__init__(message or 'This command cannot be used in private messages.')
class NotOwner(CheckFailure): class NotOwner(CheckFailure):
"""Exception raised when the message author is not the owner of the bot. """Exception raised when the message author is not the owner of the bot.
This inherits from :exc:`CheckFailure` This inherits from :exc:`CheckFailure`
""" """
pass pass
class ObjectNotFound(BadArgument): class ObjectNotFound(BadArgument):
"""Exception raised when the argument provided did not match the format """Exception raised when the argument provided did not match the format
of an ID or a mention. of an ID or a mention.
@ -250,10 +273,12 @@ class ObjectNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The argument supplied by the caller that was not matched The argument supplied by the caller that was not matched
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'{argument!r} does not follow a valid ID or mention format.') super().__init__(f'{argument!r} does not follow a valid ID or mention format.')
class MemberNotFound(BadArgument): class MemberNotFound(BadArgument):
"""Exception raised when the member provided was not found in the bot's """Exception raised when the member provided was not found in the bot's
cache. cache.
@ -267,10 +292,12 @@ class MemberNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The member supplied by the caller that was not found The member supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Member "{argument}" not found.') super().__init__(f'Member "{argument}" not found.')
class GuildNotFound(BadArgument): class GuildNotFound(BadArgument):
"""Exception raised when the guild provided was not found in the bot's cache. """Exception raised when the guild provided was not found in the bot's cache.
@ -283,10 +310,12 @@ class GuildNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The guild supplied by the called that was not found The guild supplied by the called that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Guild "{argument}" not found.') super().__init__(f'Guild "{argument}" not found.')
class UserNotFound(BadArgument): class UserNotFound(BadArgument):
"""Exception raised when the user provided was not found in the bot's """Exception raised when the user provided was not found in the bot's
cache. cache.
@ -300,10 +329,12 @@ class UserNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The user supplied by the caller that was not found The user supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'User "{argument}" not found.') super().__init__(f'User "{argument}" not found.')
class MessageNotFound(BadArgument): class MessageNotFound(BadArgument):
"""Exception raised when the message provided was not found in the channel. """Exception raised when the message provided was not found in the channel.
@ -316,10 +347,12 @@ class MessageNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The message supplied by the caller that was not found The message supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Message "{argument}" not found.') super().__init__(f'Message "{argument}" not found.')
class ChannelNotReadable(BadArgument): class ChannelNotReadable(BadArgument):
"""Exception raised when the bot does not have permission to read messages """Exception raised when the bot does not have permission to read messages
in the channel. in the channel.
@ -333,10 +366,12 @@ class ChannelNotReadable(BadArgument):
argument: Union[:class:`.abc.GuildChannel`, :class:`.Thread`] argument: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
The channel supplied by the caller that was not readable The channel supplied by the caller that was not readable
""" """
def __init__(self, argument: Union[GuildChannel, Thread]) -> None: def __init__(self, argument: Union[GuildChannel, Thread]) -> None:
self.argument: Union[GuildChannel, Thread] = argument self.argument: Union[GuildChannel, Thread] = argument
super().__init__(f"Can't read messages in {argument.mention}.") super().__init__(f"Can't read messages in {argument.mention}.")
class ChannelNotFound(BadArgument): class ChannelNotFound(BadArgument):
"""Exception raised when the bot can not find the channel. """Exception raised when the bot can not find the channel.
@ -349,10 +384,12 @@ class ChannelNotFound(BadArgument):
argument: Union[:class:`int`, :class:`str`] argument: Union[:class:`int`, :class:`str`]
The channel supplied by the caller that was not found The channel supplied by the caller that was not found
""" """
def __init__(self, argument: Union[int, str]) -> None: def __init__(self, argument: Union[int, str]) -> None:
self.argument: Union[int, str] = argument self.argument: Union[int, str] = argument
super().__init__(f'Channel "{argument}" not found.') super().__init__(f'Channel "{argument}" not found.')
class ThreadNotFound(BadArgument): class ThreadNotFound(BadArgument):
"""Exception raised when the bot can not find the thread. """Exception raised when the bot can not find the thread.
@ -365,10 +402,12 @@ class ThreadNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The thread supplied by the caller that was not found The thread supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Thread "{argument}" not found.') super().__init__(f'Thread "{argument}" not found.')
class BadColourArgument(BadArgument): class BadColourArgument(BadArgument):
"""Exception raised when the colour is not valid. """Exception raised when the colour is not valid.
@ -381,12 +420,15 @@ class BadColourArgument(BadArgument):
argument: :class:`str` argument: :class:`str`
The colour supplied by the caller that was not valid The colour supplied by the caller that was not valid
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Colour "{argument}" is invalid.') super().__init__(f'Colour "{argument}" is invalid.')
BadColorArgument = BadColourArgument BadColorArgument = BadColourArgument
class RoleNotFound(BadArgument): class RoleNotFound(BadArgument):
"""Exception raised when the bot can not find the role. """Exception raised when the bot can not find the role.
@ -399,10 +441,12 @@ class RoleNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The role supplied by the caller that was not found The role supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Role "{argument}" not found.') super().__init__(f'Role "{argument}" not found.')
class BadInviteArgument(BadArgument): class BadInviteArgument(BadArgument):
"""Exception raised when the invite is invalid or expired. """Exception raised when the invite is invalid or expired.
@ -410,10 +454,12 @@ class BadInviteArgument(BadArgument):
.. versionadded:: 1.5 .. versionadded:: 1.5
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Invite "{argument}" is invalid or expired.') super().__init__(f'Invite "{argument}" is invalid or expired.')
class EmojiNotFound(BadArgument): class EmojiNotFound(BadArgument):
"""Exception raised when the bot can not find the emoji. """Exception raised when the bot can not find the emoji.
@ -426,10 +472,12 @@ class EmojiNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The emoji supplied by the caller that was not found The emoji supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Emoji "{argument}" not found.') super().__init__(f'Emoji "{argument}" not found.')
class PartialEmojiConversionFailure(BadArgument): class PartialEmojiConversionFailure(BadArgument):
"""Exception raised when the emoji provided does not match the correct """Exception raised when the emoji provided does not match the correct
format. format.
@ -443,10 +491,12 @@ class PartialEmojiConversionFailure(BadArgument):
argument: :class:`str` argument: :class:`str`
The emoji supplied by the caller that did not match the regex The emoji supplied by the caller that did not match the regex
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Couldn\'t convert "{argument}" to PartialEmoji.') super().__init__(f'Couldn\'t convert "{argument}" to PartialEmoji.')
class GuildStickerNotFound(BadArgument): class GuildStickerNotFound(BadArgument):
"""Exception raised when the bot can not find the sticker. """Exception raised when the bot can not find the sticker.
@ -459,10 +509,12 @@ class GuildStickerNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The sticker supplied by the caller that was not found The sticker supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Sticker "{argument}" not found.') super().__init__(f'Sticker "{argument}" not found.')
class BadBoolArgument(BadArgument): class BadBoolArgument(BadArgument):
"""Exception raised when a boolean argument was not convertable. """Exception raised when a boolean argument was not convertable.
@ -475,17 +527,21 @@ class BadBoolArgument(BadArgument):
argument: :class:`str` argument: :class:`str`
The boolean argument supplied by the caller that is not in the predefined list The boolean argument supplied by the caller that is not in the predefined list
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'{argument} is not a recognised boolean option') super().__init__(f'{argument} is not a recognised boolean option')
class DisabledCommand(CommandError): class DisabledCommand(CommandError):
"""Exception raised when the command being invoked is disabled. """Exception raised when the command being invoked is disabled.
This inherits from :exc:`CommandError` This inherits from :exc:`CommandError`
""" """
pass pass
class CommandInvokeError(CommandError): class CommandInvokeError(CommandError):
"""Exception raised when the command being invoked raised an exception. """Exception raised when the command being invoked raised an exception.
@ -497,10 +553,12 @@ class CommandInvokeError(CommandError):
The original exception that was raised. You can also get this via The original exception that was raised. You can also get this via
the ``__cause__`` attribute. the ``__cause__`` attribute.
""" """
def __init__(self, e: Exception) -> None: def __init__(self, e: Exception) -> None:
self.original: Exception = e self.original: Exception = e
super().__init__(f'Command raised an exception: {e.__class__.__name__}: {e}') super().__init__(f'Command raised an exception: {e.__class__.__name__}: {e}')
class CommandOnCooldown(CommandError): class CommandOnCooldown(CommandError):
"""Exception raised when the command being invoked is on cooldown. """Exception raised when the command being invoked is on cooldown.
@ -516,12 +574,14 @@ class CommandOnCooldown(CommandError):
retry_after: :class:`float` retry_after: :class:`float`
The amount of seconds to wait before you can retry again. The amount of seconds to wait before you can retry again.
""" """
def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None: def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None:
self.cooldown: Cooldown = cooldown self.cooldown: Cooldown = cooldown
self.retry_after: float = retry_after self.retry_after: float = retry_after
self.type: BucketType = type self.type: BucketType = type
super().__init__(f'You are on cooldown. Try again in {retry_after:.2f}s') super().__init__(f'You are on cooldown. Try again in {retry_after:.2f}s')
class MaxConcurrencyReached(CommandError): class MaxConcurrencyReached(CommandError):
"""Exception raised when the command being invoked has reached its maximum concurrency. """Exception raised when the command being invoked has reached its maximum concurrency.
@ -544,6 +604,7 @@ class MaxConcurrencyReached(CommandError):
fmt = plural % (number, suffix) fmt = plural % (number, suffix)
super().__init__(f'Too many people are using this command. It can only be used {fmt} concurrently.') super().__init__(f'Too many people are using this command. It can only be used {fmt} concurrently.')
class MissingRole(CheckFailure): class MissingRole(CheckFailure):
"""Exception raised when the command invoker lacks a role to run a command. """Exception raised when the command invoker lacks a role to run a command.
@ -557,11 +618,13 @@ class MissingRole(CheckFailure):
The required role that is missing. The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`. This is the parameter passed to :func:`~.commands.has_role`.
""" """
def __init__(self, missing_role: Snowflake) -> None: def __init__(self, missing_role: Snowflake) -> None:
self.missing_role: Snowflake = missing_role self.missing_role: Snowflake = missing_role
message = f'Role {missing_role!r} is required to run this command.' message = f'Role {missing_role!r} is required to run this command.'
super().__init__(message) super().__init__(message)
class BotMissingRole(CheckFailure): class BotMissingRole(CheckFailure):
"""Exception raised when the bot's member lacks a role to run a command. """Exception raised when the bot's member lacks a role to run a command.
@ -575,11 +638,13 @@ class BotMissingRole(CheckFailure):
The required role that is missing. The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`. This is the parameter passed to :func:`~.commands.has_role`.
""" """
def __init__(self, missing_role: Snowflake) -> None: def __init__(self, missing_role: Snowflake) -> None:
self.missing_role: Snowflake = missing_role self.missing_role: Snowflake = missing_role
message = f'Bot requires the role {missing_role!r} to run this command' message = f'Bot requires the role {missing_role!r} to run this command'
super().__init__(message) super().__init__(message)
class MissingAnyRole(CheckFailure): class MissingAnyRole(CheckFailure):
"""Exception raised when the command invoker lacks any of """Exception raised when the command invoker lacks any of
the roles specified to run a command. the roles specified to run a command.
@ -594,6 +659,7 @@ class MissingAnyRole(CheckFailure):
The roles that the invoker is missing. The roles that the invoker is missing.
These are the parameters passed to :func:`~.commands.has_any_role`. These are the parameters passed to :func:`~.commands.has_any_role`.
""" """
def __init__(self, missing_roles: SnowflakeList) -> None: def __init__(self, missing_roles: SnowflakeList) -> None:
self.missing_roles: SnowflakeList = missing_roles self.missing_roles: SnowflakeList = missing_roles
@ -623,6 +689,7 @@ class BotMissingAnyRole(CheckFailure):
These are the parameters passed to :func:`~.commands.has_any_role`. These are the parameters passed to :func:`~.commands.has_any_role`.
""" """
def __init__(self, missing_roles: SnowflakeList) -> None: def __init__(self, missing_roles: SnowflakeList) -> None:
self.missing_roles: SnowflakeList = missing_roles self.missing_roles: SnowflakeList = missing_roles
@ -636,6 +703,7 @@ class BotMissingAnyRole(CheckFailure):
message = f"Bot is missing at least one of the required roles: {fmt}" message = f"Bot is missing at least one of the required roles: {fmt}"
super().__init__(message) super().__init__(message)
class NSFWChannelRequired(CheckFailure): class NSFWChannelRequired(CheckFailure):
"""Exception raised when a channel does not have the required NSFW setting. """Exception raised when a channel does not have the required NSFW setting.
@ -648,10 +716,12 @@ class NSFWChannelRequired(CheckFailure):
channel: Union[:class:`.abc.GuildChannel`, :class:`.Thread`] channel: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
The channel that does not have NSFW enabled. The channel that does not have NSFW enabled.
""" """
def __init__(self, channel: Union[GuildChannel, Thread]) -> None: def __init__(self, channel: Union[GuildChannel, Thread]) -> None:
self.channel: Union[GuildChannel, Thread] = channel self.channel: Union[GuildChannel, Thread] = channel
super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.") super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.")
class MissingPermissions(CheckFailure): class MissingPermissions(CheckFailure):
"""Exception raised when the command invoker lacks permissions to run a """Exception raised when the command invoker lacks permissions to run a
command. command.
@ -663,6 +733,7 @@ class MissingPermissions(CheckFailure):
missing_permissions: List[:class:`str`] missing_permissions: List[:class:`str`]
The required permissions that are missing. The required permissions that are missing.
""" """
def __init__(self, missing_permissions: List[str], *args: Any) -> None: def __init__(self, missing_permissions: List[str], *args: Any) -> None:
self.missing_permissions: List[str] = missing_permissions self.missing_permissions: List[str] = missing_permissions
@ -675,6 +746,7 @@ class MissingPermissions(CheckFailure):
message = f'You are missing {fmt} permission(s) to run this command.' message = f'You are missing {fmt} permission(s) to run this command.'
super().__init__(message, *args) super().__init__(message, *args)
class BotMissingPermissions(CheckFailure): class BotMissingPermissions(CheckFailure):
"""Exception raised when the bot's member lacks permissions to run a """Exception raised when the bot's member lacks permissions to run a
command. command.
@ -686,6 +758,7 @@ class BotMissingPermissions(CheckFailure):
missing_permissions: List[:class:`str`] missing_permissions: List[:class:`str`]
The required permissions that are missing. The required permissions that are missing.
""" """
def __init__(self, missing_permissions: List[str], *args: Any) -> None: def __init__(self, missing_permissions: List[str], *args: Any) -> None:
self.missing_permissions: List[str] = missing_permissions self.missing_permissions: List[str] = missing_permissions
@ -698,6 +771,7 @@ class BotMissingPermissions(CheckFailure):
message = f'Bot requires {fmt} permission(s) to run this command.' message = f'Bot requires {fmt} permission(s) to run this command.'
super().__init__(message, *args) super().__init__(message, *args)
class BadUnionArgument(UserInputError): class BadUnionArgument(UserInputError):
"""Exception raised when a :data:`typing.Union` converter fails for all """Exception raised when a :data:`typing.Union` converter fails for all
its associated types. its associated types.
@ -713,6 +787,7 @@ class BadUnionArgument(UserInputError):
errors: List[:class:`CommandError`] errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion. A list of errors that were caught from failing the conversion.
""" """
def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None: def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None:
self.param: Parameter = param self.param: Parameter = param
self.converters: Tuple[Type, ...] = converters self.converters: Tuple[Type, ...] = converters
@ -734,6 +809,7 @@ class BadUnionArgument(UserInputError):
super().__init__(f'Could not convert "{param.name}" into {fmt}.') super().__init__(f'Could not convert "{param.name}" into {fmt}.')
class BadLiteralArgument(UserInputError): class BadLiteralArgument(UserInputError):
"""Exception raised when a :data:`typing.Literal` converter fails for all """Exception raised when a :data:`typing.Literal` converter fails for all
its associated values. its associated values.
@ -751,6 +827,7 @@ class BadLiteralArgument(UserInputError):
errors: List[:class:`CommandError`] errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion. A list of errors that were caught from failing the conversion.
""" """
def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None: def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None:
self.param: Parameter = param self.param: Parameter = param
self.literals: Tuple[Any, ...] = literals self.literals: Tuple[Any, ...] = literals
@ -764,6 +841,7 @@ class BadLiteralArgument(UserInputError):
super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.') super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.')
class ArgumentParsingError(UserInputError): class ArgumentParsingError(UserInputError):
"""An exception raised when the parser fails to parse a user's input. """An exception raised when the parser fails to parse a user's input.
@ -772,8 +850,10 @@ class ArgumentParsingError(UserInputError):
There are child classes that implement more granular parsing errors for There are child classes that implement more granular parsing errors for
i18n purposes. i18n purposes.
""" """
pass pass
class UnexpectedQuoteError(ArgumentParsingError): class UnexpectedQuoteError(ArgumentParsingError):
"""An exception raised when the parser encounters a quote mark inside a non-quoted string. """An exception raised when the parser encounters a quote mark inside a non-quoted string.
@ -784,10 +864,12 @@ class UnexpectedQuoteError(ArgumentParsingError):
quote: :class:`str` quote: :class:`str`
The quote mark that was found inside the non-quoted string. The quote mark that was found inside the non-quoted string.
""" """
def __init__(self, quote: str) -> None: def __init__(self, quote: str) -> None:
self.quote: str = quote self.quote: str = quote
super().__init__(f'Unexpected quote mark, {quote!r}, in non-quoted string') super().__init__(f'Unexpected quote mark, {quote!r}, in non-quoted string')
class InvalidEndOfQuotedStringError(ArgumentParsingError): class InvalidEndOfQuotedStringError(ArgumentParsingError):
"""An exception raised when a space is expected after the closing quote in a string """An exception raised when a space is expected after the closing quote in a string
but a different character is found. but a different character is found.
@ -799,10 +881,12 @@ class InvalidEndOfQuotedStringError(ArgumentParsingError):
char: :class:`str` char: :class:`str`
The character found instead of the expected string. The character found instead of the expected string.
""" """
def __init__(self, char: str) -> None: def __init__(self, char: str) -> None:
self.char: str = char self.char: str = char
super().__init__(f'Expected space after closing quotation but received {char!r}') super().__init__(f'Expected space after closing quotation but received {char!r}')
class ExpectedClosingQuoteError(ArgumentParsingError): class ExpectedClosingQuoteError(ArgumentParsingError):
"""An exception raised when a quote character is expected but not found. """An exception raised when a quote character is expected but not found.
@ -818,6 +902,7 @@ class ExpectedClosingQuoteError(ArgumentParsingError):
self.close_quote: str = close_quote self.close_quote: str = close_quote
super().__init__(f'Expected closing {close_quote}.') super().__init__(f'Expected closing {close_quote}.')
class ExtensionError(DiscordException): class ExtensionError(DiscordException):
"""Base exception for extension related errors. """Base exception for extension related errors.
@ -828,6 +913,7 @@ class ExtensionError(DiscordException):
name: :class:`str` name: :class:`str`
The extension that had an error. The extension that had an error.
""" """
def __init__(self, message: Optional[str] = None, *args: Any, name: str) -> None: def __init__(self, message: Optional[str] = None, *args: Any, name: str) -> None:
self.name: str = name self.name: str = name
message = message or f'Extension {name!r} had an error.' message = message or f'Extension {name!r} had an error.'
@ -835,30 +921,37 @@ class ExtensionError(DiscordException):
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere') m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
super().__init__(m, *args) super().__init__(m, *args)
class ExtensionAlreadyLoaded(ExtensionError): class ExtensionAlreadyLoaded(ExtensionError):
"""An exception raised when an extension has already been loaded. """An exception raised when an extension has already been loaded.
This inherits from :exc:`ExtensionError` This inherits from :exc:`ExtensionError`
""" """
def __init__(self, name: str) -> None: def __init__(self, name: str) -> None:
super().__init__(f'Extension {name!r} is already loaded.', name=name) super().__init__(f'Extension {name!r} is already loaded.', name=name)
class ExtensionNotLoaded(ExtensionError): class ExtensionNotLoaded(ExtensionError):
"""An exception raised when an extension was not loaded. """An exception raised when an extension was not loaded.
This inherits from :exc:`ExtensionError` This inherits from :exc:`ExtensionError`
""" """
def __init__(self, name: str) -> None: def __init__(self, name: str) -> None:
super().__init__(f'Extension {name!r} has not been loaded.', name=name) super().__init__(f'Extension {name!r} has not been loaded.', name=name)
class NoEntryPointError(ExtensionError): class NoEntryPointError(ExtensionError):
"""An exception raised when an extension does not have a ``setup`` entry point function. """An exception raised when an extension does not have a ``setup`` entry point function.
This inherits from :exc:`ExtensionError` This inherits from :exc:`ExtensionError`
""" """
def __init__(self, name: str) -> None: def __init__(self, name: str) -> None:
super().__init__(f"Extension {name!r} has no 'setup' function.", name=name) super().__init__(f"Extension {name!r} has no 'setup' function.", name=name)
class ExtensionFailed(ExtensionError): class ExtensionFailed(ExtensionError):
"""An exception raised when an extension failed to load during execution of the module or ``setup`` entry point. """An exception raised when an extension failed to load during execution of the module or ``setup`` entry point.
@ -872,11 +965,13 @@ class ExtensionFailed(ExtensionError):
The original exception that was raised. You can also get this via The original exception that was raised. You can also get this via
the ``__cause__`` attribute. the ``__cause__`` attribute.
""" """
def __init__(self, name: str, original: Exception) -> None: def __init__(self, name: str, original: Exception) -> None:
self.original: Exception = original self.original: Exception = original
msg = f'Extension {name!r} raised an error: {original.__class__.__name__}: {original}' msg = f'Extension {name!r} raised an error: {original.__class__.__name__}: {original}'
super().__init__(msg, name=name) super().__init__(msg, name=name)
class ExtensionNotFound(ExtensionError): class ExtensionNotFound(ExtensionError):
"""An exception raised when an extension is not found. """An exception raised when an extension is not found.
@ -890,10 +985,12 @@ class ExtensionNotFound(ExtensionError):
name: :class:`str` name: :class:`str`
The extension that had the error. The extension that had the error.
""" """
def __init__(self, name: str) -> None: def __init__(self, name: str) -> None:
msg = f'Extension {name!r} could not be loaded.' msg = f'Extension {name!r} could not be loaded.'
super().__init__(msg, name=name) super().__init__(msg, name=name)
class CommandRegistrationError(ClientException): class CommandRegistrationError(ClientException):
"""An exception raised when the command can't be added """An exception raised when the command can't be added
because the name is already taken by a different command. because the name is already taken by a different command.
@ -909,12 +1006,14 @@ class CommandRegistrationError(ClientException):
alias_conflict: :class:`bool` alias_conflict: :class:`bool`
Whether the name that conflicts is an alias of the command we try to add. Whether the name that conflicts is an alias of the command we try to add.
""" """
def __init__(self, name: str, *, alias_conflict: bool = False) -> None: def __init__(self, name: str, *, alias_conflict: bool = False) -> None:
self.name: str = name self.name: str = name
self.alias_conflict: bool = alias_conflict self.alias_conflict: bool = alias_conflict
type_ = 'alias' if alias_conflict else 'command' type_ = 'alias' if alias_conflict else 'command'
super().__init__(f'The {type_} {name} is already an existing command or alias.') super().__init__(f'The {type_} {name} is already an existing command or alias.')
class FlagError(BadArgument): class FlagError(BadArgument):
"""The base exception type for all flag parsing related errors. """The base exception type for all flag parsing related errors.
@ -922,8 +1021,10 @@ class FlagError(BadArgument):
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
pass pass
class TooManyFlags(FlagError): class TooManyFlags(FlagError):
"""An exception raised when a flag has received too many values. """An exception raised when a flag has received too many values.
@ -938,11 +1039,13 @@ class TooManyFlags(FlagError):
values: List[:class:`str`] values: List[:class:`str`]
The values that were passed. The values that were passed.
""" """
def __init__(self, flag: Flag, values: List[str]) -> None: def __init__(self, flag: Flag, values: List[str]) -> None:
self.flag: Flag = flag self.flag: Flag = flag
self.values: List[str] = values self.values: List[str] = values
super().__init__(f'Too many flag values, expected {flag.max_args} but received {len(values)}.') super().__init__(f'Too many flag values, expected {flag.max_args} but received {len(values)}.')
class BadFlagArgument(FlagError): class BadFlagArgument(FlagError):
"""An exception raised when a flag failed to convert a value. """An exception raised when a flag failed to convert a value.
@ -955,6 +1058,7 @@ class BadFlagArgument(FlagError):
flag: :class:`~discord.ext.commands.Flag` flag: :class:`~discord.ext.commands.Flag`
The flag that failed to convert. The flag that failed to convert.
""" """
def __init__(self, flag: Flag) -> None: def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag self.flag: Flag = flag
try: try:
@ -964,6 +1068,7 @@ class BadFlagArgument(FlagError):
super().__init__(f'Could not convert to {name!r} for flag {flag.name!r}') super().__init__(f'Could not convert to {name!r} for flag {flag.name!r}')
class MissingRequiredFlag(FlagError): class MissingRequiredFlag(FlagError):
"""An exception raised when a required flag was not given. """An exception raised when a required flag was not given.
@ -976,10 +1081,12 @@ class MissingRequiredFlag(FlagError):
flag: :class:`~discord.ext.commands.Flag` flag: :class:`~discord.ext.commands.Flag`
The required flag that was not found. The required flag that was not found.
""" """
def __init__(self, flag: Flag) -> None: def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag self.flag: Flag = flag
super().__init__(f'Flag {flag.name!r} is required and missing') super().__init__(f'Flag {flag.name!r} is required and missing')
class MissingFlagArgument(FlagError): class MissingFlagArgument(FlagError):
"""An exception raised when a flag did not get a value. """An exception raised when a flag did not get a value.
@ -992,6 +1099,7 @@ class MissingFlagArgument(FlagError):
flag: :class:`~discord.ext.commands.Flag` flag: :class:`~discord.ext.commands.Flag`
The flag that did not get a value. The flag that did not get a value.
""" """
def __init__(self, flag: Flag) -> None: def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag self.flag: Flag = flag
super().__init__(f'Flag {flag.name!r} does not have an argument') super().__init__(f'Flag {flag.name!r} does not have an argument')

2
discord/ext/commands/help.py

@ -932,7 +932,7 @@ class DefaultHelpCommand(HelpCommand):
def shorten_text(self, text): def shorten_text(self, text):
""":class:`str`: Shortens text to fit into the :attr:`width`.""" """:class:`str`: Shortens text to fit into the :attr:`width`."""
if len(text) > self.width: if len(text) > self.width:
return text[:self.width - 3].rstrip() + '...' return text[: self.width - 3].rstrip() + '...'
return text return text
def get_ending_note(self): def get_ending_note(self):

10
discord/ext/commands/view.py

@ -46,6 +46,7 @@ _quotes = {
} }
_all_quotes = set(_quotes.keys()) | set(_quotes.values()) _all_quotes = set(_quotes.keys()) | set(_quotes.values())
class StringView: class StringView:
def __init__(self, buffer): def __init__(self, buffer):
self.index = 0 self.index = 0
@ -81,20 +82,20 @@ class StringView:
def skip_string(self, string): def skip_string(self, string):
strlen = len(string) strlen = len(string)
if self.buffer[self.index:self.index + strlen] == string: if self.buffer[self.index : self.index + strlen] == string:
self.previous = self.index self.previous = self.index
self.index += strlen self.index += strlen
return True return True
return False return False
def read_rest(self): def read_rest(self):
result = self.buffer[self.index:] result = self.buffer[self.index :]
self.previous = self.index self.previous = self.index
self.index = self.end self.index = self.end
return result return result
def read(self, n): def read(self, n):
result = self.buffer[self.index:self.index + n] result = self.buffer[self.index : self.index + n]
self.previous = self.index self.previous = self.index
self.index += n self.index += n
return result return result
@ -120,7 +121,7 @@ class StringView:
except IndexError: except IndexError:
break break
self.previous = self.index self.previous = self.index
result = self.buffer[self.index:self.index + pos] result = self.buffer[self.index : self.index + pos]
self.index += pos self.index += pos
return result return result
@ -187,6 +188,5 @@ class StringView:
result.append(current) result.append(current)
def __repr__(self): def __repr__(self):
return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>' return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>'

2
discord/ext/tasks/__init__.py

@ -48,9 +48,11 @@ from collections.abc import Sequence
from discord.backoff import ExponentialBackoff from discord.backoff import ExponentialBackoff
from discord.utils import MISSING from discord.utils import MISSING
# fmt: off
__all__ = ( __all__ = (
'loop', 'loop',
) )
# fmt: on
T = TypeVar('T') T = TypeVar('T')
_func = Callable[..., Awaitable[Any]] _func = Callable[..., Awaitable[Any]]

2
discord/file.py

@ -28,9 +28,11 @@ from typing import Any, Dict, Optional, Union
import os import os
import io import io
# fmt: off
__all__ = ( __all__ = (
'File', 'File',
) )
# fmt: on
class File: class File:

4
discord/flags.py

@ -82,7 +82,7 @@ def fill_with_flags(*, inverted: bool = False):
if inverted: if inverted:
max_bits = max(cls.VALID_FLAGS.values()).bit_length() max_bits = max(cls.VALID_FLAGS.values()).bit_length()
cls.DEFAULT_VALUE = -1 + (2 ** max_bits) cls.DEFAULT_VALUE = -1 + (2**max_bits)
else: else:
cls.DEFAULT_VALUE = 0 cls.DEFAULT_VALUE = 0
@ -908,7 +908,7 @@ class Intents(BaseFlags):
- :func:`on_message_edit` - :func:`on_message_edit`
- :func:`on_message_delete` - :func:`on_message_delete`
- :func:`on_raw_message_edit` - :func:`on_raw_message_edit`
For more information go to the :ref:`message content intent documentation <need_message_content_intent>`. For more information go to the :ref:`message content intent documentation <need_message_content_intent>`.
.. note:: .. note::

97
discord/gateway.py

@ -50,19 +50,25 @@ __all__ = (
'ReconnectWebSocket', 'ReconnectWebSocket',
) )
class ReconnectWebSocket(Exception): class ReconnectWebSocket(Exception):
"""Signals to safely reconnect the websocket.""" """Signals to safely reconnect the websocket."""
def __init__(self, shard_id, *, resume=True): def __init__(self, shard_id, *, resume=True):
self.shard_id = shard_id self.shard_id = shard_id
self.resume = resume self.resume = resume
self.op = 'RESUME' if resume else 'IDENTIFY' self.op = 'RESUME' if resume else 'IDENTIFY'
class WebSocketClosure(Exception): class WebSocketClosure(Exception):
"""An exception to make up for the fact that aiohttp doesn't signal closure.""" """An exception to make up for the fact that aiohttp doesn't signal closure."""
pass pass
EventListener = namedtuple('EventListener', 'predicate event result future') EventListener = namedtuple('EventListener', 'predicate event result future')
class GatewayRatelimiter: class GatewayRatelimiter:
def __init__(self, count=110, per=60.0): def __init__(self, count=110, per=60.0):
# The default is 110 to give room for at least 10 heartbeats per minute # The default is 110 to give room for at least 10 heartbeats per minute
@ -171,7 +177,7 @@ class KeepAliveHandler(threading.Thread):
def get_payload(self): def get_payload(self):
return { return {
'op': self.ws.HEARTBEAT, 'op': self.ws.HEARTBEAT,
'd': self.ws.sequence 'd': self.ws.sequence,
} }
def stop(self): def stop(self):
@ -187,6 +193,7 @@ class KeepAliveHandler(threading.Thread):
if self.latency > 10: if self.latency > 10:
_log.warning(self.behind_msg, self.shard_id, self.latency) _log.warning(self.behind_msg, self.shard_id, self.latency)
class VoiceKeepAliveHandler(KeepAliveHandler): class VoiceKeepAliveHandler(KeepAliveHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -198,7 +205,7 @@ class VoiceKeepAliveHandler(KeepAliveHandler):
def get_payload(self): def get_payload(self):
return { return {
'op': self.ws.HEARTBEAT, 'op': self.ws.HEARTBEAT,
'd': int(time.time() * 1000) 'd': int(time.time() * 1000),
} }
def ack(self): def ack(self):
@ -208,10 +215,12 @@ class VoiceKeepAliveHandler(KeepAliveHandler):
self.latency = ack_time - self._last_send self.latency = ack_time - self._last_send
self.recent_ack_latencies.append(self.latency) self.recent_ack_latencies.append(self.latency)
class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse): class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse):
async def close(self, *, code: int = 4000, message: bytes = b'') -> bool: async def close(self, *, code: int = 4000, message: bytes = b'') -> bool:
return await super().close(code=code, message=message) return await super().close(code=code, message=message)
class DiscordWebSocket: class DiscordWebSocket:
"""Implements a WebSocket for Discord's gateway v6. """Implements a WebSocket for Discord's gateway v6.
@ -252,6 +261,7 @@ class DiscordWebSocket:
The authentication token for discord. The authentication token for discord.
""" """
# fmt: off
DISPATCH = 0 DISPATCH = 0
HEARTBEAT = 1 HEARTBEAT = 1
IDENTIFY = 2 IDENTIFY = 2
@ -265,6 +275,7 @@ class DiscordWebSocket:
HELLO = 10 HELLO = 10
HEARTBEAT_ACK = 11 HEARTBEAT_ACK = 11
GUILD_SYNC = 12 GUILD_SYNC = 12
# fmt: on
def __init__(self, socket, *, loop): def __init__(self, socket, *, loop):
self.socket = socket self.socket = socket
@ -300,7 +311,17 @@ class DiscordWebSocket:
pass pass
@classmethod @classmethod
async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False): async def from_client(
cls,
client,
*,
initial=False,
gateway=None,
shard_id=None,
session=None,
sequence=None,
resume=False,
):
"""Creates a main websocket for Discord from a :class:`Client`. """Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only. This is for internal use only.
@ -378,12 +399,12 @@ class DiscordWebSocket:
'$browser': 'discord.py', '$browser': 'discord.py',
'$device': 'discord.py', '$device': 'discord.py',
'$referrer': '', '$referrer': '',
'$referring_domain': '' '$referring_domain': '',
}, },
'compress': True, 'compress': True,
'large_threshold': 250, 'large_threshold': 250,
'v': 3 'v': 3,
} },
} }
if self.shard_id is not None and self.shard_count is not None: if self.shard_id is not None and self.shard_count is not None:
@ -395,7 +416,7 @@ class DiscordWebSocket:
'status': state._status, 'status': state._status,
'game': state._activity, 'game': state._activity,
'since': 0, 'since': 0,
'afk': False 'afk': False,
} }
if state._intents is not None: if state._intents is not None:
@ -412,8 +433,8 @@ class DiscordWebSocket:
'd': { 'd': {
'seq': self.sequence, 'seq': self.sequence,
'session_id': self.session_id, 'session_id': self.session_id,
'token': self.token 'token': self.token,
} },
} }
await self.send_as_json(payload) await self.send_as_json(payload)
@ -494,15 +515,23 @@ class DiscordWebSocket:
self.session_id = data['session_id'] self.session_id = data['session_id']
# pass back shard ID to ready handler # pass back shard ID to ready handler
data['__shard_id__'] = self.shard_id data['__shard_id__'] = self.shard_id
_log.info('Shard ID %s has connected to Gateway: %s (Session ID: %s).', _log.info(
self.shard_id, ', '.join(trace), self.session_id) 'Shard ID %s has connected to Gateway: %s (Session ID: %s).',
self.shard_id,
', '.join(trace),
self.session_id,
)
elif event == 'RESUMED': elif event == 'RESUMED':
self._trace = trace = data.get('_trace', []) self._trace = trace = data.get('_trace', [])
# pass back the shard ID to the resumed handler # pass back the shard ID to the resumed handler
data['__shard_id__'] = self.shard_id data['__shard_id__'] = self.shard_id
_log.info('Shard ID %s has successfully RESUMED session %s under trace %s.', _log.info(
self.shard_id, self.session_id, ', '.join(trace)) 'Shard ID %s has successfully RESUMED session %s under trace %s.',
self.shard_id,
self.session_id,
', '.join(trace),
)
try: try:
func = self._discord_parsers[event] func = self._discord_parsers[event]
@ -625,8 +654,8 @@ class DiscordWebSocket:
'activities': activity, 'activities': activity,
'afk': False, 'afk': False,
'since': since, 'since': since,
'status': status 'status': status,
} },
} }
sent = utils._to_json(payload) sent = utils._to_json(payload)
@ -639,8 +668,8 @@ class DiscordWebSocket:
'd': { 'd': {
'guild_id': guild_id, 'guild_id': guild_id,
'presences': presences, 'presences': presences,
'limit': limit 'limit': limit,
} },
} }
if nonce: if nonce:
@ -652,7 +681,6 @@ class DiscordWebSocket:
if query is not None: if query is not None:
payload['d']['query'] = query payload['d']['query'] = query
await self.send_as_json(payload) await self.send_as_json(payload)
async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False):
@ -662,8 +690,8 @@ class DiscordWebSocket:
'guild_id': guild_id, 'guild_id': guild_id,
'channel_id': channel_id, 'channel_id': channel_id,
'self_mute': self_mute, 'self_mute': self_mute,
'self_deaf': self_deaf 'self_deaf': self_deaf,
} },
} }
_log.debug('Updating our voice state to %s.', payload) _log.debug('Updating our voice state to %s.', payload)
@ -677,6 +705,7 @@ class DiscordWebSocket:
self._close_code = code self._close_code = code
await self.socket.close(code=code) await self.socket.close(code=code)
class DiscordVoiceWebSocket: class DiscordVoiceWebSocket:
"""Implements the websocket protocol for handling voice connections. """Implements the websocket protocol for handling voice connections.
@ -708,6 +737,7 @@ class DiscordVoiceWebSocket:
Receive only. Indicates a user has disconnected from voice. Receive only. Indicates a user has disconnected from voice.
""" """
# fmt: off
IDENTIFY = 0 IDENTIFY = 0
SELECT_PROTOCOL = 1 SELECT_PROTOCOL = 1
READY = 2 READY = 2
@ -720,6 +750,7 @@ class DiscordVoiceWebSocket:
RESUMED = 9 RESUMED = 9
CLIENT_CONNECT = 12 CLIENT_CONNECT = 12
CLIENT_DISCONNECT = 13 CLIENT_DISCONNECT = 13
# fmt: on
def __init__(self, socket, loop, *, hook=None): def __init__(self, socket, loop, *, hook=None):
self.ws = socket self.ws = socket
@ -746,8 +777,8 @@ class DiscordVoiceWebSocket:
'd': { 'd': {
'token': state.token, 'token': state.token,
'server_id': str(state.server_id), 'server_id': str(state.server_id),
'session_id': state.session_id 'session_id': state.session_id,
} },
} }
await self.send_as_json(payload) await self.send_as_json(payload)
@ -759,8 +790,8 @@ class DiscordVoiceWebSocket:
'server_id': str(state.server_id), 'server_id': str(state.server_id),
'user_id': str(state.user.id), 'user_id': str(state.user.id),
'session_id': state.session_id, 'session_id': state.session_id,
'token': state.token 'token': state.token,
} },
} }
await self.send_as_json(payload) await self.send_as_json(payload)
@ -791,9 +822,9 @@ class DiscordVoiceWebSocket:
'data': { 'data': {
'address': ip, 'address': ip,
'port': port, 'port': port,
'mode': mode 'mode': mode,
} },
} },
} }
await self.send_as_json(payload) await self.send_as_json(payload)
@ -802,8 +833,8 @@ class DiscordVoiceWebSocket:
payload = { payload = {
'op': self.CLIENT_CONNECT, 'op': self.CLIENT_CONNECT,
'd': { 'd': {
'audio_ssrc': self._connection.ssrc 'audio_ssrc': self._connection.ssrc,
} },
} }
await self.send_as_json(payload) await self.send_as_json(payload)
@ -813,8 +844,8 @@ class DiscordVoiceWebSocket:
'op': self.SPEAKING, 'op': self.SPEAKING,
'd': { 'd': {
'speaking': int(state), 'speaking': int(state),
'delay': 0 'delay': 0,
} },
} }
await self.send_as_json(payload) await self.send_as_json(payload)
@ -847,8 +878,8 @@ class DiscordVoiceWebSocket:
state.endpoint_ip = data['ip'] state.endpoint_ip = data['ip']
packet = bytearray(70) packet = bytearray(70)
struct.pack_into('>H', packet, 0, 1) # 1 = Send struct.pack_into('>H', packet, 0, 1) # 1 = Send
struct.pack_into('>H', packet, 2, 70) # 70 = Length struct.pack_into('>H', packet, 2, 70) # 70 = Length
struct.pack_into('>I', packet, 4, state.ssrc) struct.pack_into('>I', packet, 4, state.ssrc)
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port)) state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
recv = await self.loop.sock_recv(state.socket, 70) recv = await self.loop.sock_recv(state.socket, 70)

9
discord/guild.py

@ -81,9 +81,11 @@ from .audit_logs import AuditLogEntry
from .object import OLDEST_OBJECT, Object from .object import OLDEST_OBJECT, Object
# fmt: off
__all__ = ( __all__ = (
'Guild', 'Guild',
) )
# fmt: on
MISSING = utils.MISSING MISSING = utils.MISSING
@ -2830,9 +2832,9 @@ class Guild(Hashable):
if data and entries: if data and entries:
if limit is not None: if limit is not None:
limit -= len(data) limit -= len(data)
before = Object(id=int(entries[-1]['id'])) before = Object(id=int(entries[-1]['id']))
return data.get('users', []), entries, before, limit return data.get('users', []), entries, before, limit
async def _after_strategy(retrieve, after, limit): async def _after_strategy(retrieve, after, limit):
@ -2846,7 +2848,7 @@ class Guild(Hashable):
if data and entries: if data and entries:
if limit is not None: if limit is not None:
limit -= len(data) limit -= len(data)
after = Object(id=int(entries[0]['id'])) after = Object(id=int(entries[0]['id']))
return data.get('users', []), entries, after, limit return data.get('users', []), entries, after, limit
@ -2864,7 +2866,6 @@ class Guild(Hashable):
if isinstance(after, datetime.datetime): if isinstance(after, datetime.datetime):
after = Object(id=utils.time_snowflake(after, high=True)) after = Object(id=utils.time_snowflake(after, high=True))
if oldest_first is None: if oldest_first is None:
reverse = after is not None reverse = after is not None
else: else:

4
discord/http.py

@ -594,7 +594,9 @@ class HTTPClient:
return self.request(r, json=payload, reason=reason) return self.request(r, json=payload, reason=reason)
def edit_message(self, channel_id: Snowflake, message_id: Snowflake, *, params: MultipartParameters) -> Response[message.Message]: def edit_message(
self, channel_id: Snowflake, message_id: Snowflake, *, params: MultipartParameters
) -> Response[message.Message]:
r = Route('PATCH', '/channels/{channel_id}/messages/{message_id}', channel_id=channel_id, message_id=message_id) r = Route('PATCH', '/channels/{channel_id}/messages/{message_id}', channel_id=channel_id, message_id=message_id)
if params.files: if params.files:
return self.request(r, files=params.files, form=params.multipart) return self.request(r, files=params.files, form=params.multipart)

8
discord/member.py

@ -767,13 +767,15 @@ class Member(discord.abc.Messageable, _UserTag):
if roles is not MISSING: if roles is not MISSING:
payload['roles'] = tuple(r.id for r in roles) payload['roles'] = tuple(r.id for r in roles)
if timed_out_until is not MISSING: if timed_out_until is not MISSING:
if timed_out_until is None: if timed_out_until is None:
payload['communication_disabled_until'] = None payload['communication_disabled_until'] = None
else: else:
if timed_out_until.tzinfo is None: if timed_out_until.tzinfo is None:
raise TypeError('timed_out_until must be an aware datetime. Consider using discord.utils.utcnow() or datetime.datetime.now().astimezone() for local time.') raise TypeError(
'timed_out_until must be an aware datetime. Consider using discord.utils.utcnow() or datetime.datetime.now().astimezone() for local time.'
)
payload['communication_disabled_until'] = timed_out_until.isoformat() payload['communication_disabled_until'] = timed_out_until.isoformat()
if payload: if payload:
@ -940,7 +942,7 @@ class Member(discord.abc.Messageable, _UserTag):
"""Returns whether this member is timed out. """Returns whether this member is timed out.
.. versionadded:: 2.0 .. versionadded:: 2.0
Returns Returns
-------- --------
:class:`bool` :class:`bool`

2
discord/mentions.py

@ -25,9 +25,11 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
from typing import Type, TypeVar, Union, List, TYPE_CHECKING, Any, Union from typing import Type, TypeVar, Union, List, TYPE_CHECKING, Any, Union
# fmt: off
__all__ = ( __all__ = (
'AllowedMentions', 'AllowedMentions',
) )
# fmt: on
if TYPE_CHECKING: if TYPE_CHECKING:
from .types.message import AllowedMentions as AllowedMentionsPayload from .types.message import AllowedMentions as AllowedMentionsPayload

2
discord/message.py

@ -1348,7 +1348,7 @@ class Message(Hashable):
----------- -----------
\*attachments: :class:`Attachment` \*attachments: :class:`Attachment`
Attachments to remove from the message. Attachments to remove from the message.
Raises Raises
------- -------
HTTPException HTTPException

2
discord/mixins.py

@ -27,6 +27,7 @@ __all__ = (
'Hashable', 'Hashable',
) )
class EqualityComparable: class EqualityComparable:
__slots__ = () __slots__ = ()
@ -40,6 +41,7 @@ class EqualityComparable:
return other.id != self.id return other.id != self.id
return True return True
class Hashable(EqualityComparable): class Hashable(EqualityComparable):
__slots__ = () __slots__ = ()

4
discord/object.py

@ -35,11 +35,15 @@ from typing import (
if TYPE_CHECKING: if TYPE_CHECKING:
import datetime import datetime
SupportsIntCast = Union[SupportsInt, str, bytes, bytearray] SupportsIntCast = Union[SupportsInt, str, bytes, bytearray]
# fmt: off
__all__ = ( __all__ = (
'Object', 'Object',
) )
# fmt: on
class Object(Hashable): class Object(Hashable):
"""Represents a generic Discord object. """Represents a generic Discord object.

12
discord/oggparse.py

@ -36,13 +36,17 @@ __all__ = (
'OggStream', 'OggStream',
) )
class OggError(DiscordException): class OggError(DiscordException):
"""An exception that is thrown for Ogg stream parsing errors.""" """An exception that is thrown for Ogg stream parsing errors."""
pass pass
# https://tools.ietf.org/html/rfc3533 # https://tools.ietf.org/html/rfc3533
# https://tools.ietf.org/html/rfc7845 # https://tools.ietf.org/html/rfc7845
class OggPage: class OggPage:
_header: ClassVar[struct.Struct] = struct.Struct('<xBQIIIB') _header: ClassVar[struct.Struct] = struct.Struct('<xBQIIIB')
if TYPE_CHECKING: if TYPE_CHECKING:
@ -57,11 +61,10 @@ class OggPage:
try: try:
header = stream.read(struct.calcsize(self._header.format)) header = stream.read(struct.calcsize(self._header.format))
self.flag, self.gran_pos, self.serial, \ self.flag, self.gran_pos, self.serial, self.pagenum, self.crc, self.segnum = self._header.unpack(header)
self.pagenum, self.crc, self.segnum = self._header.unpack(header)
self.segtable: bytes = stream.read(self.segnum) self.segtable: bytes = stream.read(self.segnum)
bodylen = sum(struct.unpack('B'*self.segnum, self.segtable)) bodylen = sum(struct.unpack('B' * self.segnum, self.segtable))
self.data: bytes = stream.read(bodylen) self.data: bytes = stream.read(bodylen)
except Exception: except Exception:
raise OggError('bad data stream') from None raise OggError('bad data stream') from None
@ -76,7 +79,7 @@ class OggPage:
partial = True partial = True
else: else:
packetlen += seg packetlen += seg
yield self.data[offset:offset+packetlen], True yield self.data[offset : offset + packetlen], True
offset += packetlen offset += packetlen
packetlen = 0 packetlen = 0
partial = False partial = False
@ -84,6 +87,7 @@ class OggPage:
if partial: if partial:
yield self.data[offset:], False yield self.data[offset:], False
class OggStream: class OggStream:
def __init__(self, stream: IO[bytes]) -> None: def __init__(self, stream: IO[bytes]) -> None:
self.stream: IO[bytes] = stream self.stream: IO[bytes] = stream

114
discord/opus.py

@ -42,6 +42,7 @@ if TYPE_CHECKING:
BAND_CTL = Literal['narrow', 'medium', 'wide', 'superwide', 'full'] BAND_CTL = Literal['narrow', 'medium', 'wide', 'superwide', 'full']
SIGNAL_CTL = Literal['auto', 'voice', 'music'] SIGNAL_CTL = Literal['auto', 'voice', 'music']
class BandCtl(TypedDict): class BandCtl(TypedDict):
narrow: int narrow: int
medium: int medium: int
@ -49,11 +50,13 @@ class BandCtl(TypedDict):
superwide: int superwide: int
full: int full: int
class SignalCtl(TypedDict): class SignalCtl(TypedDict):
auto: int auto: int
voice: int voice: int
music: int music: int
__all__ = ( __all__ = (
'Encoder', 'Encoder',
'OpusError', 'OpusError',
@ -62,23 +65,27 @@ __all__ = (
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
c_int_ptr = ctypes.POINTER(ctypes.c_int) c_int_ptr = ctypes.POINTER(ctypes.c_int)
c_int16_ptr = ctypes.POINTER(ctypes.c_int16) c_int16_ptr = ctypes.POINTER(ctypes.c_int16)
c_float_ptr = ctypes.POINTER(ctypes.c_float) c_float_ptr = ctypes.POINTER(ctypes.c_float)
_lib = None _lib = None
class EncoderStruct(ctypes.Structure): class EncoderStruct(ctypes.Structure):
pass pass
class DecoderStruct(ctypes.Structure): class DecoderStruct(ctypes.Structure):
pass pass
EncoderStructPtr = ctypes.POINTER(EncoderStruct) EncoderStructPtr = ctypes.POINTER(EncoderStruct)
DecoderStructPtr = ctypes.POINTER(DecoderStruct) DecoderStructPtr = ctypes.POINTER(DecoderStruct)
## Some constants from opus_defines.h ## Some constants from opus_defines.h
# Error codes # Error codes
# fmt: off
OK = 0 OK = 0
BAD_ARG = -1 BAD_ARG = -1
@ -96,6 +103,7 @@ CTL_SET_SIGNAL = 4024
# Decoder CTLs # Decoder CTLs
CTL_SET_GAIN = 4034 CTL_SET_GAIN = 4034
CTL_LAST_PACKET_DURATION = 4039 CTL_LAST_PACKET_DURATION = 4039
# fmt: on
band_ctl: BandCtl = { band_ctl: BandCtl = {
'narrow': 1101, 'narrow': 1101,
@ -111,12 +119,14 @@ signal_ctl: SignalCtl = {
'music': 3002, 'music': 3002,
} }
def _err_lt(result: int, func: Callable, args: List) -> int: def _err_lt(result: int, func: Callable, args: List) -> int:
if result < OK: if result < OK:
_log.info('error has happened in %s', func.__name__) _log.info('error has happened in %s', func.__name__)
raise OpusError(result) raise OpusError(result)
return result return result
def _err_ne(result: T, func: Callable, args: List) -> T: def _err_ne(result: T, func: Callable, args: List) -> T:
ret = args[-1]._obj ret = args[-1]._obj
if ret.value != OK: if ret.value != OK:
@ -124,6 +134,7 @@ def _err_ne(result: T, func: Callable, args: List) -> T:
raise OpusError(ret.value) raise OpusError(ret.value)
return result return result
# A list of exported functions. # A list of exported functions.
# The first argument is obviously the name. # The first argument is obviously the name.
# The second one are the types of arguments it takes. # The second one are the types of arguments it takes.
@ -131,54 +142,46 @@ def _err_ne(result: T, func: Callable, args: List) -> T:
# The fourth is the error handler. # The fourth is the error handler.
exported_functions: List[Tuple[Any, ...]] = [ exported_functions: List[Tuple[Any, ...]] = [
# Generic # Generic
('opus_get_version_string', ('opus_get_version_string', None, ctypes.c_char_p, None),
None, ctypes.c_char_p, None), ('opus_strerror', [ctypes.c_int], ctypes.c_char_p, None),
('opus_strerror',
[ctypes.c_int], ctypes.c_char_p, None),
# Encoder functions # Encoder functions
('opus_encoder_get_size', ('opus_encoder_get_size', [ctypes.c_int], ctypes.c_int, None),
[ctypes.c_int], ctypes.c_int, None), ('opus_encoder_create', [ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], EncoderStructPtr, _err_ne),
('opus_encoder_create', ('opus_encode', [EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt),
[ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], EncoderStructPtr, _err_ne), (
('opus_encode', 'opus_encode_float',
[EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt), [EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32],
('opus_encode_float', ctypes.c_int32,
[EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt), _err_lt,
('opus_encoder_ctl', ),
None, ctypes.c_int32, _err_lt), ('opus_encoder_ctl', None, ctypes.c_int32, _err_lt),
('opus_encoder_destroy', ('opus_encoder_destroy', [EncoderStructPtr], None, None),
[EncoderStructPtr], None, None),
# Decoder functions # Decoder functions
('opus_decoder_get_size', ('opus_decoder_get_size', [ctypes.c_int], ctypes.c_int, None),
[ctypes.c_int], ctypes.c_int, None), ('opus_decoder_create', [ctypes.c_int, ctypes.c_int, c_int_ptr], DecoderStructPtr, _err_ne),
('opus_decoder_create', (
[ctypes.c_int, ctypes.c_int, c_int_ptr], DecoderStructPtr, _err_ne), 'opus_decode',
('opus_decode',
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_int16_ptr, ctypes.c_int, ctypes.c_int], [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_int16_ptr, ctypes.c_int, ctypes.c_int],
ctypes.c_int, _err_lt), ctypes.c_int,
('opus_decode_float', _err_lt,
),
(
'opus_decode_float',
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_float_ptr, ctypes.c_int, ctypes.c_int], [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_float_ptr, ctypes.c_int, ctypes.c_int],
ctypes.c_int, _err_lt), ctypes.c_int,
('opus_decoder_ctl', _err_lt,
None, ctypes.c_int32, _err_lt), ),
('opus_decoder_destroy', ('opus_decoder_ctl', None, ctypes.c_int32, _err_lt),
[DecoderStructPtr], None, None), ('opus_decoder_destroy', [DecoderStructPtr], None, None),
('opus_decoder_get_nb_samples', ('opus_decoder_get_nb_samples', [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int, _err_lt),
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int, _err_lt),
# Packet functions # Packet functions
('opus_packet_get_bandwidth', ('opus_packet_get_bandwidth', [ctypes.c_char_p], ctypes.c_int, _err_lt),
[ctypes.c_char_p], ctypes.c_int, _err_lt), ('opus_packet_get_nb_channels', [ctypes.c_char_p], ctypes.c_int, _err_lt),
('opus_packet_get_nb_channels', ('opus_packet_get_nb_frames', [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
[ctypes.c_char_p], ctypes.c_int, _err_lt), ('opus_packet_get_samples_per_frame', [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
('opus_packet_get_nb_frames',
[ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
('opus_packet_get_samples_per_frame',
[ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
] ]
def libopus_loader(name: str) -> Any: def libopus_loader(name: str) -> Any:
# create the library... # create the library...
lib = ctypes.cdll.LoadLibrary(name) lib = ctypes.cdll.LoadLibrary(name)
@ -203,6 +206,7 @@ def libopus_loader(name: str) -> Any:
return lib return lib
def _load_default() -> bool: def _load_default() -> bool:
global _lib global _lib
try: try:
@ -219,6 +223,7 @@ def _load_default() -> bool:
return _lib is not None return _lib is not None
def load_opus(name: str) -> None: def load_opus(name: str) -> None:
"""Loads the libopus shared library for use with voice. """Loads the libopus shared library for use with voice.
@ -257,6 +262,7 @@ def load_opus(name: str) -> None:
global _lib global _lib
_lib = libopus_loader(name) _lib = libopus_loader(name)
def is_loaded() -> bool: def is_loaded() -> bool:
"""Function to check if opus lib is successfully loaded either """Function to check if opus lib is successfully loaded either
via the :func:`ctypes.util.find_library` call of :func:`load_opus`. via the :func:`ctypes.util.find_library` call of :func:`load_opus`.
@ -271,6 +277,7 @@ def is_loaded() -> bool:
global _lib global _lib
return _lib is not None return _lib is not None
class OpusError(DiscordException): class OpusError(DiscordException):
"""An exception that is thrown for libopus related errors. """An exception that is thrown for libopus related errors.
@ -286,10 +293,13 @@ class OpusError(DiscordException):
_log.info('"%s" has happened', msg) _log.info('"%s" has happened', msg)
super().__init__(msg) super().__init__(msg)
class OpusNotLoaded(DiscordException): class OpusNotLoaded(DiscordException):
"""An exception that is thrown for when libopus is not loaded.""" """An exception that is thrown for when libopus is not loaded."""
pass pass
class _OpusStruct: class _OpusStruct:
SAMPLING_RATE = 48000 SAMPLING_RATE = 48000
CHANNELS = 2 CHANNELS = 2
@ -306,6 +316,7 @@ class _OpusStruct:
return _lib.opus_get_version_string().decode('utf-8') return _lib.opus_get_version_string().decode('utf-8')
class Encoder(_OpusStruct): class Encoder(_OpusStruct):
def __init__(self, application: int = APPLICATION_AUDIO): def __init__(self, application: int = APPLICATION_AUDIO):
_OpusStruct.get_opus_version() _OpusStruct.get_opus_version()
@ -322,7 +333,7 @@ class Encoder(_OpusStruct):
if hasattr(self, '_state'): if hasattr(self, '_state'):
_lib.opus_encoder_destroy(self._state) _lib.opus_encoder_destroy(self._state)
# This is a destructor, so it's okay to assign None # This is a destructor, so it's okay to assign None
self._state = None # type: ignore self._state = None # type: ignore
def _create_state(self) -> EncoderStruct: def _create_state(self) -> EncoderStruct:
ret = ctypes.c_int() ret = ctypes.c_int()
@ -352,18 +363,19 @@ class Encoder(_OpusStruct):
_lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0) _lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0)
def set_expected_packet_loss_percent(self, percentage: float) -> None: def set_expected_packet_loss_percent(self, percentage: float) -> None:
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore
def encode(self, pcm: bytes, frame_size: int) -> bytes: def encode(self, pcm: bytes, frame_size: int) -> bytes:
max_data_bytes = len(pcm) max_data_bytes = len(pcm)
# bytes can be used to reference pointer # bytes can be used to reference pointer
pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore
data = (ctypes.c_char * max_data_bytes)() data = (ctypes.c_char * max_data_bytes)()
ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes) ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes)
# array can be initialized with bytes but mypy doesn't know # array can be initialized with bytes but mypy doesn't know
return array.array('b', data[:ret]).tobytes() # type: ignore return array.array('b', data[:ret]).tobytes() # type: ignore
class Decoder(_OpusStruct): class Decoder(_OpusStruct):
def __init__(self): def __init__(self):
@ -375,7 +387,7 @@ class Decoder(_OpusStruct):
if hasattr(self, '_state'): if hasattr(self, '_state'):
_lib.opus_decoder_destroy(self._state) _lib.opus_decoder_destroy(self._state)
# This is a destructor, so it's okay to assign None # This is a destructor, so it's okay to assign None
self._state = None # type: ignore self._state = None # type: ignore
def _create_state(self) -> DecoderStruct: def _create_state(self) -> DecoderStruct:
ret = ctypes.c_int() ret = ctypes.c_int()
@ -411,12 +423,12 @@ class Decoder(_OpusStruct):
def set_gain(self, dB: float) -> int: def set_gain(self, dB: float) -> int:
"""Sets the decoder gain in dB, from -128 to 128.""" """Sets the decoder gain in dB, from -128 to 128."""
dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8) dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8)
return self._set_gain(dB_Q8) return self._set_gain(dB_Q8)
def set_volume(self, mult: float) -> int: def set_volume(self, mult: float) -> int:
"""Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc.""" """Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc."""
return self.set_gain(20 * math.log10(mult)) # amplitude ratio return self.set_gain(20 * math.log10(mult)) # amplitude ratio
def _get_last_packet_duration(self) -> int: def _get_last_packet_duration(self) -> int:
"""Gets the duration (in samples) of the last packet successfully decoded or concealed.""" """Gets the duration (in samples) of the last packet successfully decoded or concealed."""
@ -428,7 +440,7 @@ class Decoder(_OpusStruct):
@overload @overload
def decode(self, data: bytes, *, fec: bool) -> bytes: def decode(self, data: bytes, *, fec: bool) -> bytes:
... ...
@overload @overload
def decode(self, data: Literal[None], *, fec: Literal[False]) -> bytes: def decode(self, data: Literal[None], *, fec: Literal[False]) -> bytes:
... ...
@ -451,4 +463,4 @@ class Decoder(_OpusStruct):
ret = _lib.opus_decode(self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec) ret = _lib.opus_decode(self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec)
return array.array('h', pcm[:ret * channel_count]).tobytes() return array.array('h', pcm[: ret * channel_count]).tobytes()

3
discord/partial_emoji.py

@ -31,15 +31,18 @@ from .asset import Asset, AssetMixin
from .errors import InvalidArgument from .errors import InvalidArgument
from . import utils from . import utils
# fmt: off
__all__ = ( __all__ = (
'PartialEmoji', 'PartialEmoji',
) )
# fmt: on
if TYPE_CHECKING: if TYPE_CHECKING:
from .state import ConnectionState from .state import ConnectionState
from datetime import datetime from datetime import datetime
from .types.message import PartialEmoji as PartialEmojiPayload from .types.message import PartialEmoji as PartialEmojiPayload
class _EmojiTag: class _EmojiTag:
__slots__ = () __slots__ = ()

10
discord/permissions.py

@ -46,8 +46,10 @@ def make_permission_alias(alias: str) -> Callable[[Callable[[Any], int]], permis
return decorator return decorator
P = TypeVar('P', bound='Permissions') P = TypeVar('P', bound='Permissions')
@fill_with_flags() @fill_with_flags()
class Permissions(BaseFlags): class Permissions(BaseFlags):
"""Wraps up the Discord permission value. """Wraps up the Discord permission value.
@ -554,21 +556,23 @@ class Permissions(BaseFlags):
@flag_value @flag_value
def start_embedded_activities(self) -> int: def start_embedded_activities(self) -> int:
""":class:`bool`: Returns ``True`` if a user can launch an embedded application in a Voice channel. """:class:`bool`: Returns ``True`` if a user can launch an embedded application in a Voice channel.
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
return 1 << 39 return 1 << 39
@flag_value @flag_value
def moderate_members(self) -> int: def moderate_members(self) -> int:
""":class:`bool`: Returns ``True`` if a user can time out other members. """:class:`bool`: Returns ``True`` if a user can time out other members.
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
return 1 << 40 return 1 << 40
PO = TypeVar('PO', bound='PermissionOverwrite') PO = TypeVar('PO', bound='PermissionOverwrite')
def _augment_from_permissions(cls): def _augment_from_permissions(cls):
cls.VALID_NAMES = set(Permissions.VALID_FLAGS) cls.VALID_NAMES = set(Permissions.VALID_FLAGS)
aliases = set() aliases = set()

33
discord/player.py

@ -36,7 +36,7 @@ import sys
import re import re
import io import io
from typing import Any, Callable, Generic, IO, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union from typing import Any, Callable, Generic, IO, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
from .errors import ClientException from .errors import ClientException
from .opus import Encoder as OpusEncoder from .opus import Encoder as OpusEncoder
@ -68,6 +68,7 @@ if sys.platform != 'win32':
else: else:
CREATE_NO_WINDOW = 0x08000000 CREATE_NO_WINDOW = 0x08000000
class AudioSource: class AudioSource:
"""Represents an audio stream. """Represents an audio stream.
@ -114,6 +115,7 @@ class AudioSource:
def __del__(self) -> None: def __del__(self) -> None:
self.cleanup() self.cleanup()
class PCMAudio(AudioSource): class PCMAudio(AudioSource):
"""Represents raw 16-bit 48KHz stereo PCM audio source. """Represents raw 16-bit 48KHz stereo PCM audio source.
@ -122,6 +124,7 @@ class PCMAudio(AudioSource):
stream: :term:`py:file object` stream: :term:`py:file object`
A file-like object that reads byte data representing raw PCM. A file-like object that reads byte data representing raw PCM.
""" """
def __init__(self, stream: io.BufferedIOBase) -> None: def __init__(self, stream: io.BufferedIOBase) -> None:
self.stream: io.BufferedIOBase = stream self.stream: io.BufferedIOBase = stream
@ -131,6 +134,7 @@ class PCMAudio(AudioSource):
return b'' return b''
return ret return ret
class FFmpegAudio(AudioSource): class FFmpegAudio(AudioSource):
"""Represents an FFmpeg (or AVConv) based AudioSource. """Represents an FFmpeg (or AVConv) based AudioSource.
@ -140,7 +144,14 @@ class FFmpegAudio(AudioSource):
.. versionadded:: 1.3 .. versionadded:: 1.3
""" """
def __init__(self, source: Union[str, io.BufferedIOBase], *, executable: str = 'ffmpeg', args: Any, **subprocess_kwargs: Any): def __init__(
self,
source: Union[str, io.BufferedIOBase],
*,
executable: str = 'ffmpeg',
args: Any,
**subprocess_kwargs: Any,
):
piping = subprocess_kwargs.get('stdin') == subprocess.PIPE piping = subprocess_kwargs.get('stdin') == subprocess.PIPE
if piping and isinstance(source, str): if piping and isinstance(source, str):
raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin") raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin")
@ -191,7 +202,6 @@ class FFmpegAudio(AudioSource):
else: else:
_log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode) _log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode)
def _pipe_writer(self, source: io.BufferedIOBase) -> None: def _pipe_writer(self, source: io.BufferedIOBase) -> None:
while self._process: while self._process:
# arbitrarily large read size # arbitrarily large read size
@ -211,6 +221,7 @@ class FFmpegAudio(AudioSource):
self._kill_process() self._kill_process()
self._process = self._stdout = self._stdin = MISSING self._process = self._stdout = self._stdin = MISSING
class FFmpegPCMAudio(FFmpegAudio): class FFmpegPCMAudio(FFmpegAudio):
"""An audio source from FFmpeg (or AVConv). """An audio source from FFmpeg (or AVConv).
@ -254,7 +265,7 @@ class FFmpegPCMAudio(FFmpegAudio):
pipe: bool = False, pipe: bool = False,
stderr: Optional[IO[str]] = None, stderr: Optional[IO[str]] = None,
before_options: Optional[str] = None, before_options: Optional[str] = None,
options: Optional[str] = None options: Optional[str] = None,
) -> None: ) -> None:
args = [] args = []
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr} subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
@ -282,6 +293,7 @@ class FFmpegPCMAudio(FFmpegAudio):
def is_opus(self) -> bool: def is_opus(self) -> bool:
return False return False
class FFmpegOpusAudio(FFmpegAudio): class FFmpegOpusAudio(FFmpegAudio):
"""An audio source from FFmpeg (or AVConv). """An audio source from FFmpeg (or AVConv).
@ -367,6 +379,7 @@ class FFmpegOpusAudio(FFmpegAudio):
codec = 'copy' if codec in ('opus', 'libopus') else 'libopus' codec = 'copy' if codec in ('opus', 'libopus') else 'libopus'
# fmt: off
args.extend(('-map_metadata', '-1', args.extend(('-map_metadata', '-1',
'-f', 'opus', '-f', 'opus',
'-c:a', codec, '-c:a', codec,
@ -374,6 +387,7 @@ class FFmpegOpusAudio(FFmpegAudio):
'-ac', '2', '-ac', '2',
'-b:a', f'{bitrate}k', '-b:a', f'{bitrate}k',
'-loglevel', 'warning')) '-loglevel', 'warning'))
# fmt: on
if isinstance(options, str): if isinstance(options, str):
args.extend(shlex.split(options)) args.extend(shlex.split(options))
@ -500,8 +514,7 @@ class FFmpegOpusAudio(FFmpegAudio):
probefunc = method probefunc = method
fallback = cls._probe_codec_fallback fallback = cls._probe_codec_fallback
else: else:
raise TypeError("Expected str or callable for parameter 'probe', " \ raise TypeError(f"Expected str or callable for parameter 'probe', not '{method.__class__.__name__}'")
f"not '{method.__class__.__name__}'")
codec = bitrate = None codec = bitrate = None
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -537,13 +550,13 @@ class FFmpegOpusAudio(FFmpegAudio):
codec = streamdata.get('codec_name') codec = streamdata.get('codec_name')
bitrate = int(streamdata.get('bit_rate', 0)) bitrate = int(streamdata.get('bit_rate', 0))
bitrate = max(round(bitrate/1000), 512) bitrate = max(round(bitrate / 1000), 512)
return codec, bitrate return codec, bitrate
@staticmethod @staticmethod
def _probe_codec_fallback(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: def _probe_codec_fallback(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]:
args = [executable, '-hide_banner', '-i', source] args = [executable, '-hide_banner', '-i', source]
proc = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) proc = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
out, _ = proc.communicate(timeout=20) out, _ = proc.communicate(timeout=20)
output = out.decode('utf8') output = out.decode('utf8')
@ -565,6 +578,7 @@ class FFmpegOpusAudio(FFmpegAudio):
def is_opus(self) -> bool: def is_opus(self) -> bool:
return True return True
class PCMVolumeTransformer(AudioSource, Generic[AT]): class PCMVolumeTransformer(AudioSource, Generic[AT]):
"""Transforms a previous :class:`AudioSource` to have volume controls. """Transforms a previous :class:`AudioSource` to have volume controls.
@ -613,6 +627,7 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]):
ret = self.original.read() ret = self.original.read()
return audioop.mul(ret, 2, min(self._volume, 2.0)) return audioop.mul(ret, 2, min(self._volume, 2.0))
class AudioPlayer(threading.Thread): class AudioPlayer(threading.Thread):
DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0 DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0
@ -625,7 +640,7 @@ class AudioPlayer(threading.Thread):
self._end: threading.Event = threading.Event() self._end: threading.Event = threading.Event()
self._resumed: threading.Event = threading.Event() self._resumed: threading.Event = threading.Event()
self._resumed.set() # we are not paused self._resumed.set() # we are not paused
self._current_error: Optional[Exception] = None self._current_error: Optional[Exception] = None
self._connected: threading.Event = client._connected self._connected: threading.Event = client._connected
self._lock: threading.Lock = threading.Lock() self._lock: threading.Lock = threading.Lock()

5
discord/raw_models.py

@ -34,7 +34,7 @@ if TYPE_CHECKING:
MessageUpdateEvent, MessageUpdateEvent,
ReactionClearEvent, ReactionClearEvent,
ReactionClearEmojiEvent, ReactionClearEmojiEvent,
IntegrationDeleteEvent IntegrationDeleteEvent,
) )
from .message import Message from .message import Message
from .partial_emoji import PartialEmoji from .partial_emoji import PartialEmoji
@ -179,8 +179,7 @@ class RawReactionActionEvent(_RawReprMixin):
.. versionadded:: 1.3 .. versionadded:: 1.3
""" """
__slots__ = ('message_id', 'user_id', 'channel_id', 'guild_id', 'emoji', __slots__ = ('message_id', 'user_id', 'channel_id', 'guild_id', 'emoji', 'event_type', 'member')
'event_type', 'member')
def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str) -> None: def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str) -> None:
self.message_id: int = int(data['message_id']) self.message_id: int = int(data['message_id'])

13
discord/reaction.py

@ -27,9 +27,11 @@ from typing import Any, TYPE_CHECKING, AsyncIterator, List, Union, Optional
from .object import Object from .object import Object
# fmt: off
__all__ = ( __all__ = (
'Reaction', 'Reaction',
) )
# fmt: on
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import User from .user import User
@ -40,6 +42,7 @@ if TYPE_CHECKING:
from .emoji import Emoji from .emoji import Emoji
from .abc import Snowflake from .abc import Snowflake
class Reaction: class Reaction:
"""Represents a reaction to a message. """Represents a reaction to a message.
@ -77,6 +80,7 @@ class Reaction:
message: :class:`Message` message: :class:`Message`
Message this reaction is for. Message this reaction is for.
""" """
__slots__ = ('message', 'count', 'emoji', 'me') __slots__ = ('message', 'count', 'emoji', 'me')
def __init__(self, *, message: Message, data: ReactionPayload, emoji: Optional[Union[PartialEmoji, Emoji, str]] = None): def __init__(self, *, message: Message, data: ReactionPayload, emoji: Optional[Union[PartialEmoji, Emoji, str]] = None):
@ -157,7 +161,9 @@ class Reaction:
""" """
await self.message.clear_reaction(self.emoji) await self.message.clear_reaction(self.emoji)
async def users(self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None) -> AsyncIterator[Union[Member, User]]: async def users(
self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None
) -> AsyncIterator[Union[Member, User]]:
"""Returns an :term:`asynchronous iterator` representing the users that have reacted to the message. """Returns an :term:`asynchronous iterator` representing the users that have reacted to the message.
The ``after`` parameter must represent a member The ``after`` parameter must represent a member
@ -222,9 +228,7 @@ class Reaction:
state = message._state state = message._state
after_id = after.id if after else None after_id = after.id if after else None
data = await state.http.get_reaction_users( data = await state.http.get_reaction_users(message.channel.id, message.id, emoji, retrieve, after=after_id)
message.channel.id, message.id, emoji, retrieve, after=after_id
)
if data: if data:
limit -= len(data) limit -= len(data)
@ -241,4 +245,3 @@ class Reaction:
member = guild.get_member(member_id) member = guild.get_member(member_id)
yield member or User(state=state, data=raw_user) yield member or User(state=state, data=raw_user)

12
discord/stage_instance.py

@ -31,9 +31,11 @@ from .mixins import Hashable
from .errors import InvalidArgument from .errors import InvalidArgument
from .enums import StagePrivacyLevel, try_enum from .enums import StagePrivacyLevel, try_enum
# fmt: off
__all__ = ( __all__ = (
'StageInstance', 'StageInstance',
) )
# fmt: on
if TYPE_CHECKING: if TYPE_CHECKING:
from .types.channel import StageInstance as StageInstancePayload from .types.channel import StageInstance as StageInstancePayload
@ -107,12 +109,18 @@ class StageInstance(Hashable):
def channel(self) -> Optional[StageChannel]: def channel(self) -> Optional[StageChannel]:
"""Optional[:class:`StageChannel`]: The channel that stage instance is running in.""" """Optional[:class:`StageChannel`]: The channel that stage instance is running in."""
# the returned channel will always be a StageChannel or None # the returned channel will always be a StageChannel or None
return self._state.get_channel(self.channel_id) # type: ignore return self._state.get_channel(self.channel_id) # type: ignore
def is_public(self) -> bool: def is_public(self) -> bool:
return self.privacy_level is StagePrivacyLevel.public return self.privacy_level is StagePrivacyLevel.public
async def edit(self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None) -> None: async def edit(
self,
*,
topic: str = MISSING,
privacy_level: StagePrivacyLevel = MISSING,
reason: Optional[str] = None,
) -> None:
"""|coro| """|coro|
Edits the stage instance. Edits the stage instance.

6
discord/state.py

@ -446,7 +446,9 @@ class ConnectionState:
# If presences are enabled then we get back the old guild.large behaviour # If presences are enabled then we get back the old guild.large behaviour
return self._chunk_guilds and not guild.chunked and not (self._intents.presences and not guild.large) return self._chunk_guilds and not guild.chunked and not (self._intents.presences and not guild.large)
def _get_guild_channel(self, data: MessagePayload, guild_id: Optional[int] = None) -> Tuple[Union[Channel, Thread], Optional[Guild]]: def _get_guild_channel(
self, data: MessagePayload, guild_id: Optional[int] = None
) -> Tuple[Union[Channel, Thread], Optional[Guild]]:
channel_id = int(data['channel_id']) channel_id = int(data['channel_id'])
try: try:
guild_id = guild_id or int(data['guild_id']) guild_id = guild_id or int(data['guild_id'])
@ -691,7 +693,7 @@ class ConnectionState:
self._view_store.dispatch_view(component_type, custom_id, interaction) self._view_store.dispatch_view(component_type, custom_id, interaction)
elif data['type'] == 5: # modal submit elif data['type'] == 5: # modal submit
custom_id = interaction.data['custom_id'] # type: ignore custom_id = interaction.data['custom_id'] # type: ignore
components = interaction.data['components'] # type: ignore components = interaction.data['components'] # type: ignore
self._view_store.dispatch_modal(custom_id, interaction, components) # type: ignore self._view_store.dispatch_modal(custom_id, interaction, components) # type: ignore
self.dispatch('interaction', interaction) self.dispatch('interaction', interaction)

2
discord/sticker.py

@ -116,7 +116,7 @@ class StickerPack(Hashable):
self.name: str = data['name'] self.name: str = data['name']
self.sku_id: int = int(data['sku_id']) self.sku_id: int = int(data['sku_id'])
self.cover_sticker_id: int = int(data['cover_sticker_id']) self.cover_sticker_id: int = int(data['cover_sticker_id'])
self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore
self.description: str = data['description'] self.description: str = data['description']
self._banner: int = int(data['banner_asset_id']) self._banner: int = int(data['banner_asset_id'])

4
discord/template.py

@ -29,9 +29,11 @@ from .utils import parse_time, _get_as_snowflake, _bytes_to_base64_data, MISSING
from .enums import VoiceRegion from .enums import VoiceRegion
from .guild import Guild from .guild import Guild
# fmt: off
__all__ = ( __all__ = (
'Template', 'Template',
) )
# fmt: on
if TYPE_CHECKING: if TYPE_CHECKING:
import datetime import datetime
@ -310,7 +312,7 @@ class Template:
@property @property
def url(self) -> str: def url(self) -> str:
""":class:`str`: The template url. """:class:`str`: The template url.
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
return f'https://discord.new/{self.code}' return f'https://discord.new/{self.code}'

5
discord/types/appinfo.py

@ -30,6 +30,7 @@ from .user import User
from .team import Team from .team import Team
from .snowflake import Snowflake from .snowflake import Snowflake
class BaseAppInfo(TypedDict): class BaseAppInfo(TypedDict):
id: Snowflake id: Snowflake
name: str name: str
@ -38,6 +39,7 @@ class BaseAppInfo(TypedDict):
summary: str summary: str
description: str description: str
class _AppInfoOptional(TypedDict, total=False): class _AppInfoOptional(TypedDict, total=False):
team: Team team: Team
guild_id: Snowflake guild_id: Snowflake
@ -48,12 +50,14 @@ class _AppInfoOptional(TypedDict, total=False):
hook: bool hook: bool
max_participants: int max_participants: int
class AppInfo(BaseAppInfo, _AppInfoOptional): class AppInfo(BaseAppInfo, _AppInfoOptional):
rpc_origins: List[str] rpc_origins: List[str]
owner: User owner: User
bot_public: bool bot_public: bool
bot_require_code_grant: bool bot_require_code_grant: bool
class _PartialAppInfoOptional(TypedDict, total=False): class _PartialAppInfoOptional(TypedDict, total=False):
rpc_origins: List[str] rpc_origins: List[str]
cover_image: str cover_image: str
@ -63,5 +67,6 @@ class _PartialAppInfoOptional(TypedDict, total=False):
max_participants: int max_participants: int
flags: int flags: int
class PartialAppInfo(_PartialAppInfoOptional, BaseAppInfo): class PartialAppInfo(_PartialAppInfoOptional, BaseAppInfo):
pass pass

11
discord/types/embed.py

@ -24,50 +24,61 @@ DEALINGS IN THE SOFTWARE.
from typing import List, Literal, TypedDict from typing import List, Literal, TypedDict
class _EmbedFooterOptional(TypedDict, total=False): class _EmbedFooterOptional(TypedDict, total=False):
icon_url: str icon_url: str
proxy_icon_url: str proxy_icon_url: str
class EmbedFooter(_EmbedFooterOptional): class EmbedFooter(_EmbedFooterOptional):
text: str text: str
class _EmbedFieldOptional(TypedDict, total=False): class _EmbedFieldOptional(TypedDict, total=False):
inline: bool inline: bool
class EmbedField(_EmbedFieldOptional): class EmbedField(_EmbedFieldOptional):
name: str name: str
value: str value: str
class EmbedThumbnail(TypedDict, total=False): class EmbedThumbnail(TypedDict, total=False):
url: str url: str
proxy_url: str proxy_url: str
height: int height: int
width: int width: int
class EmbedVideo(TypedDict, total=False): class EmbedVideo(TypedDict, total=False):
url: str url: str
proxy_url: str proxy_url: str
height: int height: int
width: int width: int
class EmbedImage(TypedDict, total=False): class EmbedImage(TypedDict, total=False):
url: str url: str
proxy_url: str proxy_url: str
height: int height: int
width: int width: int
class EmbedProvider(TypedDict, total=False): class EmbedProvider(TypedDict, total=False):
name: str name: str
url: str url: str
class EmbedAuthor(TypedDict, total=False): class EmbedAuthor(TypedDict, total=False):
name: str name: str
url: str url: str
icon_url: str icon_url: str
proxy_icon_url: str proxy_icon_url: str
EmbedType = Literal['rich', 'image', 'video', 'gifv', 'article', 'link'] EmbedType = Literal['rich', 'image', 'video', 'gifv', 'article', 'link']
class Embed(TypedDict, total=False): class Embed(TypedDict, total=False):
title: str title: str
type: EmbedType type: EmbedType

2
discord/types/team.py

@ -29,12 +29,14 @@ from typing import TypedDict, List, Optional
from .user import PartialUser from .user import PartialUser
from .snowflake import Snowflake from .snowflake import Snowflake
class TeamMember(TypedDict): class TeamMember(TypedDict):
user: PartialUser user: PartialUser
membership_state: int membership_state: int
permissions: List[str] permissions: List[str]
team_id: Snowflake team_id: Snowflake
class Team(TypedDict): class Team(TypedDict):
id: Snowflake id: Snowflake
name: str name: str

2
discord/ui/item.py

@ -28,9 +28,11 @@ from typing import Any, Callable, Coroutine, Dict, Generic, Optional, TYPE_CHECK
from ..interactions import Interaction from ..interactions import Interaction
# fmt: off
__all__ = ( __all__ = (
'Item', 'Item',
) )
# fmt: on
if TYPE_CHECKING: if TYPE_CHECKING:
from ..enums import ComponentType from ..enums import ComponentType

2
discord/ui/modal.py

@ -42,9 +42,11 @@ if TYPE_CHECKING:
from ..types.interactions import ModalSubmitComponentInteractionData as ModalSubmitComponentInteractionDataPayload from ..types.interactions import ModalSubmitComponentInteractionData as ModalSubmitComponentInteractionDataPayload
# fmt: off
__all__ = ( __all__ = (
'Modal', 'Modal',
) )
# fmt: on
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)

1
discord/ui/select.py

@ -224,7 +224,6 @@ class Select(Item[V]):
default=default, default=default,
) )
self.append_option(option) self.append_option(option)
def append_option(self, option: SelectOption): def append_option(self, option: SelectOption):

5
discord/ui/text_input.py

@ -40,10 +40,11 @@ if TYPE_CHECKING:
from .view import View from .view import View
# fmt: off
__all__ = ( __all__ = (
'TextInput', 'TextInput',
) )
# fmt: on
V = TypeVar('V', bound='View', covariant=True) V = TypeVar('V', bound='View', covariant=True)
@ -177,7 +178,7 @@ class TextInput(Item[V]):
def max_length(self) -> Optional[int]: def max_length(self) -> Optional[int]:
""":class:`int`: The maximum length of the text input.""" """:class:`int`: The maximum length of the text input."""
return self._underlying.max_length return self._underlying.max_length
@max_length.setter @max_length.setter
def max_length(self, value: Optional[int]) -> None: def max_length(self, value: Optional[int]) -> None:
self._underlying.max_length = value self._underlying.max_length = value

13
discord/ui/view.py

@ -43,9 +43,11 @@ from ..components import (
) )
from ..utils import MISSING from ..utils import MISSING
# fmt: off
__all__ = ( __all__ = (
'View', 'View',
) )
# fmt: on
if TYPE_CHECKING: if TYPE_CHECKING:
@ -81,9 +83,11 @@ def _component_to_item(component: Component) -> Item:
class _ViewWeights: class _ViewWeights:
# fmt: off
__slots__ = ( __slots__ = (
'weights', 'weights',
) )
# fmt: on
def __init__(self, children: List[Item]): def __init__(self, children: List[Item]):
self.weights: List[int] = [0, 0, 0, 0, 0] self.weights: List[int] = [0, 0, 0, 0, 0]
@ -517,7 +521,7 @@ class ViewStore:
def remove_view(self, view: View): def remove_view(self, view: View):
if view.__discord_ui_modal__: if view.__discord_ui_modal__:
self._modals.pop(view.custom_id, None) # type: ignore self._modals.pop(view.custom_id, None) # type: ignore
return return
for item in view.children: for item in view.children:
if item.is_dispatchable(): if item.is_dispatchable():
@ -542,7 +546,12 @@ class ViewStore:
item.refresh_state(interaction.data) # type: ignore item.refresh_state(interaction.data) # type: ignore
view._dispatch_item(item, interaction) view._dispatch_item(item, interaction)
def dispatch_modal(self, custom_id: str, interaction: Interaction, components: List[ModalSubmitComponentInteractionDataPayload]): def dispatch_modal(
self,
custom_id: str,
interaction: Interaction,
components: List[ModalSubmitComponentInteractionDataPayload],
):
modal = self._modals.get(custom_id) modal = self._modals.get(custom_id)
if modal is None: if modal is None:
_log.debug("Modal interaction referencing unknown custom_id %s. Discarding", custom_id) _log.debug("Modal interaction referencing unknown custom_id %s. Discarding", custom_id)

2
discord/utils.py

@ -360,7 +360,7 @@ def time_snowflake(dt: datetime.datetime, high: bool = False) -> int:
The snowflake representing the time given. The snowflake representing the time given.
""" """
discord_millis = int(dt.timestamp() * 1000 - DISCORD_EPOCH) discord_millis = int(dt.timestamp() * 1000 - DISCORD_EPOCH)
return (discord_millis << 22) + (2 ** 22 - 1 if high else 0) return (discord_millis << 22) + (2**22 - 1 if high else 0)
def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> Optional[T]: def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> Optional[T]:

19
discord/voice_client.py

@ -66,12 +66,13 @@ if TYPE_CHECKING:
VoiceServerUpdate as VoiceServerUpdatePayload, VoiceServerUpdate as VoiceServerUpdatePayload,
SupportedModes, SupportedModes,
) )
has_nacl: bool has_nacl: bool
try: try:
import nacl.secret # type: ignore import nacl.secret # type: ignore
has_nacl = True has_nacl = True
except ImportError: except ImportError:
has_nacl = False has_nacl = False
@ -82,10 +83,9 @@ __all__ = (
) )
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
class VoiceProtocol: class VoiceProtocol:
"""A class that represents the Discord voice protocol. """A class that represents the Discord voice protocol.
@ -195,6 +195,7 @@ class VoiceProtocol:
key_id, _ = self.channel._get_voice_client_key() key_id, _ = self.channel._get_voice_client_key()
self.client._connection._remove_voice_client(key_id) self.client._connection._remove_voice_client(key_id)
class VoiceClient(VoiceProtocol): class VoiceClient(VoiceProtocol):
"""Represents a Discord voice connection. """Represents a Discord voice connection.
@ -221,12 +222,12 @@ class VoiceClient(VoiceProtocol):
loop: :class:`asyncio.AbstractEventLoop` loop: :class:`asyncio.AbstractEventLoop`
The event loop that the voice client is running on. The event loop that the voice client is running on.
""" """
endpoint_ip: str endpoint_ip: str
voice_port: int voice_port: int
secret_key: List[int] secret_key: List[int]
ssrc: int ssrc: int
def __init__(self, client: Client, channel: abc.Connectable): def __init__(self, client: Client, channel: abc.Connectable):
if not has_nacl: if not has_nacl:
raise RuntimeError("PyNaCl library needed in order to use voice") raise RuntimeError("PyNaCl library needed in order to use voice")
@ -309,8 +310,10 @@ class VoiceClient(VoiceProtocol):
endpoint = data.get('endpoint') endpoint = data.get('endpoint')
if endpoint is None or self.token is None: if endpoint is None or self.token is None:
_log.warning('Awaiting endpoint... This requires waiting. ' \ _log.warning(
'If timeout occurred considering raising the timeout and reconnecting.') 'Awaiting endpoint... This requires waiting. '
'If timeout occurred considering raising the timeout and reconnecting.'
)
return return
self.endpoint, _, _ = endpoint.rpartition(':') self.endpoint, _, _ = endpoint.rpartition(':')
@ -359,7 +362,7 @@ class VoiceClient(VoiceProtocol):
self._connected.set() self._connected.set()
return ws return ws
async def connect(self, *, reconnect: bool, timeout: float) ->None: async def connect(self, *, reconnect: bool, timeout: float) -> None:
_log.info('Connecting to voice...') _log.info('Connecting to voice...')
self.timeout = timeout self.timeout = timeout
@ -556,7 +559,7 @@ class VoiceClient(VoiceProtocol):
return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4] return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4]
def play(self, source: AudioSource, *, after: Callable[[Optional[Exception]], Any]=None) -> None: def play(self, source: AudioSource, *, after: Callable[[Optional[Exception]], Any] = None) -> None:
"""Plays an :class:`AudioSource`. """Plays an :class:`AudioSource`.
The finalizer, ``after`` is called after the source has been exhausted The finalizer, ``after`` is called after the source has been exhausted

4
discord/webhook/async_.py

@ -681,7 +681,7 @@ class WebhookMessage(Message):
attachments: List[Union[:class:`Attachment`, :class:`File`]] attachments: List[Union[:class:`Attachment`, :class:`File`]]
A list of attachments to keep in the message as well as new files to upload. If ``[]`` is passed A list of attachments to keep in the message as well as new files to upload. If ``[]`` is passed
then all attachments are removed. then all attachments are removed.
.. note:: .. note::
New files will always appear after current attachments. New files will always appear after current attachments.
@ -761,7 +761,7 @@ class WebhookMessage(Message):
----------- -----------
\*attachments: :class:`Attachment` \*attachments: :class:`Attachment`
Attachments to remove from the message. Attachments to remove from the message.
Raises Raises
------- -------
HTTPException HTTPException

2
discord/webhook/sync.py

@ -469,7 +469,7 @@ class SyncWebhookMessage(Message):
----------- -----------
\*attachments: :class:`Attachment` \*attachments: :class:`Attachment`
Attachments to remove from the message. Attachments to remove from the message.
Raises Raises
------- -------
HTTPException HTTPException

28
discord/widget.py

@ -46,6 +46,7 @@ __all__ = (
'Widget', 'Widget',
) )
class WidgetChannel: class WidgetChannel:
"""Represents a "partial" widget channel. """Represents a "partial" widget channel.
@ -76,6 +77,7 @@ class WidgetChannel:
position: :class:`int` position: :class:`int`
The channel's position The channel's position
""" """
__slots__ = ('id', 'name', 'position') __slots__ = ('id', 'name', 'position')
def __init__(self, id: int, name: str, position: int) -> None: def __init__(self, id: int, name: str, position: int) -> None:
@ -99,6 +101,7 @@ class WidgetChannel:
""":class:`datetime.datetime`: Returns the channel's creation time in UTC.""" """:class:`datetime.datetime`: Returns the channel's creation time in UTC."""
return snowflake_time(self.id) return snowflake_time(self.id)
class WidgetMember(BaseUser): class WidgetMember(BaseUser):
"""Represents a "partial" member of the widget's guild. """Represents a "partial" member of the widget's guild.
@ -147,9 +150,21 @@ class WidgetMember(BaseUser):
connected_channel: Optional[:class:`WidgetChannel`] connected_channel: Optional[:class:`WidgetChannel`]
Which channel the member is connected to. Which channel the member is connected to.
""" """
__slots__ = ('name', 'status', 'nick', 'avatar', 'discriminator',
'id', 'bot', 'activity', 'deafened', 'suppress', 'muted', __slots__ = (
'connected_channel') 'name',
'status',
'nick',
'avatar',
'discriminator',
'id',
'bot',
'activity',
'deafened',
'suppress',
'muted',
'connected_channel',
)
if TYPE_CHECKING: if TYPE_CHECKING:
activity: Optional[Union[BaseActivity, Spotify]] activity: Optional[Union[BaseActivity, Spotify]]
@ -159,7 +174,7 @@ class WidgetMember(BaseUser):
*, *,
state: ConnectionState, state: ConnectionState,
data: WidgetMemberPayload, data: WidgetMemberPayload,
connected_channel: Optional[WidgetChannel] = None connected_channel: Optional[WidgetChannel] = None,
) -> None: ) -> None:
super().__init__(state=state, data=data) super().__init__(state=state, data=data)
self.nick: Optional[str] = data.get('nick') self.nick: Optional[str] = data.get('nick')
@ -181,8 +196,7 @@ class WidgetMember(BaseUser):
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"<WidgetMember name={self.name!r} discriminator={self.discriminator!r}" f"<WidgetMember name={self.name!r} discriminator={self.discriminator!r}" f" bot={self.bot} nick={self.nick!r}>"
f" bot={self.bot} nick={self.nick!r}>"
) )
@property @property
@ -190,6 +204,7 @@ class WidgetMember(BaseUser):
""":class:`str`: Returns the member's display name.""" """:class:`str`: Returns the member's display name."""
return self.nick or self.name return self.nick or self.name
class Widget: class Widget:
"""Represents a :class:`Guild` widget. """Represents a :class:`Guild` widget.
@ -227,6 +242,7 @@ class Widget:
retrieved is capped. retrieved is capped.
""" """
__slots__ = ('_state', 'channels', '_invite', 'id', 'members', 'name') __slots__ = ('_state', 'channels', '_invite', 'id', 'members', 'name')
def __init__(self, *, state: ConnectionState, data: WidgetPayload) -> None: def __init__(self, *, state: ConnectionState, data: WidgetPayload) -> None:

Loading…
Cancel
Save