From 88b520b5abe27ce7f302d0e7f8094641303346fb Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sun, 20 Feb 2022 06:29:41 -0500 Subject: [PATCH] Reformat code using black Segments where readability was hampered were fixed by appropriate format skipping directives. New code should hopefully be black compatible. The moment they remove the -S option is probably the moment I stop using black though. --- discord/__main__.py | 37 ++++++- discord/activity.py | 7 +- discord/asset.py | 5 +- discord/audit_logs.py | 8 +- discord/backoff.py | 9 +- discord/channel.py | 4 +- discord/client.py | 91 +++++++++++------ discord/colour.py | 44 ++++---- discord/context_managers.py | 10 +- discord/embeds.py | 6 +- discord/emoji.py | 2 + discord/enums.py | 1 + discord/ext/commands/_types.py | 4 +- discord/ext/commands/bot.py | 30 ++++-- discord/ext/commands/cog.py | 14 ++- discord/ext/commands/context.py | 7 +- discord/ext/commands/cooldowns.py | 20 ++-- discord/ext/commands/core.py | 162 +++++++++++++++++++++--------- discord/ext/commands/errors.py | 108 ++++++++++++++++++++ discord/ext/commands/help.py | 2 +- discord/ext/commands/view.py | 10 +- discord/ext/tasks/__init__.py | 2 + discord/file.py | 2 + discord/flags.py | 4 +- discord/gateway.py | 97 ++++++++++++------ discord/guild.py | 9 +- discord/http.py | 4 +- discord/member.py | 8 +- discord/mentions.py | 2 + discord/message.py | 2 +- discord/mixins.py | 2 + discord/object.py | 4 + discord/oggparse.py | 12 ++- discord/opus.py | 114 +++++++++++---------- discord/partial_emoji.py | 3 + discord/permissions.py | 10 +- discord/player.py | 33 ++++-- discord/raw_models.py | 5 +- discord/reaction.py | 13 ++- discord/stage_instance.py | 12 ++- discord/state.py | 6 +- discord/sticker.py | 2 +- discord/template.py | 4 +- discord/types/appinfo.py | 5 + discord/types/embed.py | 11 ++ discord/types/team.py | 2 + discord/ui/item.py | 2 + discord/ui/modal.py | 2 + discord/ui/select.py | 1 - discord/ui/text_input.py | 5 +- discord/ui/view.py | 13 ++- discord/utils.py | 2 +- discord/voice_client.py | 19 ++-- discord/webhook/async_.py | 4 +- discord/webhook/sync.py | 2 +- discord/widget.py | 28 ++++-- 56 files changed, 738 insertions(+), 289 deletions(-) diff --git a/discord/__main__.py b/discord/__main__.py index 513b0cb38..7ee5b0314 100644 --- a/discord/__main__.py +++ b/discord/__main__.py @@ -31,6 +31,7 @@ import pkg_resources import aiohttp import platform + def show_version(): entries = [] @@ -47,10 +48,12 @@ def show_version(): entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname)) print('\n'.join(entries)) + def core(parser, args): if args.version: show_version() + _bot_template = """#!/usr/bin/env python3 from discord.ext import commands @@ -172,13 +175,36 @@ _base_table.update((chr(i), None) for i in range(32)) _translation_table = str.maketrans(_base_table) + def to_path(parser, name, *, replace_spaces=False): if isinstance(name, Path): return name if sys.platform == 'win32': - forbidden = ('CON', 'PRN', 'AUX', 'NUL', 'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7', \ - 'COM8', 'COM9', 'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9') + forbidden = ( + 'CON', + 'PRN', + 'AUX', + 'NUL', + 'COM1', + 'COM2', + 'COM3', + 'COM4', + 'COM5', + 'COM6', + 'COM7', + 'COM8', + 'COM9', + 'LPT1', + 'LPT2', + 'LPT3', + 'LPT4', + 'LPT5', + 'LPT6', + 'LPT7', + 'LPT8', + 'LPT9', + ) if len(name) <= 4 and name.upper() in forbidden: parser.error('invalid directory name given, use a different one') @@ -187,6 +213,7 @@ def to_path(parser, name, *, replace_spaces=False): name = name.replace(' ', '-') return Path(name) + def newbot(parser, args): new_directory = to_path(parser, args.directory) / to_path(parser, args.name) @@ -228,6 +255,7 @@ def newbot(parser, args): print('successfully made bot at', new_directory) + def newcog(parser, args): cog_dir = to_path(parser, args.directory) try: @@ -261,6 +289,7 @@ def newcog(parser, args): else: print('successfully made cog at', directory) + def add_newbot_args(subparser): parser = subparser.add_parser('newbot', help='creates a command bot project quickly') parser.set_defaults(func=newbot) @@ -271,6 +300,7 @@ def add_newbot_args(subparser): parser.add_argument('--sharded', help='whether to use AutoShardedBot', action='store_true') parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git') + def add_newcog_args(subparser): parser = subparser.add_parser('newcog', help='creates a new cog template quickly') parser.set_defaults(func=newcog) @@ -282,6 +312,7 @@ def add_newcog_args(subparser): parser.add_argument('--hide-commands', help='whether to hide all commands in the cog', action='store_true') parser.add_argument('--full', help='add all special methods as well', action='store_true') + def parse_args(): parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py') parser.add_argument('-v', '--version', action='store_true', help='shows the library version') @@ -292,9 +323,11 @@ def parse_args(): add_newcog_args(subparser) return parser, parser.parse_args() + def main(): parser, args = parse_args() args.func(parser, args) + if __name__ == '__main__': main() diff --git a/discord/activity.py b/discord/activity.py index 512053777..f9d43ccf9 100644 --- a/discord/activity.py +++ b/discord/activity.py @@ -807,14 +807,17 @@ class CustomActivity(BaseActivity): ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify] + @overload def create_activity(data: ActivityPayload) -> ActivityTypes: ... + @overload def create_activity(data: None) -> None: ... + def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]: if not data: return None @@ -831,11 +834,11 @@ def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]: return Activity(**data) else: # we removed the name key from data already - return CustomActivity(name=name, **data) # type: ignore + return CustomActivity(name=name, **data) # type: ignore elif game_type is ActivityType.streaming: if 'url' in data: # the url won't be None here - return Streaming(**data) # type: ignore + return Streaming(**data) # type: ignore return Activity(**data) elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data: return Spotify(**data) diff --git a/discord/asset.py b/discord/asset.py index 36ce08e27..5f9d9c3f6 100644 --- a/discord/asset.py +++ b/discord/asset.py @@ -33,9 +33,11 @@ from . import utils import yarl +# fmt: off __all__ = ( 'Asset', ) +# fmt: on if TYPE_CHECKING: ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png'] @@ -47,6 +49,7 @@ VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"} MISSING = utils.MISSING + class AssetMixin: url: str _state: Optional[Any] @@ -245,7 +248,7 @@ class Asset(AssetMixin): state, url=f'{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512', key=banner_hash, - animated=animated + animated=animated, ) def __str__(self) -> str: diff --git a/discord/audit_logs.py b/discord/audit_logs.py index b9cd2d827..10a69877f 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -61,6 +61,10 @@ if TYPE_CHECKING: from .sticker import GuildSticker from .threads import Thread + TargetType = Union[ + Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, StageInstance, GuildSticker, Thread, Object, None + ] + def _transform_timestamp(entry: AuditLogEntry, data: Optional[str]) -> Optional[datetime.datetime]: return utils.parse_time(data) @@ -154,12 +158,14 @@ def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]: return _transform + def _transform_type(entry: AuditLogEntry, data: int) -> Union[enums.ChannelType, enums.StickerType]: if entry.action.name.startswith('sticker_'): return enums.try_enum(enums.StickerType, data) else: return enums.try_enum(enums.ChannelType, data) + class AuditLogDiff: def __len__(self) -> int: return len(self.__dict__) @@ -456,7 +462,7 @@ class AuditLogEntry(Hashable): return utils.snowflake_time(self.id) @utils.cached_property - def target(self) -> Union[Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, StageInstance, GuildSticker, Thread, Object, None]: + def target(self) -> TargetType: try: converter = getattr(self, '_convert_target_' + self.action.target_type) except AttributeError: diff --git a/discord/backoff.py b/discord/backoff.py index 903ecf769..0d0cb7c40 100644 --- a/discord/backoff.py +++ b/discord/backoff.py @@ -31,9 +31,12 @@ from typing import Callable, Generic, Literal, TypeVar, overload, Union T = TypeVar('T', bool, Literal[True], Literal[False]) +# fmt: off __all__ = ( 'ExponentialBackoff', ) +# fmt: on + class ExponentialBackoff(Generic[T]): """An implementation of the exponential backoff algorithm @@ -62,14 +65,14 @@ class ExponentialBackoff(Generic[T]): self._exp: int = 0 self._max: int = 10 - self._reset_time: int = base * 2 ** 11 + self._reset_time: int = base * 2**11 self._last_invocation: float = time.monotonic() # Use our own random instance to avoid messing with global one rand = random.Random() rand.seed() - self._randfunc: Callable[..., Union[int, float]] = rand.randrange if integral else rand.uniform # type: ignore + self._randfunc: Callable[..., Union[int, float]] = rand.randrange if integral else rand.uniform # type: ignore @overload def delay(self: ExponentialBackoff[Literal[False]]) -> float: @@ -102,4 +105,4 @@ class ExponentialBackoff(Generic[T]): self._exp = 0 self._exp = min(self._exp + 1, self._max) - return self._randfunc(0, self._base * 2 ** self._exp) + return self._randfunc(0, self._base * 2**self._exp) diff --git a/discord/channel.py b/discord/channel.py index cd7cc1abb..3229377dd 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -814,7 +814,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): before_timestamp = str(before.id) else: before_timestamp = utils.snowflake_time(before.id).isoformat() - + update_before = lambda data: data['thread_metadata']['archive_timestamp'] endpoint = self.guild._state.http.get_public_archived_threads @@ -823,7 +823,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): endpoint = self.guild._state.http.get_joined_private_archived_threads elif private: endpoint = self.guild._state.http.get_private_archived_threads - + while True: retrieve = 50 if limit is None else max(limit, 50) data = await endpoint(self.id, before=before_timestamp, limit=retrieve) diff --git a/discord/client.py b/discord/client.py index 3527dba3f..e1d583616 100644 --- a/discord/client.py +++ b/discord/client.py @@ -43,7 +43,7 @@ from typing import ( TYPE_CHECKING, Tuple, TypeVar, - Union + Union, ) import aiohttp @@ -84,15 +84,18 @@ if TYPE_CHECKING: from .member import Member from .voice_client import VoiceProtocol +# fmt: off __all__ = ( 'Client', ) +# fmt: on Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]]) _log = logging.getLogger(__name__) + def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None: tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()} @@ -110,11 +113,14 @@ def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None: if task.cancelled(): continue if task.exception() is not None: - loop.call_exception_handler({ - 'message': 'Unhandled exception during Client.run shutdown.', - 'exception': task.exception(), - 'task': task - }) + loop.call_exception_handler( + { + 'message': 'Unhandled exception during Client.run shutdown.', + 'exception': task.exception(), + 'task': task, + } + ) + def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None: try: @@ -124,6 +130,7 @@ def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None: _log.info('Closing the event loop.') loop.close() + class Client: r"""Represents a client connection that connects to Discord. This class is used to interact with the Discord WebSocket and API. @@ -215,6 +222,7 @@ class Client: loop: :class:`asyncio.AbstractEventLoop` The event loop that the client uses for asynchronous operations. """ + def __init__( self, *, @@ -232,14 +240,16 @@ class Client: proxy: Optional[str] = options.pop('proxy', None) proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) unsync_clock: bool = options.pop('assume_unsync_clock', True) - self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop) + self.http: HTTPClient = HTTPClient( + connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop + ) self._handlers: Dict[str, Callable] = { - 'ready': self._handle_ready + 'ready': self._handle_ready, } self._hooks: Dict[str, Callable] = { - 'before_identify': self._call_before_identify_hook + 'before_identify': self._call_before_identify_hook, } self._enable_debug_events: bool = options.pop('enable_debug_events', False) @@ -260,8 +270,9 @@ class Client: return self.ws def _get_state(self, **options: Any) -> ConnectionState: - return ConnectionState(dispatch=self.dispatch, handlers=self._handlers, - hooks=self._hooks, http=self.http, loop=self.loop, **options) + return ConnectionState( + dispatch=self.dispatch, handlers=self._handlers, hooks=self._hooks, http=self.http, loop=self.loop, **options + ) def _handle_ready(self) -> None: self._ready.set() @@ -344,7 +355,7 @@ class Client: If this is not passed via ``__init__`` then this is retrieved through the gateway when an event contains the data. Usually after :func:`~discord.on_connect` is called. - + .. versionadded:: 2.0 """ return self._connection.application_id @@ -361,7 +372,13 @@ class Client: """:class:`bool`: Specifies if the client's internal cache is ready for use.""" return self._ready.is_set() - async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None: + async def _run_event( + self, + coro: Callable[..., Coroutine[Any, Any, Any]], + event_name: str, + *args: Any, + **kwargs: Any, + ) -> None: try: await coro(*args, **kwargs) except asyncio.CancelledError: @@ -372,7 +389,13 @@ class Client: except asyncio.CancelledError: pass - def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task: + def _schedule_event( + self, + coro: Callable[..., Coroutine[Any, Any, Any]], + event_name: str, + *args: Any, + **kwargs: Any, + ) -> asyncio.Task: wrapped = self._run_event(coro, event_name, *args, **kwargs) # Schedules the task return asyncio.create_task(wrapped, name=f'discord.py: {event_name}') @@ -530,12 +553,14 @@ class Client: self.dispatch('disconnect') ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id) continue - except (OSError, - HTTPException, - GatewayNotFound, - ConnectionClosed, - aiohttp.ClientError, - asyncio.TimeoutError) as exc: + except ( + OSError, + HTTPException, + GatewayNotFound, + ConnectionClosed, + aiohttp.ClientError, + asyncio.TimeoutError, + ) as exc: self.dispatch('disconnect') if not reconnect: @@ -699,10 +724,10 @@ class Client: self._connection._activity = None elif isinstance(value, BaseActivity): # ConnectionState._activity is typehinted as ActivityPayload, we're passing Dict[str, Any] - self._connection._activity = value.to_dict() # type: ignore + self._connection._activity = value.to_dict() # type: ignore else: raise TypeError('activity must derive from BaseActivity.') - + @property def status(self): """:class:`.Status`: @@ -777,7 +802,7 @@ class Client: This is useful if you have a channel_id but don't want to do an API call to send messages to it. - + .. versionadded:: 2.0 Parameters @@ -1030,8 +1055,10 @@ class Client: future = self.loop.create_future() if check is None: + def _check(*args): return True + check = _check ev = event.lower() @@ -1273,7 +1300,7 @@ class Client: """ code = utils.resolve_template(code) data = await self.http.get_template(code) - return Template(data=data, state=self._connection) # type: ignore + return Template(data=data, state=self._connection) # type: ignore async def fetch_guild(self, guild_id: int, /) -> Guild: """|coro| @@ -1402,7 +1429,9 @@ class Client: # Invite management - async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite: + async def fetch_invite( + self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True + ) -> Invite: """|coro| Gets an :class:`.Invite` from a discord.gg URL or ID. @@ -1604,13 +1633,13 @@ class Client: if ch_type in (ChannelType.group, ChannelType.private): # the factory will be a DMChannel or GroupChannel here - channel = factory(me=self.user, data=data, state=self._connection) # type: ignore + channel = factory(me=self.user, data=data, state=self._connection) # type: ignore else: # the factory can't be a DMChannel or GroupChannel here - guild_id = int(data['guild_id']) # type: ignore + guild_id = int(data['guild_id']) # type: ignore guild = self.get_guild(guild_id) or Object(id=guild_id) # GuildChannels expect a Guild, we may be passing an Object - channel = factory(guild=guild, state=self._connection, data=data) # type: ignore + channel = factory(guild=guild, state=self._connection, data=data) # type: ignore return channel @@ -1661,7 +1690,7 @@ class Client: """ data = await self.http.get_sticker(sticker_id) cls, _ = _sticker_factory(data['type']) # type: ignore - return cls(state=self._connection, data=data) # type: ignore + return cls(state=self._connection, data=data) # type: ignore async def fetch_premium_sticker_packs(self) -> List[StickerPack]: """|coro| @@ -1716,7 +1745,7 @@ class Client: This method should be used for when a view is comprised of components that last longer than the lifecycle of the program. - + .. versionadded:: 2.0 Parameters @@ -1748,7 +1777,7 @@ class Client: @property def persistent_views(self) -> Sequence[View]: """Sequence[:class:`.View`]: A sequence of persistent views added to the client. - + .. versionadded:: 2.0 """ return self._connection.persistent_views diff --git a/discord/colour.py b/discord/colour.py index 2833e6225..ebe24afd4 100644 --- a/discord/colour.py +++ b/discord/colour.py @@ -85,7 +85,7 @@ class Colour: self.value: int = value def _get_byte(self, byte: int) -> int: - return (self.value >> (8 * byte)) & 0xff + return (self.value >> (8 * byte)) & 0xFF def __eq__(self, other: Any) -> bool: return isinstance(other, Colour) and self.value == other.value @@ -164,12 +164,12 @@ class Colour: @classmethod def teal(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x1abc9c``.""" - return cls(0x1abc9c) + return cls(0x1ABC9C) @classmethod def dark_teal(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x11806a``.""" - return cls(0x11806a) + return cls(0x11806A) @classmethod def brand_green(cls: Type[CT]) -> CT: @@ -182,17 +182,17 @@ class Colour: @classmethod def green(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x2ecc71``.""" - return cls(0x2ecc71) + return cls(0x2ECC71) @classmethod def dark_green(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x1f8b4c``.""" - return cls(0x1f8b4c) + return cls(0x1F8B4C) @classmethod def blue(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x3498db``.""" - return cls(0x3498db) + return cls(0x3498DB) @classmethod def dark_blue(cls: Type[CT]) -> CT: @@ -202,42 +202,42 @@ class Colour: @classmethod def purple(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x9b59b6``.""" - return cls(0x9b59b6) + return cls(0x9B59B6) @classmethod def dark_purple(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x71368a``.""" - return cls(0x71368a) + return cls(0x71368A) @classmethod def magenta(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xe91e63``.""" - return cls(0xe91e63) + return cls(0xE91E63) @classmethod def dark_magenta(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xad1457``.""" - return cls(0xad1457) + return cls(0xAD1457) @classmethod def gold(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xf1c40f``.""" - return cls(0xf1c40f) + return cls(0xF1C40F) @classmethod def dark_gold(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xc27c0e``.""" - return cls(0xc27c0e) + return cls(0xC27C0E) @classmethod def orange(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xe67e22``.""" - return cls(0xe67e22) + return cls(0xE67E22) @classmethod def dark_orange(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xa84300``.""" - return cls(0xa84300) + return cls(0xA84300) @classmethod def brand_red(cls: Type[CT]) -> CT: @@ -250,45 +250,45 @@ class Colour: @classmethod def red(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``.""" - return cls(0xe74c3c) + return cls(0xE74C3C) @classmethod def dark_red(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x992d22``.""" - return cls(0x992d22) + return cls(0x992D22) @classmethod def lighter_grey(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``.""" - return cls(0x95a5a6) + return cls(0x95A5A6) lighter_gray = lighter_grey @classmethod def dark_grey(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x607d8b``.""" - return cls(0x607d8b) + return cls(0x607D8B) dark_gray = dark_grey @classmethod def light_grey(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x979c9f``.""" - return cls(0x979c9f) + return cls(0x979C9F) light_gray = light_grey @classmethod def darker_grey(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x546e7a``.""" - return cls(0x546e7a) + return cls(0x546E7A) darker_gray = darker_grey @classmethod def og_blurple(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x7289da``.""" - return cls(0x7289da) + return cls(0x7289DA) @classmethod def blurple(cls: Type[CT]) -> CT: @@ -298,7 +298,7 @@ class Colour: @classmethod def greyple(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x99aab5``.""" - return cls(0x99aab5) + return cls(0x99AAB5) @classmethod def dark_theme(cls: Type[CT]) -> CT: diff --git a/discord/context_managers.py b/discord/context_managers.py index a3ab0d197..5ba6efbc1 100644 --- a/discord/context_managers.py +++ b/discord/context_managers.py @@ -34,9 +34,12 @@ if TYPE_CHECKING: TypingT = TypeVar('TypingT', bound='Typing') +# fmt: off __all__ = ( 'Typing', ) +# fmt: on + def _typing_done_callback(fut: asyncio.Future) -> None: # just retrieve any exception and call it a day @@ -45,6 +48,7 @@ def _typing_done_callback(fut: asyncio.Future) -> None: except (asyncio.CancelledError, Exception): pass + class Typing: def __init__(self, messageable: Messageable) -> None: self.loop: asyncio.AbstractEventLoop = messageable._state.loop @@ -67,7 +71,8 @@ class Typing: self.task.add_done_callback(_typing_done_callback) return self - def __exit__(self, + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], @@ -79,7 +84,8 @@ class Typing: await channel._state.http.send_typing(channel.id) return self.__enter__() - async def __aexit__(self, + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], diff --git a/discord/embeds.py b/discord/embeds.py index 7033a10e7..5ad28830d 100644 --- a/discord/embeds.py +++ b/discord/embeds.py @@ -30,9 +30,11 @@ from typing import Any, Dict, Final, List, Mapping, Protocol, TYPE_CHECKING, Typ from . import utils from .colour import Colour +# fmt: off __all__ = ( 'Embed', ) +# fmt: on class _EmptyEmbed: @@ -366,7 +368,7 @@ class Embed: self._footer['icon_url'] = str(icon_url) return self - + def remove_footer(self: E) -> E: """Clears embed's footer information. @@ -381,7 +383,7 @@ class Embed: pass return self - + @property def image(self) -> _EmbedMediaProxy: """Returns an ``EmbedProxy`` denoting the image contents. diff --git a/discord/emoji.py b/discord/emoji.py index 8b51d0fd3..cbcb87fdf 100644 --- a/discord/emoji.py +++ b/discord/emoji.py @@ -30,9 +30,11 @@ from .utils import SnowflakeList, snowflake_time, MISSING from .partial_emoji import _EmojiTag, PartialEmoji from .user import User +# fmt: off __all__ = ( 'Emoji', ) +# fmt: on if TYPE_CHECKING: from .types.emoji import Emoji as EmojiPayload diff --git a/discord/enums.py b/discord/enums.py index 103e05a20..b39832bc3 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -70,6 +70,7 @@ def _create_value_cls(name, comparable): cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value return cls + def _is_descriptor(obj): return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__') diff --git a/discord/ext/commands/_types.py b/discord/ext/commands/_types.py index 9b1559870..6cd941f1b 100644 --- a/discord/ext/commands/_types.py +++ b/discord/ext/commands/_types.py @@ -39,7 +39,9 @@ CoroFunc = Callable[..., Coro[Any]] Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]] Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]] -Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]] +Error = Union[ + Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]] +] # This is merely a tag type to avoid circular import issues. diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 7fef7c4e7..3563064f1 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -66,6 +66,7 @@ T = TypeVar('T') CFT = TypeVar('CFT', bound='CoroFunc') CXT = TypeVar('CXT', bound='Context') + def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]: """A callable that implements a command prefix equivalent to being mentioned. @@ -74,6 +75,7 @@ def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]: # bot.user will never be None when this is called return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore + def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]: """A callable that implements when mentioned or other prefixes provided. @@ -103,6 +105,7 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M ---------- :func:`.when_mentioned` """ + def inner(bot, msg): r = list(prefixes) r = when_mentioned(bot, msg) + r @@ -110,15 +113,19 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M return inner + def _is_submodule(parent: str, child: str) -> bool: return parent == child or child.startswith(parent + ".") + class _DefaultRepr: def __repr__(self): return '' + _default = _DefaultRepr() + class BotBase(GroupMixin): def __init__(self, command_prefix, help_command=_default, description=None, **options): super().__init__(**options) @@ -833,11 +840,13 @@ class BotBase(GroupMixin): raise errors.ExtensionNotLoaded(name) # get the previous module states from sys modules + # fmt: off modules = { name: module for name, module in sys.modules.items() if _is_submodule(lib.__name__, name) } + # fmt: on try: # Unload and then load the module... @@ -913,8 +922,10 @@ class BotBase(GroupMixin): if isinstance(ret, collections.abc.Iterable): raise - raise TypeError("command_prefix must be plain string, iterable of strings, or callable " - f"returning either of these, not {ret.__class__.__name__}") + raise TypeError( + "command_prefix must be plain string, iterable of strings, or callable " + f"returning either of these, not {ret.__class__.__name__}" + ) if not ret: raise ValueError("Iterable command_prefix must contain at least one prefix") @@ -974,14 +985,17 @@ class BotBase(GroupMixin): except TypeError: if not isinstance(prefix, list): - raise TypeError("get_prefix must return either a string or a list of string, " - f"not {prefix.__class__.__name__}") + raise TypeError( + "get_prefix must return either a string or a list of string, " f"not {prefix.__class__.__name__}" + ) # It's possible a bad command_prefix got us here. for value in prefix: if not isinstance(value, str): - raise TypeError("Iterable command_prefix or list returned from get_prefix must " - f"contain only strings, not {value.__class__.__name__}") + raise TypeError( + "Iterable command_prefix or list returned from get_prefix must " + f"contain only strings, not {value.__class__.__name__}" + ) # Getting here shouldn't happen raise @@ -1053,6 +1067,7 @@ class BotBase(GroupMixin): async def on_message(self, message): await self.process_commands(message) + class Bot(BotBase, discord.Client): """Represents a discord bot. @@ -1123,10 +1138,13 @@ class Bot(BotBase, discord.Client): .. versionadded:: 1.7 """ + pass + class AutoShardedBot(BotBase, discord.AutoShardedClient): """This is similar to :class:`.Bot` except that it is inherited from :class:`discord.AutoShardedClient` instead. """ + pass diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py index 9931557db..12c9ab36d 100644 --- a/discord/ext/commands/cog.py +++ b/discord/ext/commands/cog.py @@ -45,6 +45,7 @@ FuncT = TypeVar('FuncT', bound=Callable[..., Any]) MISSING: Any = discord.utils.MISSING + class CogMeta(type): """A metaclass for defining a cog. @@ -104,6 +105,7 @@ class CogMeta(type): async def bar(self, ctx): pass # hidden -> False """ + __cog_name__: str __cog_settings__: Dict[str, Any] __cog_commands__: List[Command] @@ -150,7 +152,7 @@ class CogMeta(type): raise TypeError(no_bot_cog.format(base, elem)) listeners[elem] = value - new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__ + new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__ listeners_as_list = [] for listener in listeners.values(): @@ -169,10 +171,12 @@ class CogMeta(type): def qualified_name(cls) -> str: return cls.__cog_name__ + def _cog_special_method(func: FuncT) -> FuncT: func.__cog_special_method__ = None return func + class Cog(metaclass=CogMeta): """The base class that all cogs must inherit from. @@ -183,6 +187,7 @@ class Cog(metaclass=CogMeta): When inheriting from this class, the options shown in :class:`CogMeta` are equally valid here. """ + __cog_name__: ClassVar[str] __cog_settings__: ClassVar[Dict[str, Any]] __cog_commands__: ClassVar[List[Command]] @@ -199,10 +204,7 @@ class Cog(metaclass=CogMeta): # r.e type ignore, type-checker complains about overriding a ClassVar self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore - lookup = { - cmd.qualified_name: cmd - for cmd in self.__cog_commands__ - } + lookup = {cmd.qualified_name: cmd for cmd in self.__cog_commands__} # Update the Command instances dynamically as well for command in self.__cog_commands__: @@ -255,6 +257,7 @@ class Cog(metaclass=CogMeta): A command or group from the cog. """ from .core import GroupMixin + for command in self.__cog_commands__: if command.parent is None: yield command @@ -315,6 +318,7 @@ class Cog(metaclass=CogMeta): # to pick it up but the metaclass unfurls the function and # thus the assignments need to be on the actual function return func + return decorator def has_error_handler(self) -> bool: diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 75bd99874..8e5d59851 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -49,9 +49,11 @@ if TYPE_CHECKING: from .help import HelpCommand from .view import StringView +# fmt: off __all__ = ( 'Context', ) +# fmt: on MISSING: Any = discord.utils.MISSING @@ -122,7 +124,8 @@ class Context(discord.abc.Messageable, Generic[BotT]): or invoked. """ - def __init__(self, + def __init__( + self, *, message: Message, bot: BotT, @@ -237,7 +240,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): view.index = len(self.prefix or '') view.previous = 0 self.invoked_parents = [] - self.invoked_with = view.get_word() # advance to get the root command + self.invoked_with = view.get_word() # advance to get the root command else: to_call = cmd diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index 2e008aed4..b3af2bf04 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -48,14 +48,15 @@ __all__ = ( C = TypeVar('C', bound='CooldownMapping') MC = TypeVar('MC', bound='MaxConcurrency') + class BucketType(Enum): - default = 0 - user = 1 - guild = 2 - channel = 3 - member = 4 + default = 0 + user = 1 + guild = 2 + channel = 3 + member = 4 category = 5 - role = 6 + role = 6 def get_key(self, msg: Message) -> Any: if self is BucketType.user: @@ -192,6 +193,7 @@ class Cooldown: def __repr__(self) -> str: return f'' + class CooldownMapping: def __init__( self, @@ -256,12 +258,12 @@ class CooldownMapping: bucket = self.get_bucket(message, current) return bucket.update_rate_limit(current) -class DynamicCooldownMapping(CooldownMapping): +class DynamicCooldownMapping(CooldownMapping): def __init__( self, factory: Callable[[Message], Cooldown], - type: Callable[[Message], Any] + type: Callable[[Message], Any], ) -> None: super().__init__(None, type) self._factory: Callable[[Message], Cooldown] = factory @@ -278,6 +280,7 @@ class DynamicCooldownMapping(CooldownMapping): def create_bucket(self, message: Message) -> Cooldown: return self._factory(message) + class _Semaphore: """This class is a version of a semaphore. @@ -337,6 +340,7 @@ class _Semaphore: self.value += 1 self.wake_up() + class MaxConcurrency: __slots__ = ('number', 'per', 'wait', '_mapping') diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 393880d74..c420ca4fe 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -93,7 +93,7 @@ __all__ = ( 'is_owner', 'is_nsfw', 'has_guild_permissions', - 'bot_has_guild_permissions' + 'bot_has_guild_permissions', ) MISSING: Any = discord.utils.MISSING @@ -112,6 +112,7 @@ if TYPE_CHECKING: else: P = TypeVar('P') + def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: partial = functools.partial while True: @@ -158,8 +159,10 @@ def wrap_callback(coro): except Exception as exc: raise CommandInvokeError(exc) from exc return ret + return wrapped + def hooked_wrapped_callback(command, ctx, coro): @functools.wraps(coro) async def wrapped(*args, **kwargs): @@ -180,6 +183,7 @@ def hooked_wrapped_callback(command, ctx, coro): await command.call_after_hooks(ctx) return ret + return wrapped @@ -202,6 +206,7 @@ class _CaseInsensitiveDict(dict): def __setitem__(self, k, v): super().__setitem__(k.casefold(), v) + class Command(_BaseCommand, Generic[CogT, P, T]): r"""A class that implements the protocol for a bot text command. @@ -269,8 +274,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]): which calls converters. If ``False`` then cooldown processing is done first and then the converters are called second. Defaults to ``False``. extras: :class:`dict` - A dict of user provided extras to attach to the Command. - + A dict of user provided extras to attach to the Command. + .. note:: This object may be copied by the library. @@ -295,10 +300,14 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.__original_kwargs__ = kwargs.copy() return self - def __init__(self, func: Union[ + def __init__( + self, + func: Union[ Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]], - ], **kwargs: Any): + ], + **kwargs: Any, + ): if not asyncio.iscoroutinefunction(func): raise TypeError('Callback must be a coroutine.') @@ -344,7 +353,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): cooldown = func.__commands_cooldown__ except AttributeError: cooldown = kwargs.get('cooldown') - + if cooldown is None: buckets = CooldownMapping(cooldown, BucketType.default) elif isinstance(cooldown, CooldownMapping): @@ -386,17 +395,19 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.after_invoke(after_invoke) @property - def callback(self) -> Union[ - Callable[Concatenate[CogT, Context, P], Coro[T]], - Callable[Concatenate[Context, P], Coro[T]], - ]: + def callback( + self, + ) -> Union[Callable[Concatenate[CogT, Context, P], Coro[T]], Callable[Concatenate[Context, P], Coro[T]],]: return self._callback @callback.setter - def callback(self, function: Union[ + def callback( + self, + function: Union[ Callable[Concatenate[CogT, Context, P], Coro[T]], Callable[Concatenate[Context, P], Coro[T]], - ]) -> None: + ], + ) -> None: self._callback = function unwrap = unwrap_function(function) self.module = unwrap.__module__ @@ -561,7 +572,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if view.eof: if param.kind == param.VAR_POSITIONAL: - raise RuntimeError() # break the loop + raise RuntimeError() # break the loop if required: if self._is_typing_optional(param.annotation): return None @@ -616,7 +627,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): value = await run_converters(ctx, converter, argument, param) # type: ignore except (CommandError, ArgumentParsingError): view.index = previous - raise RuntimeError() from None # break loop + raise RuntimeError() from None # break loop else: return value @@ -653,9 +664,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): entries = [] command = self # command.parent is type-hinted as GroupMixin some attributes are resolved via MRO - while command.parent is not None: # type: ignore - command = command.parent # type: ignore - entries.append(command.name) # type: ignore + while command.parent is not None: # type: ignore + command = command.parent # type: ignore + entries.append(command.name) # type: ignore return ' '.join(reversed(entries)) @@ -671,8 +682,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]): """ entries = [] command = self - while command.parent is not None: # type: ignore - command = command.parent # type: ignore + while command.parent is not None: # type: ignore + command = command.parent # type: ignore entries.append(command) return entries @@ -1061,8 +1072,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): # do [name] since [name=None] or [name=] are not exactly useful for the user. should_print = param.default if isinstance(param.default, str) else param.default is not None if should_print: - result.append(f'[{name}={param.default}]' if not greedy else - f'[{name}={param.default}]...') + result.append(f'[{name}={param.default}]' if not greedy else f'[{name}={param.default}]...') continue else: result.append(f'[{name}]') @@ -1135,6 +1145,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): finally: ctx.command = original + class GroupMixin(Generic[CogT]): """A mixin that implements common functionality for classes that behave similar to :class:`.Group` and are allowed to register commands. @@ -1147,6 +1158,7 @@ class GroupMixin(Generic[CogT]): case_insensitive: :class:`bool` Whether the commands should be case insensitive. Defaults to ``False``. """ + def __init__(self, *args: Any, **kwargs: Any) -> None: case_insensitive = kwargs.get('case_insensitive', False) self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {} @@ -1320,7 +1332,9 @@ class GroupMixin(Generic[CogT]): Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]], ] - ], Command[CogT, P, T]]: + ], + Command[CogT, P, T], + ]: ... @overload @@ -1348,6 +1362,7 @@ class GroupMixin(Generic[CogT]): Callable[..., :class:`Command`] A decorator that converts the provided method into a Command, adds it to the bot, then returns it. """ + def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> CommandT: kwargs.setdefault('parent', self) result = command(name=name, cls=cls, *args, **kwargs)(func) @@ -1363,12 +1378,10 @@ class GroupMixin(Generic[CogT]): cls: Type[Group[CogT, P, T]] = ..., *args: Any, **kwargs: Any, - ) -> Callable[[ - Union[ - Callable[Concatenate[CogT, ContextT, P], Coro[T]], - Callable[Concatenate[ContextT, P], Coro[T]] - ] - ], Group[CogT, P, T]]: + ) -> Callable[ + [Union[Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]]]], + Group[CogT, P, T], + ]: ... @overload @@ -1396,6 +1409,7 @@ class GroupMixin(Generic[CogT]): Callable[..., :class:`Group`] A decorator that converts the provided method into a Group, adds it to the bot, then returns it. """ + def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> GroupT: kwargs.setdefault('parent', self) result = group(name=name, cls=cls, *args, **kwargs)(func) @@ -1404,6 +1418,7 @@ class GroupMixin(Generic[CogT]): return decorator + class Group(GroupMixin[CogT], Command[CogT, P, T]): """A class that implements a grouping protocol for commands to be executed as subcommands. @@ -1426,6 +1441,7 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]): Indicates if the group's commands should be case insensitive. Defaults to ``False``. """ + def __init__(self, *args: Any, **attrs: Any) -> None: self.invoke_without_command: bool = attrs.pop('invoke_without_command', False) super().__init__(*args, **attrs) @@ -1514,8 +1530,10 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]): view.previous = previous await super().reinvoke(ctx, call_hooks=call_hooks) + # Decorators + @overload def command( name: str = ..., @@ -1527,10 +1545,12 @@ def command( Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]], ] - ] -, Command[CogT, P, T]]: + ], + Command[CogT, P, T], +]: ... + @overload def command( name: str = ..., @@ -1542,22 +1562,25 @@ def command( Callable[Concatenate[CogT, ContextT, P], Coro[Any]], Callable[Concatenate[ContextT, P], Coro[Any]], ] - ] -, CommandT]: + ], + CommandT, +]: ... + def command( name: str = MISSING, cls: Type[CommandT] = MISSING, - **attrs: Any + **attrs: Any, ) -> Callable[ [ Union[ Callable[Concatenate[ContextT, P], Coro[Any]], Callable[Concatenate[CogT, ContextT, P], Coro[T]], ] - ] -, Union[Command[CogT, P, T], CommandT]]: + ], + Union[Command[CogT, P, T], CommandT], +]: """A decorator that transforms a function into a :class:`.Command` or if called with :func:`.group`, :class:`.Group`. @@ -1590,16 +1613,19 @@ def command( if cls is MISSING: cls = Command # type: ignore - def decorator(func: Union[ + def decorator( + func: Union[ Callable[Concatenate[ContextT, P], Coro[Any]], Callable[Concatenate[CogT, ContextT, P], Coro[Any]], - ]) -> CommandT: + ] + ) -> CommandT: if isinstance(func, Command): raise TypeError('Callback is already a command.') return cls(func, name=name, **attrs) return decorator + @overload def group( name: str = ..., @@ -1611,10 +1637,12 @@ def group( Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]], ] - ] -, Group[CogT, P, T]]: + ], + Group[CogT, P, T], +]: ... + @overload def group( name: str = ..., @@ -1626,10 +1654,12 @@ def group( Callable[Concatenate[CogT, ContextT, P], Coro[Any]], Callable[Concatenate[ContextT, P], Coro[Any]], ] - ] -, GroupT]: + ], + GroupT, +]: ... + def group( name: str = MISSING, cls: Type[GroupT] = MISSING, @@ -1640,8 +1670,9 @@ def group( Callable[Concatenate[ContextT, P], Coro[Any]], Callable[Concatenate[CogT, ContextT, P], Coro[T]], ] - ] -, Union[Group[CogT, P, T], GroupT]]: + ], + Union[Group[CogT, P, T], GroupT], +]: """A decorator that transforms a function into a :class:`.Group`. This is similar to the :func:`.command` decorator but the ``cls`` @@ -1654,6 +1685,7 @@ def group( cls = Group # type: ignore return command(name=name, cls=cls, **attrs) # type: ignore + def check(predicate: Check) -> Callable[[T], T]: r"""A decorator that adds a check to the :class:`.Command` or its subclasses. These checks could be accessed via :attr:`.Command.checks`. @@ -1739,13 +1771,16 @@ def check(predicate: Check) -> Callable[[T], T]: if inspect.iscoroutinefunction(predicate): decorator.predicate = predicate else: + @functools.wraps(predicate) async def wrapper(ctx): return predicate(ctx) # type: ignore + decorator.predicate = wrapper return decorator # type: ignore + def check_any(*checks: Check) -> Callable[[T], T]: r"""A :func:`check` that is added that checks if any of the checks passed will pass, i.e. using logical OR. @@ -1814,6 +1849,7 @@ def check_any(*checks: Check) -> Callable[[T], T]: return check(predicate) + def has_role(item: Union[int, str]) -> Callable[[T], T]: """A :func:`.check` that is added that checks if the member invoking the command has the role specified via the name or ID specified. @@ -1856,6 +1892,7 @@ def has_role(item: Union[int, str]) -> Callable[[T], T]: return check(predicate) + def has_any_role(*items: Union[int, str]) -> Callable[[T], T]: r"""A :func:`.check` that is added that checks if the member invoking the command has **any** of the roles specified. This means that if they have @@ -1887,6 +1924,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]: async def cool(ctx): await ctx.send('You are cool indeed') """ + def predicate(ctx): if ctx.guild is None: raise NoPrivateMessage() @@ -1899,6 +1937,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]: return check(predicate) + def bot_has_role(item: int) -> Callable[[T], T]: """Similar to :func:`.has_role` except checks if the bot itself has the role. @@ -1925,8 +1964,10 @@ def bot_has_role(item: int) -> Callable[[T], T]: if role is None: raise BotMissingRole(item) return True + return check(predicate) + def bot_has_any_role(*items: int) -> Callable[[T], T]: """Similar to :func:`.has_any_role` except checks if the bot itself has any of the roles listed. @@ -1940,6 +1981,7 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]: Raise :exc:`.BotMissingAnyRole` or :exc:`.NoPrivateMessage` instead of generic checkfailure """ + def predicate(ctx): if ctx.guild is None: raise NoPrivateMessage() @@ -1949,8 +1991,10 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]: if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items): return True raise BotMissingAnyRole(list(items)) + return check(predicate) + def has_permissions(**perms: bool) -> Callable[[T], T]: """A :func:`.check` that is added that checks if the member has all of the permissions necessary. @@ -1998,6 +2042,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) + def bot_has_permissions(**perms: bool) -> Callable[[T], T]: """Similar to :func:`.has_permissions` except checks if the bot itself has the permissions listed. @@ -2024,6 +2069,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) + def has_guild_permissions(**perms: bool) -> Callable[[T], T]: """Similar to :func:`.has_permissions`, but operates on guild wide permissions instead of the current channel permissions. @@ -2052,6 +2098,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) + def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: """Similar to :func:`.has_guild_permissions`, but checks the bot members guild permissions. @@ -2077,6 +2124,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) + def dm_only() -> Callable[[T], T]: """A :func:`.check` that indicates this command must only be used in a DM context. Only private messages are allowed when @@ -2095,6 +2143,7 @@ def dm_only() -> Callable[[T], T]: return check(predicate) + def guild_only() -> Callable[[T], T]: """A :func:`.check` that indicates this command must only be used in a guild context only. Basically, no private messages are allowed when @@ -2111,6 +2160,7 @@ def guild_only() -> Callable[[T], T]: return check(predicate) + def is_owner() -> Callable[[T], T]: """A :func:`.check` that checks if the person invoking this command is the owner of the bot. @@ -2128,6 +2178,7 @@ def is_owner() -> Callable[[T], T]: return check(predicate) + def is_nsfw() -> Callable[[T], T]: """A :func:`.check` that checks if the channel is a NSFW channel. @@ -2139,14 +2190,21 @@ def is_nsfw() -> Callable[[T], T]: Raise :exc:`.NSFWChannelRequired` instead of generic :exc:`.CheckFailure`. DM channels will also now pass this check. """ + def pred(ctx: Context) -> bool: ch = ctx.channel if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()): return True raise NSFWChannelRequired(ch) # type: ignore + return check(pred) -def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default) -> Callable[[T], T]: + +def cooldown( + rate: int, + per: float, + type: Union[BucketType, Callable[[Message], Any]] = BucketType.default, +) -> Callable[[T], T]: """A decorator that adds a cooldown to a :class:`.Command` A cooldown allows a command to only be used a specific amount @@ -2179,9 +2237,14 @@ def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message], else: func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type) return func + return decorator # type: ignore -def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default) -> Callable[[T], T]: + +def dynamic_cooldown( + cooldown: Union[BucketType, Callable[[Message], Any]], + type: BucketType = BucketType.default, +) -> Callable[[T], T]: """A decorator that adds a dynamic cooldown to a :class:`.Command` This differs from :func:`.cooldown` in that it takes a function that @@ -2219,8 +2282,10 @@ def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type else: func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type) return func + return decorator # type: ignore + def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]: """A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses. @@ -2252,8 +2317,10 @@ def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: else: func.__commands_max_concurrency__ = value return func + return decorator # type: ignore + def before_invoke(coro) -> Callable[[T], T]: """A decorator that registers a coroutine as a pre-invoke hook. @@ -2292,14 +2359,17 @@ def before_invoke(coro) -> Callable[[T], T]: bot.add_cog(What()) """ + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: if isinstance(func, Command): func.before_invoke(coro) else: func.__before_invoke__ = coro return func + return decorator # type: ignore + def after_invoke(coro) -> Callable[[T], T]: """A decorator that registers a coroutine as a post-invoke hook. @@ -2308,10 +2378,12 @@ def after_invoke(coro) -> Callable[[T], T]: .. versionadded:: 1.4 """ + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: if isinstance(func, Command): func.after_invoke(coro) else: func.__after_invoke__ = coro return func + return decorator # type: ignore diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 11d442416..b02d97fae 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -100,6 +100,7 @@ __all__ = ( 'MissingRequiredFlag', ) + class CommandError(DiscordException): r"""The base exception type for all command related errors. @@ -109,6 +110,7 @@ class CommandError(DiscordException): in a special way as they are caught and passed into a special event from :class:`.Bot`\, :func:`.on_command_error`. """ + def __init__(self, message: Optional[str] = None, *args: Any) -> None: if message is not None: # clean-up @everyone and @here mentions @@ -117,6 +119,7 @@ class CommandError(DiscordException): else: super().__init__(*args) + class ConversionError(CommandError): """Exception raised when a Converter class raises non-CommandError. @@ -130,18 +133,22 @@ class ConversionError(CommandError): The original exception that was raised. You can also get this via the ``__cause__`` attribute. """ + def __init__(self, converter: Converter, original: Exception) -> None: self.converter: Converter = converter self.original: Exception = original + class UserInputError(CommandError): """The base exception type for errors that involve errors regarding user input. This inherits from :exc:`CommandError`. """ + pass + class CommandNotFound(CommandError): """Exception raised when a command is attempted to be invoked but no command under that name is found. @@ -151,8 +158,10 @@ class CommandNotFound(CommandError): This inherits from :exc:`CommandError`. """ + pass + class MissingRequiredArgument(UserInputError): """Exception raised when parsing a command and a parameter that is required is not encountered. @@ -164,33 +173,41 @@ class MissingRequiredArgument(UserInputError): param: :class:`inspect.Parameter` The argument that is missing. """ + def __init__(self, param: Parameter) -> None: self.param: Parameter = param super().__init__(f'{param.name} is a required argument that is missing.') + class TooManyArguments(UserInputError): """Exception raised when the command was passed too many arguments and its :attr:`.Command.ignore_extra` attribute was not set to ``True``. This inherits from :exc:`UserInputError` """ + pass + class BadArgument(UserInputError): """Exception raised when a parsing or conversion failure is encountered on an argument to pass into a command. This inherits from :exc:`UserInputError` """ + pass + class CheckFailure(CommandError): """Exception raised when the predicates in :attr:`.Command.checks` have failed. This inherits from :exc:`CommandError` """ + pass + class CheckAnyFailure(CheckFailure): """Exception raised when all predicates in :func:`check_any` fail. @@ -211,15 +228,18 @@ class CheckAnyFailure(CheckFailure): self.errors: List[Callable[[Context], bool]] = errors super().__init__('You do not have permission to run this command.') + class PrivateMessageOnly(CheckFailure): """Exception raised when an operation does not work outside of private message contexts. This inherits from :exc:`CheckFailure` """ + def __init__(self, message: Optional[str] = None) -> None: super().__init__(message or 'This command can only be used in private messages.') + class NoPrivateMessage(CheckFailure): """Exception raised when an operation does not work in private message contexts. @@ -230,13 +250,16 @@ class NoPrivateMessage(CheckFailure): def __init__(self, message: Optional[str] = None) -> None: super().__init__(message or 'This command cannot be used in private messages.') + class NotOwner(CheckFailure): """Exception raised when the message author is not the owner of the bot. This inherits from :exc:`CheckFailure` """ + pass + class ObjectNotFound(BadArgument): """Exception raised when the argument provided did not match the format of an ID or a mention. @@ -250,10 +273,12 @@ class ObjectNotFound(BadArgument): argument: :class:`str` The argument supplied by the caller that was not matched """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'{argument!r} does not follow a valid ID or mention format.') + class MemberNotFound(BadArgument): """Exception raised when the member provided was not found in the bot's cache. @@ -267,10 +292,12 @@ class MemberNotFound(BadArgument): argument: :class:`str` The member supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Member "{argument}" not found.') + class GuildNotFound(BadArgument): """Exception raised when the guild provided was not found in the bot's cache. @@ -283,10 +310,12 @@ class GuildNotFound(BadArgument): argument: :class:`str` The guild supplied by the called that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Guild "{argument}" not found.') + class UserNotFound(BadArgument): """Exception raised when the user provided was not found in the bot's cache. @@ -300,10 +329,12 @@ class UserNotFound(BadArgument): argument: :class:`str` The user supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'User "{argument}" not found.') + class MessageNotFound(BadArgument): """Exception raised when the message provided was not found in the channel. @@ -316,10 +347,12 @@ class MessageNotFound(BadArgument): argument: :class:`str` The message supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Message "{argument}" not found.') + class ChannelNotReadable(BadArgument): """Exception raised when the bot does not have permission to read messages in the channel. @@ -333,10 +366,12 @@ class ChannelNotReadable(BadArgument): argument: Union[:class:`.abc.GuildChannel`, :class:`.Thread`] The channel supplied by the caller that was not readable """ + def __init__(self, argument: Union[GuildChannel, Thread]) -> None: self.argument: Union[GuildChannel, Thread] = argument super().__init__(f"Can't read messages in {argument.mention}.") + class ChannelNotFound(BadArgument): """Exception raised when the bot can not find the channel. @@ -349,10 +384,12 @@ class ChannelNotFound(BadArgument): argument: Union[:class:`int`, :class:`str`] The channel supplied by the caller that was not found """ + def __init__(self, argument: Union[int, str]) -> None: self.argument: Union[int, str] = argument super().__init__(f'Channel "{argument}" not found.') + class ThreadNotFound(BadArgument): """Exception raised when the bot can not find the thread. @@ -365,10 +402,12 @@ class ThreadNotFound(BadArgument): argument: :class:`str` The thread supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Thread "{argument}" not found.') + class BadColourArgument(BadArgument): """Exception raised when the colour is not valid. @@ -381,12 +420,15 @@ class BadColourArgument(BadArgument): argument: :class:`str` The colour supplied by the caller that was not valid """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Colour "{argument}" is invalid.') + BadColorArgument = BadColourArgument + class RoleNotFound(BadArgument): """Exception raised when the bot can not find the role. @@ -399,10 +441,12 @@ class RoleNotFound(BadArgument): argument: :class:`str` The role supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Role "{argument}" not found.') + class BadInviteArgument(BadArgument): """Exception raised when the invite is invalid or expired. @@ -410,10 +454,12 @@ class BadInviteArgument(BadArgument): .. versionadded:: 1.5 """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Invite "{argument}" is invalid or expired.') + class EmojiNotFound(BadArgument): """Exception raised when the bot can not find the emoji. @@ -426,10 +472,12 @@ class EmojiNotFound(BadArgument): argument: :class:`str` The emoji supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Emoji "{argument}" not found.') + class PartialEmojiConversionFailure(BadArgument): """Exception raised when the emoji provided does not match the correct format. @@ -443,10 +491,12 @@ class PartialEmojiConversionFailure(BadArgument): argument: :class:`str` The emoji supplied by the caller that did not match the regex """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Couldn\'t convert "{argument}" to PartialEmoji.') + class GuildStickerNotFound(BadArgument): """Exception raised when the bot can not find the sticker. @@ -459,10 +509,12 @@ class GuildStickerNotFound(BadArgument): argument: :class:`str` The sticker supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Sticker "{argument}" not found.') + class BadBoolArgument(BadArgument): """Exception raised when a boolean argument was not convertable. @@ -475,17 +527,21 @@ class BadBoolArgument(BadArgument): argument: :class:`str` The boolean argument supplied by the caller that is not in the predefined list """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'{argument} is not a recognised boolean option') + class DisabledCommand(CommandError): """Exception raised when the command being invoked is disabled. This inherits from :exc:`CommandError` """ + pass + class CommandInvokeError(CommandError): """Exception raised when the command being invoked raised an exception. @@ -497,10 +553,12 @@ class CommandInvokeError(CommandError): The original exception that was raised. You can also get this via the ``__cause__`` attribute. """ + def __init__(self, e: Exception) -> None: self.original: Exception = e super().__init__(f'Command raised an exception: {e.__class__.__name__}: {e}') + class CommandOnCooldown(CommandError): """Exception raised when the command being invoked is on cooldown. @@ -516,12 +574,14 @@ class CommandOnCooldown(CommandError): retry_after: :class:`float` The amount of seconds to wait before you can retry again. """ + def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None: self.cooldown: Cooldown = cooldown self.retry_after: float = retry_after self.type: BucketType = type super().__init__(f'You are on cooldown. Try again in {retry_after:.2f}s') + class MaxConcurrencyReached(CommandError): """Exception raised when the command being invoked has reached its maximum concurrency. @@ -544,6 +604,7 @@ class MaxConcurrencyReached(CommandError): fmt = plural % (number, suffix) super().__init__(f'Too many people are using this command. It can only be used {fmt} concurrently.') + class MissingRole(CheckFailure): """Exception raised when the command invoker lacks a role to run a command. @@ -557,11 +618,13 @@ class MissingRole(CheckFailure): The required role that is missing. This is the parameter passed to :func:`~.commands.has_role`. """ + def __init__(self, missing_role: Snowflake) -> None: self.missing_role: Snowflake = missing_role message = f'Role {missing_role!r} is required to run this command.' super().__init__(message) + class BotMissingRole(CheckFailure): """Exception raised when the bot's member lacks a role to run a command. @@ -575,11 +638,13 @@ class BotMissingRole(CheckFailure): The required role that is missing. This is the parameter passed to :func:`~.commands.has_role`. """ + def __init__(self, missing_role: Snowflake) -> None: self.missing_role: Snowflake = missing_role message = f'Bot requires the role {missing_role!r} to run this command' super().__init__(message) + class MissingAnyRole(CheckFailure): """Exception raised when the command invoker lacks any of the roles specified to run a command. @@ -594,6 +659,7 @@ class MissingAnyRole(CheckFailure): The roles that the invoker is missing. These are the parameters passed to :func:`~.commands.has_any_role`. """ + def __init__(self, missing_roles: SnowflakeList) -> None: self.missing_roles: SnowflakeList = missing_roles @@ -623,6 +689,7 @@ class BotMissingAnyRole(CheckFailure): These are the parameters passed to :func:`~.commands.has_any_role`. """ + def __init__(self, missing_roles: SnowflakeList) -> None: self.missing_roles: SnowflakeList = missing_roles @@ -636,6 +703,7 @@ class BotMissingAnyRole(CheckFailure): message = f"Bot is missing at least one of the required roles: {fmt}" super().__init__(message) + class NSFWChannelRequired(CheckFailure): """Exception raised when a channel does not have the required NSFW setting. @@ -648,10 +716,12 @@ class NSFWChannelRequired(CheckFailure): channel: Union[:class:`.abc.GuildChannel`, :class:`.Thread`] The channel that does not have NSFW enabled. """ + def __init__(self, channel: Union[GuildChannel, Thread]) -> None: self.channel: Union[GuildChannel, Thread] = channel super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.") + class MissingPermissions(CheckFailure): """Exception raised when the command invoker lacks permissions to run a command. @@ -663,6 +733,7 @@ class MissingPermissions(CheckFailure): missing_permissions: List[:class:`str`] The required permissions that are missing. """ + def __init__(self, missing_permissions: List[str], *args: Any) -> None: self.missing_permissions: List[str] = missing_permissions @@ -675,6 +746,7 @@ class MissingPermissions(CheckFailure): message = f'You are missing {fmt} permission(s) to run this command.' super().__init__(message, *args) + class BotMissingPermissions(CheckFailure): """Exception raised when the bot's member lacks permissions to run a command. @@ -686,6 +758,7 @@ class BotMissingPermissions(CheckFailure): missing_permissions: List[:class:`str`] The required permissions that are missing. """ + def __init__(self, missing_permissions: List[str], *args: Any) -> None: self.missing_permissions: List[str] = missing_permissions @@ -698,6 +771,7 @@ class BotMissingPermissions(CheckFailure): message = f'Bot requires {fmt} permission(s) to run this command.' super().__init__(message, *args) + class BadUnionArgument(UserInputError): """Exception raised when a :data:`typing.Union` converter fails for all its associated types. @@ -713,6 +787,7 @@ class BadUnionArgument(UserInputError): errors: List[:class:`CommandError`] A list of errors that were caught from failing the conversion. """ + def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None: self.param: Parameter = param self.converters: Tuple[Type, ...] = converters @@ -734,6 +809,7 @@ class BadUnionArgument(UserInputError): super().__init__(f'Could not convert "{param.name}" into {fmt}.') + class BadLiteralArgument(UserInputError): """Exception raised when a :data:`typing.Literal` converter fails for all its associated values. @@ -751,6 +827,7 @@ class BadLiteralArgument(UserInputError): errors: List[:class:`CommandError`] A list of errors that were caught from failing the conversion. """ + def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None: self.param: Parameter = param self.literals: Tuple[Any, ...] = literals @@ -764,6 +841,7 @@ class BadLiteralArgument(UserInputError): super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.') + class ArgumentParsingError(UserInputError): """An exception raised when the parser fails to parse a user's input. @@ -772,8 +850,10 @@ class ArgumentParsingError(UserInputError): There are child classes that implement more granular parsing errors for i18n purposes. """ + pass + class UnexpectedQuoteError(ArgumentParsingError): """An exception raised when the parser encounters a quote mark inside a non-quoted string. @@ -784,10 +864,12 @@ class UnexpectedQuoteError(ArgumentParsingError): quote: :class:`str` The quote mark that was found inside the non-quoted string. """ + def __init__(self, quote: str) -> None: self.quote: str = quote super().__init__(f'Unexpected quote mark, {quote!r}, in non-quoted string') + class InvalidEndOfQuotedStringError(ArgumentParsingError): """An exception raised when a space is expected after the closing quote in a string but a different character is found. @@ -799,10 +881,12 @@ class InvalidEndOfQuotedStringError(ArgumentParsingError): char: :class:`str` The character found instead of the expected string. """ + def __init__(self, char: str) -> None: self.char: str = char super().__init__(f'Expected space after closing quotation but received {char!r}') + class ExpectedClosingQuoteError(ArgumentParsingError): """An exception raised when a quote character is expected but not found. @@ -818,6 +902,7 @@ class ExpectedClosingQuoteError(ArgumentParsingError): self.close_quote: str = close_quote super().__init__(f'Expected closing {close_quote}.') + class ExtensionError(DiscordException): """Base exception for extension related errors. @@ -828,6 +913,7 @@ class ExtensionError(DiscordException): name: :class:`str` The extension that had an error. """ + def __init__(self, message: Optional[str] = None, *args: Any, name: str) -> None: self.name: str = name message = message or f'Extension {name!r} had an error.' @@ -835,30 +921,37 @@ class ExtensionError(DiscordException): m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere') super().__init__(m, *args) + class ExtensionAlreadyLoaded(ExtensionError): """An exception raised when an extension has already been loaded. This inherits from :exc:`ExtensionError` """ + def __init__(self, name: str) -> None: super().__init__(f'Extension {name!r} is already loaded.', name=name) + class ExtensionNotLoaded(ExtensionError): """An exception raised when an extension was not loaded. This inherits from :exc:`ExtensionError` """ + def __init__(self, name: str) -> None: super().__init__(f'Extension {name!r} has not been loaded.', name=name) + class NoEntryPointError(ExtensionError): """An exception raised when an extension does not have a ``setup`` entry point function. This inherits from :exc:`ExtensionError` """ + def __init__(self, name: str) -> None: super().__init__(f"Extension {name!r} has no 'setup' function.", name=name) + class ExtensionFailed(ExtensionError): """An exception raised when an extension failed to load during execution of the module or ``setup`` entry point. @@ -872,11 +965,13 @@ class ExtensionFailed(ExtensionError): The original exception that was raised. You can also get this via the ``__cause__`` attribute. """ + def __init__(self, name: str, original: Exception) -> None: self.original: Exception = original msg = f'Extension {name!r} raised an error: {original.__class__.__name__}: {original}' super().__init__(msg, name=name) + class ExtensionNotFound(ExtensionError): """An exception raised when an extension is not found. @@ -890,10 +985,12 @@ class ExtensionNotFound(ExtensionError): name: :class:`str` The extension that had the error. """ + def __init__(self, name: str) -> None: msg = f'Extension {name!r} could not be loaded.' super().__init__(msg, name=name) + class CommandRegistrationError(ClientException): """An exception raised when the command can't be added because the name is already taken by a different command. @@ -909,12 +1006,14 @@ class CommandRegistrationError(ClientException): alias_conflict: :class:`bool` Whether the name that conflicts is an alias of the command we try to add. """ + def __init__(self, name: str, *, alias_conflict: bool = False) -> None: self.name: str = name self.alias_conflict: bool = alias_conflict type_ = 'alias' if alias_conflict else 'command' super().__init__(f'The {type_} {name} is already an existing command or alias.') + class FlagError(BadArgument): """The base exception type for all flag parsing related errors. @@ -922,8 +1021,10 @@ class FlagError(BadArgument): .. versionadded:: 2.0 """ + pass + class TooManyFlags(FlagError): """An exception raised when a flag has received too many values. @@ -938,11 +1039,13 @@ class TooManyFlags(FlagError): values: List[:class:`str`] The values that were passed. """ + def __init__(self, flag: Flag, values: List[str]) -> None: self.flag: Flag = flag self.values: List[str] = values super().__init__(f'Too many flag values, expected {flag.max_args} but received {len(values)}.') + class BadFlagArgument(FlagError): """An exception raised when a flag failed to convert a value. @@ -955,6 +1058,7 @@ class BadFlagArgument(FlagError): flag: :class:`~discord.ext.commands.Flag` The flag that failed to convert. """ + def __init__(self, flag: Flag) -> None: self.flag: Flag = flag try: @@ -964,6 +1068,7 @@ class BadFlagArgument(FlagError): super().__init__(f'Could not convert to {name!r} for flag {flag.name!r}') + class MissingRequiredFlag(FlagError): """An exception raised when a required flag was not given. @@ -976,10 +1081,12 @@ class MissingRequiredFlag(FlagError): flag: :class:`~discord.ext.commands.Flag` The required flag that was not found. """ + def __init__(self, flag: Flag) -> None: self.flag: Flag = flag super().__init__(f'Flag {flag.name!r} is required and missing') + class MissingFlagArgument(FlagError): """An exception raised when a flag did not get a value. @@ -992,6 +1099,7 @@ class MissingFlagArgument(FlagError): flag: :class:`~discord.ext.commands.Flag` The flag that did not get a value. """ + def __init__(self, flag: Flag) -> None: self.flag: Flag = flag super().__init__(f'Flag {flag.name!r} does not have an argument') diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index 67507b010..a2ae7d38c 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -932,7 +932,7 @@ class DefaultHelpCommand(HelpCommand): def shorten_text(self, text): """:class:`str`: Shortens text to fit into the :attr:`width`.""" if len(text) > self.width: - return text[:self.width - 3].rstrip() + '...' + return text[: self.width - 3].rstrip() + '...' return text def get_ending_note(self): diff --git a/discord/ext/commands/view.py b/discord/ext/commands/view.py index a7dc72367..6ba31e1a7 100644 --- a/discord/ext/commands/view.py +++ b/discord/ext/commands/view.py @@ -46,6 +46,7 @@ _quotes = { } _all_quotes = set(_quotes.keys()) | set(_quotes.values()) + class StringView: def __init__(self, buffer): self.index = 0 @@ -81,20 +82,20 @@ class StringView: def skip_string(self, string): strlen = len(string) - if self.buffer[self.index:self.index + strlen] == string: + if self.buffer[self.index : self.index + strlen] == string: self.previous = self.index self.index += strlen return True return False def read_rest(self): - result = self.buffer[self.index:] + result = self.buffer[self.index :] self.previous = self.index self.index = self.end return result def read(self, n): - result = self.buffer[self.index:self.index + n] + result = self.buffer[self.index : self.index + n] self.previous = self.index self.index += n return result @@ -120,7 +121,7 @@ class StringView: except IndexError: break self.previous = self.index - result = self.buffer[self.index:self.index + pos] + result = self.buffer[self.index : self.index + pos] self.index += pos return result @@ -187,6 +188,5 @@ class StringView: result.append(current) - def __repr__(self): return f'' diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index 5b78f10e7..c4838d9a3 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -48,9 +48,11 @@ from collections.abc import Sequence from discord.backoff import ExponentialBackoff from discord.utils import MISSING +# fmt: off __all__ = ( 'loop', ) +# fmt: on T = TypeVar('T') _func = Callable[..., Awaitable[Any]] diff --git a/discord/file.py b/discord/file.py index 6f176eef8..4b060554b 100644 --- a/discord/file.py +++ b/discord/file.py @@ -28,9 +28,11 @@ from typing import Any, Dict, Optional, Union import os import io +# fmt: off __all__ = ( 'File', ) +# fmt: on class File: diff --git a/discord/flags.py b/discord/flags.py index 3b1afd61e..31a594529 100644 --- a/discord/flags.py +++ b/discord/flags.py @@ -82,7 +82,7 @@ def fill_with_flags(*, inverted: bool = False): if inverted: max_bits = max(cls.VALID_FLAGS.values()).bit_length() - cls.DEFAULT_VALUE = -1 + (2 ** max_bits) + cls.DEFAULT_VALUE = -1 + (2**max_bits) else: cls.DEFAULT_VALUE = 0 @@ -908,7 +908,7 @@ class Intents(BaseFlags): - :func:`on_message_edit` - :func:`on_message_delete` - :func:`on_raw_message_edit` - + For more information go to the :ref:`message content intent documentation `. .. note:: diff --git a/discord/gateway.py b/discord/gateway.py index aa0c6ba06..a8f5c0dd0 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -50,19 +50,25 @@ __all__ = ( 'ReconnectWebSocket', ) + class ReconnectWebSocket(Exception): """Signals to safely reconnect the websocket.""" + def __init__(self, shard_id, *, resume=True): self.shard_id = shard_id self.resume = resume self.op = 'RESUME' if resume else 'IDENTIFY' + class WebSocketClosure(Exception): """An exception to make up for the fact that aiohttp doesn't signal closure.""" + pass + EventListener = namedtuple('EventListener', 'predicate event result future') + class GatewayRatelimiter: def __init__(self, count=110, per=60.0): # The default is 110 to give room for at least 10 heartbeats per minute @@ -171,7 +177,7 @@ class KeepAliveHandler(threading.Thread): def get_payload(self): return { 'op': self.ws.HEARTBEAT, - 'd': self.ws.sequence + 'd': self.ws.sequence, } def stop(self): @@ -187,6 +193,7 @@ class KeepAliveHandler(threading.Thread): if self.latency > 10: _log.warning(self.behind_msg, self.shard_id, self.latency) + class VoiceKeepAliveHandler(KeepAliveHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -198,7 +205,7 @@ class VoiceKeepAliveHandler(KeepAliveHandler): def get_payload(self): return { 'op': self.ws.HEARTBEAT, - 'd': int(time.time() * 1000) + 'd': int(time.time() * 1000), } def ack(self): @@ -208,10 +215,12 @@ class VoiceKeepAliveHandler(KeepAliveHandler): self.latency = ack_time - self._last_send self.recent_ack_latencies.append(self.latency) + class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse): async def close(self, *, code: int = 4000, message: bytes = b'') -> bool: return await super().close(code=code, message=message) + class DiscordWebSocket: """Implements a WebSocket for Discord's gateway v6. @@ -252,6 +261,7 @@ class DiscordWebSocket: The authentication token for discord. """ + # fmt: off DISPATCH = 0 HEARTBEAT = 1 IDENTIFY = 2 @@ -265,6 +275,7 @@ class DiscordWebSocket: HELLO = 10 HEARTBEAT_ACK = 11 GUILD_SYNC = 12 + # fmt: on def __init__(self, socket, *, loop): self.socket = socket @@ -300,7 +311,17 @@ class DiscordWebSocket: pass @classmethod - async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False): + async def from_client( + cls, + client, + *, + initial=False, + gateway=None, + shard_id=None, + session=None, + sequence=None, + resume=False, + ): """Creates a main websocket for Discord from a :class:`Client`. This is for internal use only. @@ -378,12 +399,12 @@ class DiscordWebSocket: '$browser': 'discord.py', '$device': 'discord.py', '$referrer': '', - '$referring_domain': '' + '$referring_domain': '', }, 'compress': True, 'large_threshold': 250, - 'v': 3 - } + 'v': 3, + }, } if self.shard_id is not None and self.shard_count is not None: @@ -395,7 +416,7 @@ class DiscordWebSocket: 'status': state._status, 'game': state._activity, 'since': 0, - 'afk': False + 'afk': False, } if state._intents is not None: @@ -412,8 +433,8 @@ class DiscordWebSocket: 'd': { 'seq': self.sequence, 'session_id': self.session_id, - 'token': self.token - } + 'token': self.token, + }, } await self.send_as_json(payload) @@ -494,15 +515,23 @@ class DiscordWebSocket: self.session_id = data['session_id'] # pass back shard ID to ready handler data['__shard_id__'] = self.shard_id - _log.info('Shard ID %s has connected to Gateway: %s (Session ID: %s).', - self.shard_id, ', '.join(trace), self.session_id) + _log.info( + 'Shard ID %s has connected to Gateway: %s (Session ID: %s).', + self.shard_id, + ', '.join(trace), + self.session_id, + ) elif event == 'RESUMED': self._trace = trace = data.get('_trace', []) # pass back the shard ID to the resumed handler data['__shard_id__'] = self.shard_id - _log.info('Shard ID %s has successfully RESUMED session %s under trace %s.', - self.shard_id, self.session_id, ', '.join(trace)) + _log.info( + 'Shard ID %s has successfully RESUMED session %s under trace %s.', + self.shard_id, + self.session_id, + ', '.join(trace), + ) try: func = self._discord_parsers[event] @@ -625,8 +654,8 @@ class DiscordWebSocket: 'activities': activity, 'afk': False, 'since': since, - 'status': status - } + 'status': status, + }, } sent = utils._to_json(payload) @@ -639,8 +668,8 @@ class DiscordWebSocket: 'd': { 'guild_id': guild_id, 'presences': presences, - 'limit': limit - } + 'limit': limit, + }, } if nonce: @@ -652,7 +681,6 @@ class DiscordWebSocket: if query is not None: payload['d']['query'] = query - await self.send_as_json(payload) async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): @@ -662,8 +690,8 @@ class DiscordWebSocket: 'guild_id': guild_id, 'channel_id': channel_id, 'self_mute': self_mute, - 'self_deaf': self_deaf - } + 'self_deaf': self_deaf, + }, } _log.debug('Updating our voice state to %s.', payload) @@ -677,6 +705,7 @@ class DiscordWebSocket: self._close_code = code await self.socket.close(code=code) + class DiscordVoiceWebSocket: """Implements the websocket protocol for handling voice connections. @@ -708,6 +737,7 @@ class DiscordVoiceWebSocket: Receive only. Indicates a user has disconnected from voice. """ + # fmt: off IDENTIFY = 0 SELECT_PROTOCOL = 1 READY = 2 @@ -720,6 +750,7 @@ class DiscordVoiceWebSocket: RESUMED = 9 CLIENT_CONNECT = 12 CLIENT_DISCONNECT = 13 + # fmt: on def __init__(self, socket, loop, *, hook=None): self.ws = socket @@ -746,8 +777,8 @@ class DiscordVoiceWebSocket: 'd': { 'token': state.token, 'server_id': str(state.server_id), - 'session_id': state.session_id - } + 'session_id': state.session_id, + }, } await self.send_as_json(payload) @@ -759,8 +790,8 @@ class DiscordVoiceWebSocket: 'server_id': str(state.server_id), 'user_id': str(state.user.id), 'session_id': state.session_id, - 'token': state.token - } + 'token': state.token, + }, } await self.send_as_json(payload) @@ -791,9 +822,9 @@ class DiscordVoiceWebSocket: 'data': { 'address': ip, 'port': port, - 'mode': mode - } - } + 'mode': mode, + }, + }, } await self.send_as_json(payload) @@ -802,8 +833,8 @@ class DiscordVoiceWebSocket: payload = { 'op': self.CLIENT_CONNECT, 'd': { - 'audio_ssrc': self._connection.ssrc - } + 'audio_ssrc': self._connection.ssrc, + }, } await self.send_as_json(payload) @@ -813,8 +844,8 @@ class DiscordVoiceWebSocket: 'op': self.SPEAKING, 'd': { 'speaking': int(state), - 'delay': 0 - } + 'delay': 0, + }, } await self.send_as_json(payload) @@ -847,8 +878,8 @@ class DiscordVoiceWebSocket: state.endpoint_ip = data['ip'] packet = bytearray(70) - struct.pack_into('>H', packet, 0, 1) # 1 = Send - struct.pack_into('>H', packet, 2, 70) # 70 = Length + struct.pack_into('>H', packet, 0, 1) # 1 = Send + struct.pack_into('>H', packet, 2, 70) # 70 = Length struct.pack_into('>I', packet, 4, state.ssrc) state.socket.sendto(packet, (state.endpoint_ip, state.voice_port)) recv = await self.loop.sock_recv(state.socket, 70) diff --git a/discord/guild.py b/discord/guild.py index 2d55b1932..dbc2d7329 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -81,9 +81,11 @@ from .audit_logs import AuditLogEntry from .object import OLDEST_OBJECT, Object +# fmt: off __all__ = ( 'Guild', ) +# fmt: on MISSING = utils.MISSING @@ -2830,9 +2832,9 @@ class Guild(Hashable): if data and entries: if limit is not None: limit -= len(data) - + before = Object(id=int(entries[-1]['id'])) - + return data.get('users', []), entries, before, limit async def _after_strategy(retrieve, after, limit): @@ -2846,7 +2848,7 @@ class Guild(Hashable): if data and entries: if limit is not None: limit -= len(data) - + after = Object(id=int(entries[0]['id'])) return data.get('users', []), entries, after, limit @@ -2864,7 +2866,6 @@ class Guild(Hashable): if isinstance(after, datetime.datetime): after = Object(id=utils.time_snowflake(after, high=True)) - if oldest_first is None: reverse = after is not None else: diff --git a/discord/http.py b/discord/http.py index 42aa2199a..005726a87 100644 --- a/discord/http.py +++ b/discord/http.py @@ -594,7 +594,9 @@ class HTTPClient: return self.request(r, json=payload, reason=reason) - def edit_message(self, channel_id: Snowflake, message_id: Snowflake, *, params: MultipartParameters) -> Response[message.Message]: + def edit_message( + self, channel_id: Snowflake, message_id: Snowflake, *, params: MultipartParameters + ) -> Response[message.Message]: r = Route('PATCH', '/channels/{channel_id}/messages/{message_id}', channel_id=channel_id, message_id=message_id) if params.files: return self.request(r, files=params.files, form=params.multipart) diff --git a/discord/member.py b/discord/member.py index 8b97e92d8..4d05888ce 100644 --- a/discord/member.py +++ b/discord/member.py @@ -767,13 +767,15 @@ class Member(discord.abc.Messageable, _UserTag): if roles is not MISSING: payload['roles'] = tuple(r.id for r in roles) - + if timed_out_until is not MISSING: if timed_out_until is None: payload['communication_disabled_until'] = None else: if timed_out_until.tzinfo is None: - raise TypeError('timed_out_until must be an aware datetime. Consider using discord.utils.utcnow() or datetime.datetime.now().astimezone() for local time.') + raise TypeError( + 'timed_out_until must be an aware datetime. Consider using discord.utils.utcnow() or datetime.datetime.now().astimezone() for local time.' + ) payload['communication_disabled_until'] = timed_out_until.isoformat() if payload: @@ -940,7 +942,7 @@ class Member(discord.abc.Messageable, _UserTag): """Returns whether this member is timed out. .. versionadded:: 2.0 - + Returns -------- :class:`bool` diff --git a/discord/mentions.py b/discord/mentions.py index 0516decfc..eacbbaaa9 100644 --- a/discord/mentions.py +++ b/discord/mentions.py @@ -25,9 +25,11 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations from typing import Type, TypeVar, Union, List, TYPE_CHECKING, Any, Union +# fmt: off __all__ = ( 'AllowedMentions', ) +# fmt: on if TYPE_CHECKING: from .types.message import AllowedMentions as AllowedMentionsPayload diff --git a/discord/message.py b/discord/message.py index 6a5864f45..65688bd45 100644 --- a/discord/message.py +++ b/discord/message.py @@ -1348,7 +1348,7 @@ class Message(Hashable): ----------- \*attachments: :class:`Attachment` Attachments to remove from the message. - + Raises ------- HTTPException diff --git a/discord/mixins.py b/discord/mixins.py index 32ee222b7..9556f9d90 100644 --- a/discord/mixins.py +++ b/discord/mixins.py @@ -27,6 +27,7 @@ __all__ = ( 'Hashable', ) + class EqualityComparable: __slots__ = () @@ -40,6 +41,7 @@ class EqualityComparable: return other.id != self.id return True + class Hashable(EqualityComparable): __slots__ = () diff --git a/discord/object.py b/discord/object.py index 8ba0afd72..867eb27fe 100644 --- a/discord/object.py +++ b/discord/object.py @@ -35,11 +35,15 @@ from typing import ( if TYPE_CHECKING: import datetime + SupportsIntCast = Union[SupportsInt, str, bytes, bytearray] +# fmt: off __all__ = ( 'Object', ) +# fmt: on + class Object(Hashable): """Represents a generic Discord object. diff --git a/discord/oggparse.py b/discord/oggparse.py index e0347d2cb..09ee8b984 100644 --- a/discord/oggparse.py +++ b/discord/oggparse.py @@ -36,13 +36,17 @@ __all__ = ( 'OggStream', ) + class OggError(DiscordException): """An exception that is thrown for Ogg stream parsing errors.""" + pass + # https://tools.ietf.org/html/rfc3533 # https://tools.ietf.org/html/rfc7845 + class OggPage: _header: ClassVar[struct.Struct] = struct.Struct(' None: self.stream: IO[bytes] = stream diff --git a/discord/opus.py b/discord/opus.py index 97d437a36..c546b9040 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -42,6 +42,7 @@ if TYPE_CHECKING: BAND_CTL = Literal['narrow', 'medium', 'wide', 'superwide', 'full'] SIGNAL_CTL = Literal['auto', 'voice', 'music'] + class BandCtl(TypedDict): narrow: int medium: int @@ -49,11 +50,13 @@ class BandCtl(TypedDict): superwide: int full: int + class SignalCtl(TypedDict): auto: int voice: int music: int + __all__ = ( 'Encoder', 'OpusError', @@ -62,23 +65,27 @@ __all__ = ( _log = logging.getLogger(__name__) -c_int_ptr = ctypes.POINTER(ctypes.c_int) +c_int_ptr = ctypes.POINTER(ctypes.c_int) c_int16_ptr = ctypes.POINTER(ctypes.c_int16) c_float_ptr = ctypes.POINTER(ctypes.c_float) _lib = None + class EncoderStruct(ctypes.Structure): pass + class DecoderStruct(ctypes.Structure): pass + EncoderStructPtr = ctypes.POINTER(EncoderStruct) DecoderStructPtr = ctypes.POINTER(DecoderStruct) ## Some constants from opus_defines.h # Error codes +# fmt: off OK = 0 BAD_ARG = -1 @@ -96,6 +103,7 @@ CTL_SET_SIGNAL = 4024 # Decoder CTLs CTL_SET_GAIN = 4034 CTL_LAST_PACKET_DURATION = 4039 +# fmt: on band_ctl: BandCtl = { 'narrow': 1101, @@ -111,12 +119,14 @@ signal_ctl: SignalCtl = { 'music': 3002, } + def _err_lt(result: int, func: Callable, args: List) -> int: if result < OK: _log.info('error has happened in %s', func.__name__) raise OpusError(result) return result + def _err_ne(result: T, func: Callable, args: List) -> T: ret = args[-1]._obj if ret.value != OK: @@ -124,6 +134,7 @@ def _err_ne(result: T, func: Callable, args: List) -> T: raise OpusError(ret.value) return result + # A list of exported functions. # The first argument is obviously the name. # The second one are the types of arguments it takes. @@ -131,54 +142,46 @@ def _err_ne(result: T, func: Callable, args: List) -> T: # The fourth is the error handler. exported_functions: List[Tuple[Any, ...]] = [ # Generic - ('opus_get_version_string', - None, ctypes.c_char_p, None), - ('opus_strerror', - [ctypes.c_int], ctypes.c_char_p, None), - + ('opus_get_version_string', None, ctypes.c_char_p, None), + ('opus_strerror', [ctypes.c_int], ctypes.c_char_p, None), # Encoder functions - ('opus_encoder_get_size', - [ctypes.c_int], ctypes.c_int, None), - ('opus_encoder_create', - [ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], EncoderStructPtr, _err_ne), - ('opus_encode', - [EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt), - ('opus_encode_float', - [EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt), - ('opus_encoder_ctl', - None, ctypes.c_int32, _err_lt), - ('opus_encoder_destroy', - [EncoderStructPtr], None, None), - + ('opus_encoder_get_size', [ctypes.c_int], ctypes.c_int, None), + ('opus_encoder_create', [ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], EncoderStructPtr, _err_ne), + ('opus_encode', [EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt), + ( + 'opus_encode_float', + [EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], + ctypes.c_int32, + _err_lt, + ), + ('opus_encoder_ctl', None, ctypes.c_int32, _err_lt), + ('opus_encoder_destroy', [EncoderStructPtr], None, None), # Decoder functions - ('opus_decoder_get_size', - [ctypes.c_int], ctypes.c_int, None), - ('opus_decoder_create', - [ctypes.c_int, ctypes.c_int, c_int_ptr], DecoderStructPtr, _err_ne), - ('opus_decode', + ('opus_decoder_get_size', [ctypes.c_int], ctypes.c_int, None), + ('opus_decoder_create', [ctypes.c_int, ctypes.c_int, c_int_ptr], DecoderStructPtr, _err_ne), + ( + 'opus_decode', [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_int16_ptr, ctypes.c_int, ctypes.c_int], - ctypes.c_int, _err_lt), - ('opus_decode_float', + ctypes.c_int, + _err_lt, + ), + ( + 'opus_decode_float', [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_float_ptr, ctypes.c_int, ctypes.c_int], - ctypes.c_int, _err_lt), - ('opus_decoder_ctl', - None, ctypes.c_int32, _err_lt), - ('opus_decoder_destroy', - [DecoderStructPtr], None, None), - ('opus_decoder_get_nb_samples', - [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int, _err_lt), - + ctypes.c_int, + _err_lt, + ), + ('opus_decoder_ctl', None, ctypes.c_int32, _err_lt), + ('opus_decoder_destroy', [DecoderStructPtr], None, None), + ('opus_decoder_get_nb_samples', [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int, _err_lt), # Packet functions - ('opus_packet_get_bandwidth', - [ctypes.c_char_p], ctypes.c_int, _err_lt), - ('opus_packet_get_nb_channels', - [ctypes.c_char_p], ctypes.c_int, _err_lt), - ('opus_packet_get_nb_frames', - [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt), - ('opus_packet_get_samples_per_frame', - [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt), + ('opus_packet_get_bandwidth', [ctypes.c_char_p], ctypes.c_int, _err_lt), + ('opus_packet_get_nb_channels', [ctypes.c_char_p], ctypes.c_int, _err_lt), + ('opus_packet_get_nb_frames', [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt), + ('opus_packet_get_samples_per_frame', [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt), ] + def libopus_loader(name: str) -> Any: # create the library... lib = ctypes.cdll.LoadLibrary(name) @@ -203,6 +206,7 @@ def libopus_loader(name: str) -> Any: return lib + def _load_default() -> bool: global _lib try: @@ -219,6 +223,7 @@ def _load_default() -> bool: return _lib is not None + def load_opus(name: str) -> None: """Loads the libopus shared library for use with voice. @@ -257,6 +262,7 @@ def load_opus(name: str) -> None: global _lib _lib = libopus_loader(name) + def is_loaded() -> bool: """Function to check if opus lib is successfully loaded either via the :func:`ctypes.util.find_library` call of :func:`load_opus`. @@ -271,6 +277,7 @@ def is_loaded() -> bool: global _lib return _lib is not None + class OpusError(DiscordException): """An exception that is thrown for libopus related errors. @@ -286,10 +293,13 @@ class OpusError(DiscordException): _log.info('"%s" has happened', msg) super().__init__(msg) + class OpusNotLoaded(DiscordException): """An exception that is thrown for when libopus is not loaded.""" + pass + class _OpusStruct: SAMPLING_RATE = 48000 CHANNELS = 2 @@ -306,6 +316,7 @@ class _OpusStruct: return _lib.opus_get_version_string().decode('utf-8') + class Encoder(_OpusStruct): def __init__(self, application: int = APPLICATION_AUDIO): _OpusStruct.get_opus_version() @@ -322,7 +333,7 @@ class Encoder(_OpusStruct): if hasattr(self, '_state'): _lib.opus_encoder_destroy(self._state) # This is a destructor, so it's okay to assign None - self._state = None # type: ignore + self._state = None # type: ignore def _create_state(self) -> EncoderStruct: ret = ctypes.c_int() @@ -352,18 +363,19 @@ class Encoder(_OpusStruct): _lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0) def set_expected_packet_loss_percent(self, percentage: float) -> None: - _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore + _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore def encode(self, pcm: bytes, frame_size: int) -> bytes: max_data_bytes = len(pcm) # bytes can be used to reference pointer - pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore + pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore data = (ctypes.c_char * max_data_bytes)() ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes) # array can be initialized with bytes but mypy doesn't know - return array.array('b', data[:ret]).tobytes() # type: ignore + return array.array('b', data[:ret]).tobytes() # type: ignore + class Decoder(_OpusStruct): def __init__(self): @@ -375,7 +387,7 @@ class Decoder(_OpusStruct): if hasattr(self, '_state'): _lib.opus_decoder_destroy(self._state) # This is a destructor, so it's okay to assign None - self._state = None # type: ignore + self._state = None # type: ignore def _create_state(self) -> DecoderStruct: ret = ctypes.c_int() @@ -411,12 +423,12 @@ class Decoder(_OpusStruct): def set_gain(self, dB: float) -> int: """Sets the decoder gain in dB, from -128 to 128.""" - dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8) + dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8) return self._set_gain(dB_Q8) def set_volume(self, mult: float) -> int: """Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc.""" - return self.set_gain(20 * math.log10(mult)) # amplitude ratio + return self.set_gain(20 * math.log10(mult)) # amplitude ratio def _get_last_packet_duration(self) -> int: """Gets the duration (in samples) of the last packet successfully decoded or concealed.""" @@ -428,7 +440,7 @@ class Decoder(_OpusStruct): @overload def decode(self, data: bytes, *, fec: bool) -> bytes: ... - + @overload def decode(self, data: Literal[None], *, fec: Literal[False]) -> bytes: ... @@ -451,4 +463,4 @@ class Decoder(_OpusStruct): ret = _lib.opus_decode(self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec) - return array.array('h', pcm[:ret * channel_count]).tobytes() + return array.array('h', pcm[: ret * channel_count]).tobytes() diff --git a/discord/partial_emoji.py b/discord/partial_emoji.py index e2c689e24..a9f68c90c 100644 --- a/discord/partial_emoji.py +++ b/discord/partial_emoji.py @@ -31,15 +31,18 @@ from .asset import Asset, AssetMixin from .errors import InvalidArgument from . import utils +# fmt: off __all__ = ( 'PartialEmoji', ) +# fmt: on if TYPE_CHECKING: from .state import ConnectionState from datetime import datetime from .types.message import PartialEmoji as PartialEmojiPayload + class _EmojiTag: __slots__ = () diff --git a/discord/permissions.py b/discord/permissions.py index c0584a1db..13a9a7aa7 100644 --- a/discord/permissions.py +++ b/discord/permissions.py @@ -46,8 +46,10 @@ def make_permission_alias(alias: str) -> Callable[[Callable[[Any], int]], permis return decorator + P = TypeVar('P', bound='Permissions') + @fill_with_flags() class Permissions(BaseFlags): """Wraps up the Discord permission value. @@ -554,21 +556,23 @@ class Permissions(BaseFlags): @flag_value def start_embedded_activities(self) -> int: """:class:`bool`: Returns ``True`` if a user can launch an embedded application in a Voice channel. - + .. versionadded:: 2.0 """ return 1 << 39 - + @flag_value def moderate_members(self) -> int: """:class:`bool`: Returns ``True`` if a user can time out other members. - + .. versionadded:: 2.0 """ return 1 << 40 + PO = TypeVar('PO', bound='PermissionOverwrite') + def _augment_from_permissions(cls): cls.VALID_NAMES = set(Permissions.VALID_FLAGS) aliases = set() diff --git a/discord/player.py b/discord/player.py index 143e36fb9..dc5c9ba78 100644 --- a/discord/player.py +++ b/discord/player.py @@ -36,7 +36,7 @@ import sys import re import io -from typing import Any, Callable, Generic, IO, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Generic, IO, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union from .errors import ClientException from .opus import Encoder as OpusEncoder @@ -68,6 +68,7 @@ if sys.platform != 'win32': else: CREATE_NO_WINDOW = 0x08000000 + class AudioSource: """Represents an audio stream. @@ -114,6 +115,7 @@ class AudioSource: def __del__(self) -> None: self.cleanup() + class PCMAudio(AudioSource): """Represents raw 16-bit 48KHz stereo PCM audio source. @@ -122,6 +124,7 @@ class PCMAudio(AudioSource): stream: :term:`py:file object` A file-like object that reads byte data representing raw PCM. """ + def __init__(self, stream: io.BufferedIOBase) -> None: self.stream: io.BufferedIOBase = stream @@ -131,6 +134,7 @@ class PCMAudio(AudioSource): return b'' return ret + class FFmpegAudio(AudioSource): """Represents an FFmpeg (or AVConv) based AudioSource. @@ -140,7 +144,14 @@ class FFmpegAudio(AudioSource): .. versionadded:: 1.3 """ - def __init__(self, source: Union[str, io.BufferedIOBase], *, executable: str = 'ffmpeg', args: Any, **subprocess_kwargs: Any): + def __init__( + self, + source: Union[str, io.BufferedIOBase], + *, + executable: str = 'ffmpeg', + args: Any, + **subprocess_kwargs: Any, + ): piping = subprocess_kwargs.get('stdin') == subprocess.PIPE if piping and isinstance(source, str): raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin") @@ -191,7 +202,6 @@ class FFmpegAudio(AudioSource): else: _log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode) - def _pipe_writer(self, source: io.BufferedIOBase) -> None: while self._process: # arbitrarily large read size @@ -211,6 +221,7 @@ class FFmpegAudio(AudioSource): self._kill_process() self._process = self._stdout = self._stdin = MISSING + class FFmpegPCMAudio(FFmpegAudio): """An audio source from FFmpeg (or AVConv). @@ -254,7 +265,7 @@ class FFmpegPCMAudio(FFmpegAudio): pipe: bool = False, stderr: Optional[IO[str]] = None, before_options: Optional[str] = None, - options: Optional[str] = None + options: Optional[str] = None, ) -> None: args = [] subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr} @@ -282,6 +293,7 @@ class FFmpegPCMAudio(FFmpegAudio): def is_opus(self) -> bool: return False + class FFmpegOpusAudio(FFmpegAudio): """An audio source from FFmpeg (or AVConv). @@ -367,6 +379,7 @@ class FFmpegOpusAudio(FFmpegAudio): codec = 'copy' if codec in ('opus', 'libopus') else 'libopus' + # fmt: off args.extend(('-map_metadata', '-1', '-f', 'opus', '-c:a', codec, @@ -374,6 +387,7 @@ class FFmpegOpusAudio(FFmpegAudio): '-ac', '2', '-b:a', f'{bitrate}k', '-loglevel', 'warning')) + # fmt: on if isinstance(options, str): args.extend(shlex.split(options)) @@ -500,8 +514,7 @@ class FFmpegOpusAudio(FFmpegAudio): probefunc = method fallback = cls._probe_codec_fallback else: - raise TypeError("Expected str or callable for parameter 'probe', " \ - f"not '{method.__class__.__name__}'") + raise TypeError(f"Expected str or callable for parameter 'probe', not '{method.__class__.__name__}'") codec = bitrate = None loop = asyncio.get_event_loop() @@ -537,13 +550,13 @@ class FFmpegOpusAudio(FFmpegAudio): codec = streamdata.get('codec_name') bitrate = int(streamdata.get('bit_rate', 0)) - bitrate = max(round(bitrate/1000), 512) + bitrate = max(round(bitrate / 1000), 512) return codec, bitrate @staticmethod def _probe_codec_fallback(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: - args = [executable, '-hide_banner', '-i', source] + args = [executable, '-hide_banner', '-i', source] proc = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) out, _ = proc.communicate(timeout=20) output = out.decode('utf8') @@ -565,6 +578,7 @@ class FFmpegOpusAudio(FFmpegAudio): def is_opus(self) -> bool: return True + class PCMVolumeTransformer(AudioSource, Generic[AT]): """Transforms a previous :class:`AudioSource` to have volume controls. @@ -613,6 +627,7 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]): ret = self.original.read() return audioop.mul(ret, 2, min(self._volume, 2.0)) + class AudioPlayer(threading.Thread): DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0 @@ -625,7 +640,7 @@ class AudioPlayer(threading.Thread): self._end: threading.Event = threading.Event() self._resumed: threading.Event = threading.Event() - self._resumed.set() # we are not paused + self._resumed.set() # we are not paused self._current_error: Optional[Exception] = None self._connected: threading.Event = client._connected self._lock: threading.Lock = threading.Lock() diff --git a/discord/raw_models.py b/discord/raw_models.py index cda754d1b..b8a5acc3e 100644 --- a/discord/raw_models.py +++ b/discord/raw_models.py @@ -34,7 +34,7 @@ if TYPE_CHECKING: MessageUpdateEvent, ReactionClearEvent, ReactionClearEmojiEvent, - IntegrationDeleteEvent + IntegrationDeleteEvent, ) from .message import Message from .partial_emoji import PartialEmoji @@ -179,8 +179,7 @@ class RawReactionActionEvent(_RawReprMixin): .. versionadded:: 1.3 """ - __slots__ = ('message_id', 'user_id', 'channel_id', 'guild_id', 'emoji', - 'event_type', 'member') + __slots__ = ('message_id', 'user_id', 'channel_id', 'guild_id', 'emoji', 'event_type', 'member') def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str) -> None: self.message_id: int = int(data['message_id']) diff --git a/discord/reaction.py b/discord/reaction.py index 733937141..9835acb3a 100644 --- a/discord/reaction.py +++ b/discord/reaction.py @@ -27,9 +27,11 @@ from typing import Any, TYPE_CHECKING, AsyncIterator, List, Union, Optional from .object import Object +# fmt: off __all__ = ( 'Reaction', ) +# fmt: on if TYPE_CHECKING: from .user import User @@ -40,6 +42,7 @@ if TYPE_CHECKING: from .emoji import Emoji from .abc import Snowflake + class Reaction: """Represents a reaction to a message. @@ -77,6 +80,7 @@ class Reaction: message: :class:`Message` Message this reaction is for. """ + __slots__ = ('message', 'count', 'emoji', 'me') def __init__(self, *, message: Message, data: ReactionPayload, emoji: Optional[Union[PartialEmoji, Emoji, str]] = None): @@ -157,7 +161,9 @@ class Reaction: """ await self.message.clear_reaction(self.emoji) - async def users(self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None) -> AsyncIterator[Union[Member, User]]: + async def users( + self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None + ) -> AsyncIterator[Union[Member, User]]: """Returns an :term:`asynchronous iterator` representing the users that have reacted to the message. The ``after`` parameter must represent a member @@ -222,9 +228,7 @@ class Reaction: state = message._state after_id = after.id if after else None - data = await state.http.get_reaction_users( - message.channel.id, message.id, emoji, retrieve, after=after_id - ) + data = await state.http.get_reaction_users(message.channel.id, message.id, emoji, retrieve, after=after_id) if data: limit -= len(data) @@ -241,4 +245,3 @@ class Reaction: member = guild.get_member(member_id) yield member or User(state=state, data=raw_user) - diff --git a/discord/stage_instance.py b/discord/stage_instance.py index 479e89f2c..3ac4fee03 100644 --- a/discord/stage_instance.py +++ b/discord/stage_instance.py @@ -31,9 +31,11 @@ from .mixins import Hashable from .errors import InvalidArgument from .enums import StagePrivacyLevel, try_enum +# fmt: off __all__ = ( 'StageInstance', ) +# fmt: on if TYPE_CHECKING: from .types.channel import StageInstance as StageInstancePayload @@ -107,12 +109,18 @@ class StageInstance(Hashable): def channel(self) -> Optional[StageChannel]: """Optional[:class:`StageChannel`]: The channel that stage instance is running in.""" # the returned channel will always be a StageChannel or None - return self._state.get_channel(self.channel_id) # type: ignore + return self._state.get_channel(self.channel_id) # type: ignore def is_public(self) -> bool: return self.privacy_level is StagePrivacyLevel.public - async def edit(self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None) -> None: + async def edit( + self, + *, + topic: str = MISSING, + privacy_level: StagePrivacyLevel = MISSING, + reason: Optional[str] = None, + ) -> None: """|coro| Edits the stage instance. diff --git a/discord/state.py b/discord/state.py index c939c7e8d..39a676773 100644 --- a/discord/state.py +++ b/discord/state.py @@ -446,7 +446,9 @@ class ConnectionState: # If presences are enabled then we get back the old guild.large behaviour return self._chunk_guilds and not guild.chunked and not (self._intents.presences and not guild.large) - def _get_guild_channel(self, data: MessagePayload, guild_id: Optional[int] = None) -> Tuple[Union[Channel, Thread], Optional[Guild]]: + def _get_guild_channel( + self, data: MessagePayload, guild_id: Optional[int] = None + ) -> Tuple[Union[Channel, Thread], Optional[Guild]]: channel_id = int(data['channel_id']) try: guild_id = guild_id or int(data['guild_id']) @@ -691,7 +693,7 @@ class ConnectionState: self._view_store.dispatch_view(component_type, custom_id, interaction) elif data['type'] == 5: # modal submit custom_id = interaction.data['custom_id'] # type: ignore - components = interaction.data['components'] # type: ignore + components = interaction.data['components'] # type: ignore self._view_store.dispatch_modal(custom_id, interaction, components) # type: ignore self.dispatch('interaction', interaction) diff --git a/discord/sticker.py b/discord/sticker.py index b0b5c678d..933b9f425 100644 --- a/discord/sticker.py +++ b/discord/sticker.py @@ -116,7 +116,7 @@ class StickerPack(Hashable): self.name: str = data['name'] self.sku_id: int = int(data['sku_id']) self.cover_sticker_id: int = int(data['cover_sticker_id']) - self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore + self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore self.description: str = data['description'] self._banner: int = int(data['banner_asset_id']) diff --git a/discord/template.py b/discord/template.py index 30af3a4d9..0af24c305 100644 --- a/discord/template.py +++ b/discord/template.py @@ -29,9 +29,11 @@ from .utils import parse_time, _get_as_snowflake, _bytes_to_base64_data, MISSING from .enums import VoiceRegion from .guild import Guild +# fmt: off __all__ = ( 'Template', ) +# fmt: on if TYPE_CHECKING: import datetime @@ -310,7 +312,7 @@ class Template: @property def url(self) -> str: """:class:`str`: The template url. - + .. versionadded:: 2.0 """ return f'https://discord.new/{self.code}' diff --git a/discord/types/appinfo.py b/discord/types/appinfo.py index 912d5ad5d..e691e812c 100644 --- a/discord/types/appinfo.py +++ b/discord/types/appinfo.py @@ -30,6 +30,7 @@ from .user import User from .team import Team from .snowflake import Snowflake + class BaseAppInfo(TypedDict): id: Snowflake name: str @@ -38,6 +39,7 @@ class BaseAppInfo(TypedDict): summary: str description: str + class _AppInfoOptional(TypedDict, total=False): team: Team guild_id: Snowflake @@ -48,12 +50,14 @@ class _AppInfoOptional(TypedDict, total=False): hook: bool max_participants: int + class AppInfo(BaseAppInfo, _AppInfoOptional): rpc_origins: List[str] owner: User bot_public: bool bot_require_code_grant: bool + class _PartialAppInfoOptional(TypedDict, total=False): rpc_origins: List[str] cover_image: str @@ -63,5 +67,6 @@ class _PartialAppInfoOptional(TypedDict, total=False): max_participants: int flags: int + class PartialAppInfo(_PartialAppInfoOptional, BaseAppInfo): pass diff --git a/discord/types/embed.py b/discord/types/embed.py index b38c9314c..2e56d272b 100644 --- a/discord/types/embed.py +++ b/discord/types/embed.py @@ -24,50 +24,61 @@ DEALINGS IN THE SOFTWARE. from typing import List, Literal, TypedDict + class _EmbedFooterOptional(TypedDict, total=False): icon_url: str proxy_icon_url: str + class EmbedFooter(_EmbedFooterOptional): text: str + class _EmbedFieldOptional(TypedDict, total=False): inline: bool + class EmbedField(_EmbedFieldOptional): name: str value: str + class EmbedThumbnail(TypedDict, total=False): url: str proxy_url: str height: int width: int + class EmbedVideo(TypedDict, total=False): url: str proxy_url: str height: int width: int + class EmbedImage(TypedDict, total=False): url: str proxy_url: str height: int width: int + class EmbedProvider(TypedDict, total=False): name: str url: str + class EmbedAuthor(TypedDict, total=False): name: str url: str icon_url: str proxy_icon_url: str + EmbedType = Literal['rich', 'image', 'video', 'gifv', 'article', 'link'] + class Embed(TypedDict, total=False): title: str type: EmbedType diff --git a/discord/types/team.py b/discord/types/team.py index 918ede605..83ed08137 100644 --- a/discord/types/team.py +++ b/discord/types/team.py @@ -29,12 +29,14 @@ from typing import TypedDict, List, Optional from .user import PartialUser from .snowflake import Snowflake + class TeamMember(TypedDict): user: PartialUser membership_state: int permissions: List[str] team_id: Snowflake + class Team(TypedDict): id: Snowflake name: str diff --git a/discord/ui/item.py b/discord/ui/item.py index 317327e58..6478f9418 100644 --- a/discord/ui/item.py +++ b/discord/ui/item.py @@ -28,9 +28,11 @@ from typing import Any, Callable, Coroutine, Dict, Generic, Optional, TYPE_CHECK from ..interactions import Interaction +# fmt: off __all__ = ( 'Item', ) +# fmt: on if TYPE_CHECKING: from ..enums import ComponentType diff --git a/discord/ui/modal.py b/discord/ui/modal.py index 2386a0c23..d368af9e4 100644 --- a/discord/ui/modal.py +++ b/discord/ui/modal.py @@ -42,9 +42,11 @@ if TYPE_CHECKING: from ..types.interactions import ModalSubmitComponentInteractionData as ModalSubmitComponentInteractionDataPayload +# fmt: off __all__ = ( 'Modal', ) +# fmt: on _log = logging.getLogger(__name__) diff --git a/discord/ui/select.py b/discord/ui/select.py index 3fe57c734..f04a5511c 100644 --- a/discord/ui/select.py +++ b/discord/ui/select.py @@ -224,7 +224,6 @@ class Select(Item[V]): default=default, ) - self.append_option(option) def append_option(self, option: SelectOption): diff --git a/discord/ui/text_input.py b/discord/ui/text_input.py index 43fdbbe77..fd2a3cca0 100644 --- a/discord/ui/text_input.py +++ b/discord/ui/text_input.py @@ -40,10 +40,11 @@ if TYPE_CHECKING: from .view import View +# fmt: off __all__ = ( 'TextInput', ) - +# fmt: on V = TypeVar('V', bound='View', covariant=True) @@ -177,7 +178,7 @@ class TextInput(Item[V]): def max_length(self) -> Optional[int]: """:class:`int`: The maximum length of the text input.""" return self._underlying.max_length - + @max_length.setter def max_length(self, value: Optional[int]) -> None: self._underlying.max_length = value diff --git a/discord/ui/view.py b/discord/ui/view.py index 0879a3399..b99e1ec4c 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -43,9 +43,11 @@ from ..components import ( ) from ..utils import MISSING +# fmt: off __all__ = ( 'View', ) +# fmt: on if TYPE_CHECKING: @@ -81,9 +83,11 @@ def _component_to_item(component: Component) -> Item: class _ViewWeights: + # fmt: off __slots__ = ( 'weights', ) + # fmt: on def __init__(self, children: List[Item]): self.weights: List[int] = [0, 0, 0, 0, 0] @@ -517,7 +521,7 @@ class ViewStore: def remove_view(self, view: View): if view.__discord_ui_modal__: self._modals.pop(view.custom_id, None) # type: ignore - return + return for item in view.children: if item.is_dispatchable(): @@ -542,7 +546,12 @@ class ViewStore: item.refresh_state(interaction.data) # type: ignore view._dispatch_item(item, interaction) - def dispatch_modal(self, custom_id: str, interaction: Interaction, components: List[ModalSubmitComponentInteractionDataPayload]): + def dispatch_modal( + self, + custom_id: str, + interaction: Interaction, + components: List[ModalSubmitComponentInteractionDataPayload], + ): modal = self._modals.get(custom_id) if modal is None: _log.debug("Modal interaction referencing unknown custom_id %s. Discarding", custom_id) diff --git a/discord/utils.py b/discord/utils.py index 867401f92..68b14407d 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -360,7 +360,7 @@ def time_snowflake(dt: datetime.datetime, high: bool = False) -> int: The snowflake representing the time given. """ discord_millis = int(dt.timestamp() * 1000 - DISCORD_EPOCH) - return (discord_millis << 22) + (2 ** 22 - 1 if high else 0) + return (discord_millis << 22) + (2**22 - 1 if high else 0) def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> Optional[T]: diff --git a/discord/voice_client.py b/discord/voice_client.py index d382a74d7..fca3aa592 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -66,12 +66,13 @@ if TYPE_CHECKING: VoiceServerUpdate as VoiceServerUpdatePayload, SupportedModes, ) - + has_nacl: bool try: import nacl.secret # type: ignore + has_nacl = True except ImportError: has_nacl = False @@ -82,10 +83,9 @@ __all__ = ( ) - - _log = logging.getLogger(__name__) + class VoiceProtocol: """A class that represents the Discord voice protocol. @@ -195,6 +195,7 @@ class VoiceProtocol: key_id, _ = self.channel._get_voice_client_key() self.client._connection._remove_voice_client(key_id) + class VoiceClient(VoiceProtocol): """Represents a Discord voice connection. @@ -221,12 +222,12 @@ class VoiceClient(VoiceProtocol): loop: :class:`asyncio.AbstractEventLoop` The event loop that the voice client is running on. """ + endpoint_ip: str voice_port: int secret_key: List[int] ssrc: int - def __init__(self, client: Client, channel: abc.Connectable): if not has_nacl: raise RuntimeError("PyNaCl library needed in order to use voice") @@ -309,8 +310,10 @@ class VoiceClient(VoiceProtocol): endpoint = data.get('endpoint') if endpoint is None or self.token is None: - _log.warning('Awaiting endpoint... This requires waiting. ' \ - 'If timeout occurred considering raising the timeout and reconnecting.') + _log.warning( + 'Awaiting endpoint... This requires waiting. ' + 'If timeout occurred considering raising the timeout and reconnecting.' + ) return self.endpoint, _, _ = endpoint.rpartition(':') @@ -359,7 +362,7 @@ class VoiceClient(VoiceProtocol): self._connected.set() return ws - async def connect(self, *, reconnect: bool, timeout: float) ->None: + async def connect(self, *, reconnect: bool, timeout: float) -> None: _log.info('Connecting to voice...') self.timeout = timeout @@ -556,7 +559,7 @@ class VoiceClient(VoiceProtocol): return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4] - def play(self, source: AudioSource, *, after: Callable[[Optional[Exception]], Any]=None) -> None: + def play(self, source: AudioSource, *, after: Callable[[Optional[Exception]], Any] = None) -> None: """Plays an :class:`AudioSource`. The finalizer, ``after`` is called after the source has been exhausted diff --git a/discord/webhook/async_.py b/discord/webhook/async_.py index a2f128564..989b2b445 100644 --- a/discord/webhook/async_.py +++ b/discord/webhook/async_.py @@ -681,7 +681,7 @@ class WebhookMessage(Message): attachments: List[Union[:class:`Attachment`, :class:`File`]] A list of attachments to keep in the message as well as new files to upload. If ``[]`` is passed then all attachments are removed. - + .. note:: New files will always appear after current attachments. @@ -761,7 +761,7 @@ class WebhookMessage(Message): ----------- \*attachments: :class:`Attachment` Attachments to remove from the message. - + Raises ------- HTTPException diff --git a/discord/webhook/sync.py b/discord/webhook/sync.py index c2adf3f6c..dc579d755 100644 --- a/discord/webhook/sync.py +++ b/discord/webhook/sync.py @@ -469,7 +469,7 @@ class SyncWebhookMessage(Message): ----------- \*attachments: :class:`Attachment` Attachments to remove from the message. - + Raises ------- HTTPException diff --git a/discord/widget.py b/discord/widget.py index 36b6e3dd5..10075caf6 100644 --- a/discord/widget.py +++ b/discord/widget.py @@ -46,6 +46,7 @@ __all__ = ( 'Widget', ) + class WidgetChannel: """Represents a "partial" widget channel. @@ -76,6 +77,7 @@ class WidgetChannel: position: :class:`int` The channel's position """ + __slots__ = ('id', 'name', 'position') def __init__(self, id: int, name: str, position: int) -> None: @@ -99,6 +101,7 @@ class WidgetChannel: """:class:`datetime.datetime`: Returns the channel's creation time in UTC.""" return snowflake_time(self.id) + class WidgetMember(BaseUser): """Represents a "partial" member of the widget's guild. @@ -147,9 +150,21 @@ class WidgetMember(BaseUser): connected_channel: Optional[:class:`WidgetChannel`] Which channel the member is connected to. """ - __slots__ = ('name', 'status', 'nick', 'avatar', 'discriminator', - 'id', 'bot', 'activity', 'deafened', 'suppress', 'muted', - 'connected_channel') + + __slots__ = ( + 'name', + 'status', + 'nick', + 'avatar', + 'discriminator', + 'id', + 'bot', + 'activity', + 'deafened', + 'suppress', + 'muted', + 'connected_channel', + ) if TYPE_CHECKING: activity: Optional[Union[BaseActivity, Spotify]] @@ -159,7 +174,7 @@ class WidgetMember(BaseUser): *, state: ConnectionState, data: WidgetMemberPayload, - connected_channel: Optional[WidgetChannel] = None + connected_channel: Optional[WidgetChannel] = None, ) -> None: super().__init__(state=state, data=data) self.nick: Optional[str] = data.get('nick') @@ -181,8 +196,7 @@ class WidgetMember(BaseUser): def __repr__(self) -> str: return ( - f"" + f"" ) @property @@ -190,6 +204,7 @@ class WidgetMember(BaseUser): """:class:`str`: Returns the member's display name.""" return self.nick or self.name + class Widget: """Represents a :class:`Guild` widget. @@ -227,6 +242,7 @@ class Widget: retrieved is capped. """ + __slots__ = ('_state', 'channels', '_invite', 'id', 'members', 'name') def __init__(self, *, state: ConnectionState, data: WidgetPayload) -> None: