Browse Source

Fix typing issues and improve typing completeness across the library

Co-authored-by: Danny <[email protected]>
Co-authored-by: Josh <[email protected]>
pull/7681/head
Stocker 3 years ago
committed by GitHub
parent
commit
5aa696ccfa
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 21
      discord/__main__.py
  2. 56
      discord/abc.py
  3. 42
      discord/activity.py
  4. 22
      discord/app_commands/commands.py
  5. 8
      discord/app_commands/errors.py
  6. 87
      discord/app_commands/models.py
  7. 4
      discord/app_commands/transformers.py
  8. 53
      discord/app_commands/tree.py
  9. 53
      discord/asset.py
  10. 25
      discord/audit_logs.py
  11. 22
      discord/channel.py
  12. 10
      discord/client.py
  13. 13
      discord/colour.py
  14. 10
      discord/components.py
  15. 10
      discord/context_managers.py
  16. 12
      discord/embeds.py
  17. 4
      discord/emoji.py
  18. 44
      discord/enums.py
  19. 18
      discord/ext/commands/_types.py
  20. 88
      discord/ext/commands/bot.py
  21. 16
      discord/ext/commands/cog.py
  22. 15
      discord/ext/commands/context.py
  23. 86
      discord/ext/commands/converter.py
  24. 2
      discord/ext/commands/cooldowns.py
  25. 115
      discord/ext/commands/core.py
  26. 14
      discord/ext/commands/errors.py
  27. 18
      discord/ext/commands/flags.py
  28. 293
      discord/ext/commands/help.py
  29. 37
      discord/ext/commands/view.py
  30. 8
      discord/ext/tasks/__init__.py
  31. 2
      discord/file.py
  32. 20
      discord/flags.py
  33. 30
      discord/gateway.py
  34. 5
      discord/guild.py
  35. 22
      discord/http.py
  36. 17
      discord/integrations.py
  37. 7
      discord/interactions.py
  38. 4
      discord/invite.py
  39. 13
      discord/member.py
  40. 8
      discord/mentions.py
  41. 17
      discord/message.py
  42. 5
      discord/opus.py
  43. 15
      discord/partial_emoji.py
  44. 4
      discord/permissions.py
  45. 21
      discord/player.py
  46. 4
      discord/reaction.py
  47. 4
      discord/role.py
  48. 6
      discord/scheduled_event.py
  49. 5
      discord/shard.py
  50. 6
      discord/stage_instance.py
  51. 58
      discord/state.py
  52. 2
      discord/sticker.py
  53. 71
      discord/threads.py
  54. 1
      discord/types/activity.py
  55. 5
      discord/types/widget.py
  56. 17
      discord/ui/button.py
  57. 4
      discord/ui/item.py
  58. 8
      discord/ui/modal.py
  59. 18
      discord/ui/select.py
  60. 30
      discord/ui/view.py
  61. 6
      discord/user.py
  62. 36
      discord/utils.py
  63. 10
      discord/voice_client.py
  64. 122
      discord/webhook/async_.py
  65. 62
      discord/webhook/sync.py
  66. 2
      discord/widget.py

21
discord/__main__.py

@ -23,7 +23,8 @@ DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
from typing import Dict, Optional
from typing import Optional, Tuple, Dict
import argparse import argparse
import sys import sys
@ -35,7 +36,7 @@ import aiohttp
import platform import platform
def show_version(): def show_version() -> None:
entries = [] entries = []
entries.append('- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(sys.version_info)) entries.append('- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(sys.version_info))
@ -52,7 +53,7 @@ def show_version():
print('\n'.join(entries)) print('\n'.join(entries))
def core(parser, args): def core(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
if args.version: if args.version:
show_version() show_version()
@ -185,7 +186,7 @@ _base_table.update((chr(i), None) for i in range(32))
_translation_table = str.maketrans(_base_table) _translation_table = str.maketrans(_base_table)
def to_path(parser, name, *, replace_spaces=False): def to_path(parser: argparse.ArgumentParser, name: str, *, replace_spaces: bool = False) -> Path:
if isinstance(name, Path): if isinstance(name, Path):
return name return name
@ -223,7 +224,7 @@ def to_path(parser, name, *, replace_spaces=False):
return Path(name) return Path(name)
def newbot(parser, args): def newbot(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
new_directory = to_path(parser, args.directory) / to_path(parser, args.name) new_directory = to_path(parser, args.directory) / to_path(parser, args.name)
# as a note exist_ok for Path is a 3.5+ only feature # as a note exist_ok for Path is a 3.5+ only feature
@ -265,7 +266,7 @@ def newbot(parser, args):
print('successfully made bot at', new_directory) print('successfully made bot at', new_directory)
def newcog(parser, args): def newcog(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
cog_dir = to_path(parser, args.directory) cog_dir = to_path(parser, args.directory)
try: try:
cog_dir.mkdir(exist_ok=True) cog_dir.mkdir(exist_ok=True)
@ -299,7 +300,7 @@ def newcog(parser, args):
print('successfully made cog at', directory) print('successfully made cog at', directory)
def add_newbot_args(subparser): def add_newbot_args(subparser: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
parser = subparser.add_parser('newbot', help='creates a command bot project quickly') parser = subparser.add_parser('newbot', help='creates a command bot project quickly')
parser.set_defaults(func=newbot) parser.set_defaults(func=newbot)
@ -310,7 +311,7 @@ def add_newbot_args(subparser):
parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git') parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git')
def add_newcog_args(subparser): def add_newcog_args(subparser: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
parser = subparser.add_parser('newcog', help='creates a new cog template quickly') parser = subparser.add_parser('newcog', help='creates a new cog template quickly')
parser.set_defaults(func=newcog) parser.set_defaults(func=newcog)
@ -322,7 +323,7 @@ def add_newcog_args(subparser):
parser.add_argument('--full', help='add all special methods as well', action='store_true') parser.add_argument('--full', help='add all special methods as well', action='store_true')
def parse_args(): def parse_args() -> Tuple[argparse.ArgumentParser, argparse.Namespace]:
parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py') parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py')
parser.add_argument('-v', '--version', action='store_true', help='shows the library version') parser.add_argument('-v', '--version', action='store_true', help='shows the library version')
parser.set_defaults(func=core) parser.set_defaults(func=core)
@ -333,7 +334,7 @@ def parse_args():
return parser, parser.parse_args() return parser, parser.parse_args()
def main(): def main() -> None:
parser, args = parse_args() parser, args = parse_args()
args.func(parser, args) args.func(parser, args)

56
discord/abc.py

@ -91,6 +91,9 @@ if TYPE_CHECKING:
GuildChannel as GuildChannelPayload, GuildChannel as GuildChannelPayload,
OverwriteType, OverwriteType,
) )
from .types.snowflake import (
SnowflakeList,
)
PartialMessageableChannel = Union[TextChannel, Thread, DMChannel, PartialMessageable] PartialMessageableChannel = Union[TextChannel, Thread, DMChannel, PartialMessageable]
MessageableChannel = Union[PartialMessageableChannel, GroupChannel] MessageableChannel = Union[PartialMessageableChannel, GroupChannel]
@ -708,7 +711,14 @@ class GuildChannel:
) -> None: ) -> None:
... ...
async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions): async def set_permissions(
self,
target: Union[Member, Role],
*,
overwrite: Any = _undefined,
reason: Optional[str] = None,
**permissions: bool,
) -> None:
r"""|coro| r"""|coro|
Sets the channel specific permission overwrites for a target in the Sets the channel specific permission overwrites for a target in the
@ -917,7 +927,7 @@ class GuildChannel:
) -> None: ) -> None:
... ...
async def move(self, **kwargs) -> None: async def move(self, **kwargs: Any) -> None:
"""|coro| """|coro|
A rich interface to help move a channel relative to other channels. A rich interface to help move a channel relative to other channels.
@ -1248,22 +1258,22 @@ class Messageable:
async def send( async def send(
self, self,
content=None, content: Optional[str] = None,
*, *,
tts=False, tts: bool = False,
embed=None, embed: Optional[Embed] = None,
embeds=None, embeds: Optional[List[Embed]] = None,
file=None, file: Optional[File] = None,
files=None, files: Optional[List[File]] = None,
stickers=None, stickers: Optional[Sequence[Union[GuildSticker, StickerItem]]] = None,
delete_after=None, delete_after: Optional[float] = None,
nonce=None, nonce: Optional[Union[str, int]] = None,
allowed_mentions=None, allowed_mentions: Optional[AllowedMentions] = None,
reference=None, reference: Optional[Union[Message, MessageReference, PartialMessage]] = None,
mention_author=None, mention_author: Optional[bool] = None,
view=None, view: Optional[View] = None,
suppress_embeds=False, suppress_embeds: bool = False,
): ) -> Message:
"""|coro| """|coro|
Sends a message to the destination with the content given. Sends a message to the destination with the content given.
@ -1368,17 +1378,17 @@ class Messageable:
previous_allowed_mention = state.allowed_mentions previous_allowed_mention = state.allowed_mentions
if stickers is not None: if stickers is not None:
stickers = [sticker.id for sticker in stickers] sticker_ids: SnowflakeList = [sticker.id for sticker in stickers]
else: else:
stickers = MISSING sticker_ids = MISSING
if reference is not None: if reference is not None:
try: try:
reference = reference.to_message_reference_dict() reference_dict = reference.to_message_reference_dict()
except AttributeError: except AttributeError:
raise TypeError('reference parameter must be Message, MessageReference, or PartialMessage') from None raise TypeError('reference parameter must be Message, MessageReference, or PartialMessage') from None
else: else:
reference = MISSING reference_dict = MISSING
if view and not hasattr(view, '__discord_ui_view__'): if view and not hasattr(view, '__discord_ui_view__'):
raise TypeError(f'view parameter must be View not {view.__class__!r}') raise TypeError(f'view parameter must be View not {view.__class__!r}')
@ -1399,10 +1409,10 @@ class Messageable:
embeds=embeds if embeds is not None else MISSING, embeds=embeds if embeds is not None else MISSING,
nonce=nonce, nonce=nonce,
allowed_mentions=allowed_mentions, allowed_mentions=allowed_mentions,
message_reference=reference, message_reference=reference_dict,
previous_allowed_mentions=previous_allowed_mention, previous_allowed_mentions=previous_allowed_mention,
mention_author=mention_author, mention_author=mention_author,
stickers=stickers, stickers=sticker_ids,
view=view, view=view,
flags=flags, flags=flags,
) as params: ) as params:

42
discord/activity.py

@ -123,7 +123,7 @@ class BaseActivity:
__slots__ = ('_created_at',) __slots__ = ('_created_at',)
def __init__(self, **kwargs): def __init__(self, **kwargs: Any) -> None:
self._created_at: Optional[float] = kwargs.pop('created_at', None) self._created_at: Optional[float] = kwargs.pop('created_at', None)
@property @property
@ -218,7 +218,7 @@ class Activity(BaseActivity):
'buttons', 'buttons',
) )
def __init__(self, **kwargs): def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.state: Optional[str] = kwargs.pop('state', None) self.state: Optional[str] = kwargs.pop('state', None)
self.details: Optional[str] = kwargs.pop('details', None) self.details: Optional[str] = kwargs.pop('details', None)
@ -363,7 +363,7 @@ class Game(BaseActivity):
__slots__ = ('name', '_end', '_start') __slots__ = ('name', '_end', '_start')
def __init__(self, name: str, **extra): def __init__(self, name: str, **extra: Any) -> None:
super().__init__(**extra) super().__init__(**extra)
self.name: str = name self.name: str = name
@ -420,10 +420,10 @@ class Game(BaseActivity):
} }
# fmt: on # fmt: on
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, Game) and other.name == self.name return isinstance(other, Game) and other.name == self.name
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:
@ -477,7 +477,7 @@ class Streaming(BaseActivity):
__slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets') __slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets')
def __init__(self, *, name: Optional[str], url: str, **extra: Any): def __init__(self, *, name: Optional[str], url: str, **extra: Any) -> None:
super().__init__(**extra) super().__init__(**extra)
self.platform: Optional[str] = name self.platform: Optional[str] = name
self.name: Optional[str] = extra.pop('details', name) self.name: Optional[str] = extra.pop('details', name)
@ -501,7 +501,7 @@ class Streaming(BaseActivity):
return f'<Streaming name={self.name!r}>' return f'<Streaming name={self.name!r}>'
@property @property
def twitch_name(self): def twitch_name(self) -> Optional[str]:
"""Optional[:class:`str`]: If provided, the twitch name of the user streaming. """Optional[:class:`str`]: If provided, the twitch name of the user streaming.
This corresponds to the ``large_image`` key of the :attr:`Streaming.assets` This corresponds to the ``large_image`` key of the :attr:`Streaming.assets`
@ -528,10 +528,10 @@ class Streaming(BaseActivity):
ret['details'] = self.details ret['details'] = self.details
return ret return ret
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, Streaming) and other.name == self.name and other.url == self.url return isinstance(other, Streaming) and other.name == self.name and other.url == self.url
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:
@ -563,14 +563,14 @@ class Spotify:
__slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id', '_created_at') __slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id', '_created_at')
def __init__(self, **data): def __init__(self, **data: Any) -> None:
self._state: str = data.pop('state', '') self._state: str = data.pop('state', '')
self._details: str = data.pop('details', '') self._details: str = data.pop('details', '')
self._timestamps: Dict[str, int] = data.pop('timestamps', {}) self._timestamps: ActivityTimestamps = data.pop('timestamps', {})
self._assets: ActivityAssets = data.pop('assets', {}) self._assets: ActivityAssets = data.pop('assets', {})
self._party: ActivityParty = data.pop('party', {}) self._party: ActivityParty = data.pop('party', {})
self._sync_id: str = data.pop('sync_id') self._sync_id: str = data.pop('sync_id', '')
self._session_id: str = data.pop('session_id') self._session_id: Optional[str] = data.pop('session_id')
self._created_at: Optional[float] = data.pop('created_at', None) self._created_at: Optional[float] = data.pop('created_at', None)
@property @property
@ -622,7 +622,7 @@ class Spotify:
""":class:`str`: The activity's name. This will always return "Spotify".""" """:class:`str`: The activity's name. This will always return "Spotify"."""
return 'Spotify' return 'Spotify'
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return ( return (
isinstance(other, Spotify) isinstance(other, Spotify)
and other._session_id == self._session_id and other._session_id == self._session_id
@ -630,7 +630,7 @@ class Spotify:
and other.start == self.start and other.start == self.start
) )
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:
@ -691,12 +691,14 @@ class Spotify:
@property @property
def start(self) -> datetime.datetime: def start(self) -> datetime.datetime:
""":class:`datetime.datetime`: When the user started playing this song in UTC.""" """:class:`datetime.datetime`: When the user started playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc) # the start key will be present here
return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc) # type: ignore
@property @property
def end(self) -> datetime.datetime: def end(self) -> datetime.datetime:
""":class:`datetime.datetime`: When the user will stop playing this song in UTC.""" """:class:`datetime.datetime`: When the user will stop playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc) # the end key will be present here
return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc) # type: ignore
@property @property
def duration(self) -> datetime.timedelta: def duration(self) -> datetime.timedelta:
@ -742,7 +744,7 @@ class CustomActivity(BaseActivity):
__slots__ = ('name', 'emoji', 'state') __slots__ = ('name', 'emoji', 'state')
def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any): def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any) -> None:
super().__init__(**extra) super().__init__(**extra)
self.name: Optional[str] = name self.name: Optional[str] = name
self.state: Optional[str] = extra.pop('state', None) self.state: Optional[str] = extra.pop('state', None)
@ -786,10 +788,10 @@ class CustomActivity(BaseActivity):
o['emoji'] = self.emoji.to_dict() o['emoji'] = self.emoji.to_dict()
return o return o
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, CustomActivity) and other.name == self.name and other.emoji == self.emoji return isinstance(other, CustomActivity) and other.name == self.name and other.emoji == self.emoji
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:

22
discord/app_commands/commands.py

@ -166,7 +166,7 @@ def _validate_auto_complete_callback(
return callback return callback
def _context_menu_annotation(annotation: Any, *, _none=NoneType) -> AppCommandType: def _context_menu_annotation(annotation: Any, *, _none: type = NoneType) -> AppCommandType:
if annotation is Message: if annotation is Message:
return AppCommandType.message return AppCommandType.message
@ -686,7 +686,7 @@ class Group:
The parent group. ``None`` if there isn't one. The parent group. ``None`` if there isn't one.
""" """
__discord_app_commands_group_children__: ClassVar[List[Union[Command, Group]]] = [] __discord_app_commands_group_children__: ClassVar[List[Union[Command[Any, ..., Any], Group]]] = []
__discord_app_commands_skip_init_binding__: bool = False __discord_app_commands_skip_init_binding__: bool = False
__discord_app_commands_group_name__: str = MISSING __discord_app_commands_group_name__: str = MISSING
__discord_app_commands_group_description__: str = MISSING __discord_app_commands_group_description__: str = MISSING
@ -694,10 +694,12 @@ class Group:
def __init_subclass__(cls, *, name: str = MISSING, description: str = MISSING) -> None: def __init_subclass__(cls, *, name: str = MISSING, description: str = MISSING) -> None:
if not cls.__discord_app_commands_group_children__: if not cls.__discord_app_commands_group_children__:
cls.__discord_app_commands_group_children__ = children = [ children: List[Union[Command[Any, ..., Any], Group]] = [
member for member in cls.__dict__.values() if isinstance(member, (Group, Command)) and member.parent is None member for member in cls.__dict__.values() if isinstance(member, (Group, Command)) and member.parent is None
] ]
cls.__discord_app_commands_group_children__ = children
found = set() found = set()
for child in children: for child in children:
if child.name in found: if child.name in found:
@ -796,15 +798,15 @@ class Group:
"""Optional[:class:`Group`]: The parent of this group.""" """Optional[:class:`Group`]: The parent of this group."""
return self.parent return self.parent
def _get_internal_command(self, name: str) -> Optional[Union[Command, Group]]: def _get_internal_command(self, name: str) -> Optional[Union[Command[Any, ..., Any], Group]]:
return self._children.get(name) return self._children.get(name)
@property @property
def commands(self) -> List[Union[Command, Group]]: def commands(self) -> List[Union[Command[Any, ..., Any], Group]]:
"""List[Union[:class:`Command`, :class:`Group`]]: The commands that this group contains.""" """List[Union[:class:`Command`, :class:`Group`]]: The commands that this group contains."""
return list(self._children.values()) return list(self._children.values())
async def on_error(self, interaction: Interaction, command: Command, error: AppCommandError) -> None: async def on_error(self, interaction: Interaction, command: Command[Any, ..., Any], error: AppCommandError) -> None:
"""|coro| """|coro|
A callback that is called when a child's command raises an :exc:`AppCommandError`. A callback that is called when a child's command raises an :exc:`AppCommandError`.
@ -823,7 +825,7 @@ class Group:
pass pass
def add_command(self, command: Union[Command, Group], /, *, override: bool = False): def add_command(self, command: Union[Command[Any, ..., Any], Group], /, *, override: bool = False) -> None:
"""Adds a command or group to this group's internal list of commands. """Adds a command or group to this group's internal list of commands.
Parameters Parameters
@ -855,7 +857,7 @@ class Group:
if len(self._children) > 25: if len(self._children) > 25:
raise ValueError('maximum number of child commands exceeded') raise ValueError('maximum number of child commands exceeded')
def remove_command(self, name: str, /) -> Optional[Union[Command, Group]]: def remove_command(self, name: str, /) -> Optional[Union[Command[Any, ..., Any], Group]]:
"""Removes a command or group from the internal list of commands. """Removes a command or group from the internal list of commands.
Parameters Parameters
@ -872,7 +874,7 @@ class Group:
self._children.pop(name, None) self._children.pop(name, None)
def get_command(self, name: str, /) -> Optional[Union[Command, Group]]: def get_command(self, name: str, /) -> Optional[Union[Command[Any, ..., Any], Group]]:
"""Retrieves a command or group from its name. """Retrieves a command or group from its name.
Parameters Parameters
@ -1046,7 +1048,7 @@ def describe(**parameters: str) -> Callable[[T], T]:
return decorator return decorator
def choices(**parameters: List[Choice]) -> Callable[[T], T]: def choices(**parameters: List[Choice[ChoiceT]]) -> Callable[[T], T]:
r"""Instructs the given parameters by their name to use the given choices for their choices. r"""Instructs the given parameters by their name to use the given choices for their choices.
Example: Example:

8
discord/app_commands/errors.py

@ -79,9 +79,9 @@ class CommandInvokeError(AppCommandError):
The command that failed. The command that failed.
""" """
def __init__(self, command: Union[Command, ContextMenu], e: Exception) -> None: def __init__(self, command: Union[Command[Any, ..., Any], ContextMenu], e: Exception) -> None:
self.original: Exception = e self.original: Exception = e
self.command: Union[Command, ContextMenu] = command self.command: Union[Command[Any, ..., Any], ContextMenu] = command
super().__init__(f'Command {command.name!r} raised an exception: {e.__class__.__name__}: {e}') super().__init__(f'Command {command.name!r} raised an exception: {e.__class__.__name__}: {e}')
@ -191,8 +191,8 @@ class CommandSignatureMismatch(AppCommandError):
The command that had the signature mismatch. The command that had the signature mismatch.
""" """
def __init__(self, command: Union[Command, ContextMenu, Group]): def __init__(self, command: Union[Command[Any, ..., Any], ContextMenu, Group]):
self.command: Union[Command, ContextMenu, Group] = command self.command: Union[Command[Any, ..., Any], ContextMenu, Group] = command
msg = ( msg = (
f'The signature for command {command.name!r} is different from the one provided by Discord. ' f'The signature for command {command.name!r} is different from the one provided by Discord. '
'This can happen because either your code is out of date or you have not synced the ' 'This can happen because either your code is out of date or you have not synced the '

87
discord/app_commands/models.py

@ -58,7 +58,10 @@ if TYPE_CHECKING:
PartialChannel, PartialChannel,
PartialThread, PartialThread,
) )
from ..types.threads import ThreadMetadata from ..types.threads import (
ThreadMetadata,
ThreadArchiveDuration,
)
from ..state import ConnectionState from ..state import ConnectionState
from ..guild import GuildChannel, Guild from ..guild import GuildChannel, Guild
from ..channel import TextChannel from ..channel import TextChannel
@ -117,17 +120,19 @@ class AppCommand(Hashable):
'_state', '_state',
) )
def __init__(self, *, data: ApplicationCommandPayload, state=None): def __init__(self, *, data: ApplicationCommandPayload, state: Optional[ConnectionState] = None) -> None:
self._state = state self._state: Optional[ConnectionState] = state
self._from_data(data) self._from_data(data)
def _from_data(self, data: ApplicationCommandPayload): def _from_data(self, data: ApplicationCommandPayload) -> None:
self.id: int = int(data['id']) self.id: int = int(data['id'])
self.application_id: int = int(data['application_id']) self.application_id: int = int(data['application_id'])
self.name: str = data['name'] self.name: str = data['name']
self.description: str = data['description'] self.description: str = data['description']
self.type: AppCommandType = try_enum(AppCommandType, data.get('type', 1)) self.type: AppCommandType = try_enum(AppCommandType, data.get('type', 1))
self.options = [app_command_option_factory(data=d, parent=self, state=self._state) for d in data.get('options', [])] self.options: List[Union[Argument, AppCommandGroup]] = [
app_command_option_factory(data=d, parent=self, state=self._state) for d in data.get('options', [])
]
def to_dict(self) -> ApplicationCommandPayload: def to_dict(self) -> ApplicationCommandPayload:
return { return {
@ -262,12 +267,12 @@ class AppCommandChannel(Hashable):
data: PartialChannel, data: PartialChannel,
guild_id: int, guild_id: int,
): ):
self._state = state self._state: ConnectionState = state
self.guild_id = guild_id self.guild_id: int = guild_id
self.id = int(data['id']) self.id: int = int(data['id'])
self.type = try_enum(ChannelType, data['type']) self.type: ChannelType = try_enum(ChannelType, data['type'])
self.name = data['name'] self.name: str = data['name']
self.permissions = Permissions(int(data['permissions'])) self.permissions: Permissions = Permissions(int(data['permissions']))
def __str__(self) -> str: def __str__(self) -> str:
return self.name return self.name
@ -405,13 +410,13 @@ class AppCommandThread(Hashable):
data: PartialThread, data: PartialThread,
guild_id: int, guild_id: int,
): ):
self._state = state self._state: ConnectionState = state
self.guild_id = guild_id self.guild_id: int = guild_id
self.id = int(data['id']) self.id: int = int(data['id'])
self.parent_id = int(data['parent_id']) self.parent_id: int = int(data['parent_id'])
self.type = try_enum(ChannelType, data['type']) self.type: ChannelType = try_enum(ChannelType, data['type'])
self.name = data['name'] self.name: str = data['name']
self.permissions = Permissions(int(data['permissions'])) self.permissions: Permissions = Permissions(int(data['permissions']))
self._unroll_metadata(data['thread_metadata']) self._unroll_metadata(data['thread_metadata'])
def __str__(self) -> str: def __str__(self) -> str:
@ -425,14 +430,14 @@ class AppCommandThread(Hashable):
"""Optional[:class:`~discord.Guild`]: The channel's guild, from cache, if found.""" """Optional[:class:`~discord.Guild`]: The channel's guild, from cache, if found."""
return self._state._get_guild(self.guild_id) return self._state._get_guild(self.guild_id)
def _unroll_metadata(self, data: ThreadMetadata): def _unroll_metadata(self, data: ThreadMetadata) -> None:
self.archived = data['archived'] self.archived: bool = data['archived']
self.archiver_id = _get_as_snowflake(data, 'archiver_id') self.archiver_id: Optional[int] = _get_as_snowflake(data, 'archiver_id')
self.auto_archive_duration = data['auto_archive_duration'] self.auto_archive_duration: ThreadArchiveDuration = data['auto_archive_duration']
self.archive_timestamp = parse_time(data['archive_timestamp']) self.archive_timestamp: datetime = parse_time(data['archive_timestamp'])
self.locked = data.get('locked', False) self.locked: bool = data.get('locked', False)
self.invitable = data.get('invitable', True) self.invitable: bool = data.get('invitable', True)
self._created_at = parse_time(data.get('create_timestamp')) self._created_at: Optional[datetime] = parse_time(data.get('create_timestamp'))
@property @property
def parent(self) -> Optional[TextChannel]: def parent(self) -> Optional[TextChannel]:
@ -522,20 +527,24 @@ class Argument:
'_state', '_state',
) )
def __init__(self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state=None): def __init__(
self._state = state self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state: Optional[ConnectionState] = None
self.parent = parent ) -> None:
self._state: Optional[ConnectionState] = state
self.parent: ApplicationCommandParent = parent
self._from_data(data) self._from_data(data)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<{self.__class__.__name__} name={self.name!r} type={self.type!r} required={self.required}>' return f'<{self.__class__.__name__} name={self.name!r} type={self.type!r} required={self.required}>'
def _from_data(self, data: ApplicationCommandOption): def _from_data(self, data: ApplicationCommandOption) -> None:
self.type: AppCommandOptionType = try_enum(AppCommandOptionType, data['type']) self.type: AppCommandOptionType = try_enum(AppCommandOptionType, data['type'])
self.name: str = data['name'] self.name: str = data['name']
self.description: str = data['description'] self.description: str = data['description']
self.required: bool = data.get('required', False) self.required: bool = data.get('required', False)
self.choices: List[Choice] = [Choice(name=d['name'], value=d['value']) for d in data.get('choices', [])] self.choices: List[Choice[Union[int, float, str]]] = [
Choice(name=d['name'], value=d['value']) for d in data.get('choices', [])
]
def to_dict(self) -> ApplicationCommandOption: def to_dict(self) -> ApplicationCommandOption:
return { return {
@ -582,20 +591,24 @@ class AppCommandGroup:
'_state', '_state',
) )
def __init__(self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state=None): def __init__(
self.parent = parent self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state: Optional[ConnectionState] = None
self._state = state ) -> None:
self.parent: ApplicationCommandParent = parent
self._state: Optional[ConnectionState] = state
self._from_data(data) self._from_data(data)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<{self.__class__.__name__} name={self.name!r} type={self.type!r} required={self.required}>' return f'<{self.__class__.__name__} name={self.name!r} type={self.type!r} required={self.required}>'
def _from_data(self, data: ApplicationCommandOption): def _from_data(self, data: ApplicationCommandOption) -> None:
self.type: AppCommandOptionType = try_enum(AppCommandOptionType, data['type']) self.type: AppCommandOptionType = try_enum(AppCommandOptionType, data['type'])
self.name: str = data['name'] self.name: str = data['name']
self.description: str = data['description'] self.description: str = data['description']
self.required: bool = data.get('required', False) self.required: bool = data.get('required', False)
self.choices: List[Choice] = [Choice(name=d['name'], value=d['value']) for d in data.get('choices', [])] self.choices: List[Choice[Union[int, float, str]]] = [
Choice(name=d['name'], value=d['value']) for d in data.get('choices', [])
]
self.arguments: List[Argument] = [ self.arguments: List[Argument] = [
Argument(parent=self, state=self._state, data=d) Argument(parent=self, state=self._state, data=d)
for d in data.get('options', []) for d in data.get('options', [])
@ -614,7 +627,7 @@ class AppCommandGroup:
def app_command_option_factory( def app_command_option_factory(
parent: ApplicationCommandParent, data: ApplicationCommandOption, *, state=None parent: ApplicationCommandParent, data: ApplicationCommandOption, *, state: Optional[ConnectionState] = None
) -> Union[Argument, AppCommandGroup]: ) -> Union[Argument, AppCommandGroup]:
if is_app_command_argument_type(data['type']): if is_app_command_argument_type(data['type']):
return Argument(parent=parent, data=data, state=state) return Argument(parent=parent, data=data, state=state)

4
discord/app_commands/transformers.py

@ -95,7 +95,7 @@ class CommandParameter:
description: str = MISSING description: str = MISSING
required: bool = MISSING required: bool = MISSING
default: Any = MISSING default: Any = MISSING
choices: List[Choice] = MISSING choices: List[Choice[Union[str, int, float]]] = MISSING
type: AppCommandOptionType = MISSING type: AppCommandOptionType = MISSING
channel_types: List[ChannelType] = MISSING channel_types: List[ChannelType] = MISSING
min_value: Optional[Union[int, float]] = None min_value: Optional[Union[int, float]] = None
@ -549,7 +549,7 @@ ALLOWED_DEFAULTS: Dict[AppCommandOptionType, Tuple[Type[Any], ...]] = {
def get_supported_annotation( def get_supported_annotation(
annotation: Any, annotation: Any,
*, *,
_none=NoneType, _none: type = NoneType,
_mapping: Dict[Any, Type[Transformer]] = BUILT_IN_TRANSFORMERS, _mapping: Dict[Any, Type[Transformer]] = BUILT_IN_TRANSFORMERS,
) -> Tuple[Any, Any]: ) -> Tuple[Any, Any]:
"""Returns an appropriate, yet supported, annotation along with an optional default value. """Returns an appropriate, yet supported, annotation along with an optional default value.

53
discord/app_commands/tree.py

@ -26,7 +26,22 @@ from __future__ import annotations
import inspect import inspect
import sys import sys
import traceback import traceback
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, TYPE_CHECKING, Set, Tuple, TypeVar, Union, overload
from typing import (
Any,
TYPE_CHECKING,
Callable,
Dict,
Generic,
List,
Literal,
Optional,
Set,
Tuple,
TypeVar,
Union,
overload,
)
from collections import Counter from collections import Counter
@ -194,13 +209,13 @@ class CommandTree(Generic[ClientT]):
def add_command( def add_command(
self, self,
command: Union[Command, ContextMenu, Group], command: Union[Command[Any, ..., Any], ContextMenu, Group],
/, /,
*, *,
guild: Optional[Snowflake] = MISSING, guild: Optional[Snowflake] = MISSING,
guilds: List[Snowflake] = MISSING, guilds: List[Snowflake] = MISSING,
override: bool = False, override: bool = False,
): ) -> None:
"""Adds an application command to the tree. """Adds an application command to the tree.
This only adds the command locally -- in order to sync the commands This only adds the command locally -- in order to sync the commands
@ -317,7 +332,7 @@ class CommandTree(Generic[ClientT]):
*, *,
guild: Optional[Snowflake] = ..., guild: Optional[Snowflake] = ...,
type: Literal[AppCommandType.chat_input] = ..., type: Literal[AppCommandType.chat_input] = ...,
) -> Optional[Union[Command, Group]]: ) -> Optional[Union[Command[Any, ..., Any], Group]]:
... ...
@overload @overload
@ -328,7 +343,7 @@ class CommandTree(Generic[ClientT]):
*, *,
guild: Optional[Snowflake] = ..., guild: Optional[Snowflake] = ...,
type: AppCommandType = ..., type: AppCommandType = ...,
) -> Optional[Union[Command, ContextMenu, Group]]: ) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]:
... ...
def remove_command( def remove_command(
@ -338,7 +353,7 @@ class CommandTree(Generic[ClientT]):
*, *,
guild: Optional[Snowflake] = None, guild: Optional[Snowflake] = None,
type: AppCommandType = AppCommandType.chat_input, type: AppCommandType = AppCommandType.chat_input,
) -> Optional[Union[Command, ContextMenu, Group]]: ) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]:
"""Removes an application command from the tree. """Removes an application command from the tree.
This only removes the command locally -- in order to sync the commands This only removes the command locally -- in order to sync the commands
@ -396,7 +411,7 @@ class CommandTree(Generic[ClientT]):
*, *,
guild: Optional[Snowflake] = ..., guild: Optional[Snowflake] = ...,
type: Literal[AppCommandType.chat_input] = ..., type: Literal[AppCommandType.chat_input] = ...,
) -> Optional[Union[Command, Group]]: ) -> Optional[Union[Command[Any, ..., Any], Group]]:
... ...
@overload @overload
@ -407,7 +422,7 @@ class CommandTree(Generic[ClientT]):
*, *,
guild: Optional[Snowflake] = ..., guild: Optional[Snowflake] = ...,
type: AppCommandType = ..., type: AppCommandType = ...,
) -> Optional[Union[Command, ContextMenu, Group]]: ) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]:
... ...
def get_command( def get_command(
@ -417,7 +432,7 @@ class CommandTree(Generic[ClientT]):
*, *,
guild: Optional[Snowflake] = None, guild: Optional[Snowflake] = None,
type: AppCommandType = AppCommandType.chat_input, type: AppCommandType = AppCommandType.chat_input,
) -> Optional[Union[Command, ContextMenu, Group]]: ) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]:
"""Gets a application command from the tree. """Gets a application command from the tree.
Parameters Parameters
@ -468,7 +483,7 @@ class CommandTree(Generic[ClientT]):
*, *,
guild: Optional[Snowflake] = ..., guild: Optional[Snowflake] = ...,
type: Literal[AppCommandType.chat_input] = ..., type: Literal[AppCommandType.chat_input] = ...,
) -> List[Union[Command, Group]]: ) -> List[Union[Command[Any, ..., Any], Group]]:
... ...
@overload @overload
@ -477,7 +492,7 @@ class CommandTree(Generic[ClientT]):
*, *,
guild: Optional[Snowflake] = ..., guild: Optional[Snowflake] = ...,
type: AppCommandType = ..., type: AppCommandType = ...,
) -> Union[List[Union[Command, Group]], List[ContextMenu]]: ) -> Union[List[Union[Command[Any, ..., Any], Group]], List[ContextMenu]]:
... ...
def get_commands( def get_commands(
@ -485,7 +500,7 @@ class CommandTree(Generic[ClientT]):
*, *,
guild: Optional[Snowflake] = None, guild: Optional[Snowflake] = None,
type: AppCommandType = AppCommandType.chat_input, type: AppCommandType = AppCommandType.chat_input,
) -> Union[List[Union[Command, Group]], List[ContextMenu]]: ) -> Union[List[Union[Command[Any, ..., Any], Group]], List[ContextMenu]]:
"""Gets all application commands from the tree. """Gets all application commands from the tree.
Parameters Parameters
@ -518,9 +533,11 @@ class CommandTree(Generic[ClientT]):
value = type.value value = type.value
return [command for ((_, g, t), command) in self._context_menus.items() if g == guild_id and t == value] return [command for ((_, g, t), command) in self._context_menus.items() if g == guild_id and t == value]
def _get_all_commands(self, *, guild: Optional[Snowflake] = None) -> List[Union[Command, Group, ContextMenu]]: def _get_all_commands(
self, *, guild: Optional[Snowflake] = None
) -> List[Union[Command[Any, ..., Any], Group, ContextMenu]]:
if guild is None: if guild is None:
base: List[Union[Command, Group, ContextMenu]] = list(self._global_commands.values()) base: List[Union[Command[Any, ..., Any], Group, ContextMenu]] = list(self._global_commands.values())
base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g is None) base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g is None)
return base return base
else: else:
@ -530,7 +547,7 @@ class CommandTree(Generic[ClientT]):
guild_id = guild.id guild_id = guild.id
return [cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id] return [cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id]
else: else:
base: List[Union[Command, Group, ContextMenu]] = list(commands.values()) base: List[Union[Command[Any, ..., Any], Group, ContextMenu]] = list(commands.values())
guild_id = guild.id guild_id = guild.id
base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id) base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id)
return base return base
@ -564,7 +581,7 @@ class CommandTree(Generic[ClientT]):
async def on_error( async def on_error(
self, self,
interaction: Interaction, interaction: Interaction,
command: Optional[Union[ContextMenu, Command]], command: Optional[Union[ContextMenu, Command[Any, ..., Any]]],
error: AppCommandError, error: AppCommandError,
) -> None: ) -> None:
"""|coro| """|coro|
@ -742,7 +759,7 @@ class CommandTree(Generic[ClientT]):
self.client.loop.create_task(wrapper(), name='CommandTree-invoker') self.client.loop.create_task(wrapper(), name='CommandTree-invoker')
async def _call_context_menu(self, interaction: Interaction, data: ApplicationCommandInteractionData, type: int): async def _call_context_menu(self, interaction: Interaction, data: ApplicationCommandInteractionData, type: int) -> None:
name = data['name'] name = data['name']
guild_id = _get_as_snowflake(data, 'guild_id') guild_id = _get_as_snowflake(data, 'guild_id')
ctx_menu = self._context_menus.get((name, guild_id, type)) ctx_menu = self._context_menus.get((name, guild_id, type))
@ -770,7 +787,7 @@ class CommandTree(Generic[ClientT]):
except AppCommandError as e: except AppCommandError as e:
await self.on_error(interaction, ctx_menu, e) await self.on_error(interaction, ctx_menu, e)
async def call(self, interaction: Interaction): async def call(self, interaction: Interaction) -> None:
"""|coro| """|coro|
Given an :class:`~discord.Interaction`, calls the matching Given an :class:`~discord.Interaction`, calls the matching

53
discord/asset.py

@ -39,6 +39,13 @@ __all__ = (
# fmt: on # fmt: on
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from .state import ConnectionState
from .webhook.async_ import _WebhookState
_State = Union[ConnectionState, _WebhookState]
ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png'] ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png']
ValidAssetFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif'] ValidAssetFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif']
@ -77,7 +84,7 @@ class AssetMixin:
return await self._state.http.get_from_cdn(self.url) return await self._state.http.get_from_cdn(self.url)
async def save(self, fp: Union[str, bytes, os.PathLike, io.BufferedIOBase], *, seek_begin: bool = True) -> int: async def save(self, fp: Union[str, bytes, os.PathLike[Any], io.BufferedIOBase], *, seek_begin: bool = True) -> int:
"""|coro| """|coro|
Saves this asset into a file-like object. Saves this asset into a file-like object.
@ -153,14 +160,14 @@ class Asset(AssetMixin):
BASE = 'https://cdn.discordapp.com' BASE = 'https://cdn.discordapp.com'
def __init__(self, state, *, url: str, key: str, animated: bool = False): def __init__(self, state: _State, *, url: str, key: str, animated: bool = False) -> None:
self._state = state self._state: _State = state
self._url = url self._url: str = url
self._animated = animated self._animated: bool = animated
self._key = key self._key: str = key
@classmethod @classmethod
def _from_default_avatar(cls, state, index: int) -> Asset: def _from_default_avatar(cls, state: _State, index: int) -> Self:
return cls( return cls(
state, state,
url=f'{cls.BASE}/embed/avatars/{index}.png', url=f'{cls.BASE}/embed/avatars/{index}.png',
@ -169,7 +176,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset: def _from_avatar(cls, state: _State, user_id: int, avatar: str) -> Self:
animated = avatar.startswith('a_') animated = avatar.startswith('a_')
format = 'gif' if animated else 'png' format = 'gif' if animated else 'png'
return cls( return cls(
@ -180,7 +187,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset: def _from_guild_avatar(cls, state: _State, guild_id: int, member_id: int, avatar: str) -> Self:
animated = avatar.startswith('a_') animated = avatar.startswith('a_')
format = 'gif' if animated else 'png' format = 'gif' if animated else 'png'
return cls( return cls(
@ -191,7 +198,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset: def _from_icon(cls, state: _State, object_id: int, icon_hash: str, path: str) -> Self:
return cls( return cls(
state, state,
url=f'{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024', url=f'{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024',
@ -200,7 +207,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset: def _from_cover_image(cls, state: _State, object_id: int, cover_image_hash: str) -> Self:
return cls( return cls(
state, state,
url=f'{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024', url=f'{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024',
@ -209,7 +216,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_scheduled_event_cover_image(cls, state, scheduled_event_id: int, cover_image_hash: str) -> Asset: def _from_scheduled_event_cover_image(cls, state: _State, scheduled_event_id: int, cover_image_hash: str) -> Self:
return cls( return cls(
state, state,
url=f'{cls.BASE}/guild-events/{scheduled_event_id}/{cover_image_hash}.png?size=1024', url=f'{cls.BASE}/guild-events/{scheduled_event_id}/{cover_image_hash}.png?size=1024',
@ -218,7 +225,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset: def _from_guild_image(cls, state: _State, guild_id: int, image: str, path: str) -> Self:
animated = image.startswith('a_') animated = image.startswith('a_')
format = 'gif' if animated else 'png' format = 'gif' if animated else 'png'
return cls( return cls(
@ -229,7 +236,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset: def _from_guild_icon(cls, state: _State, guild_id: int, icon_hash: str) -> Self:
animated = icon_hash.startswith('a_') animated = icon_hash.startswith('a_')
format = 'gif' if animated else 'png' format = 'gif' if animated else 'png'
return cls( return cls(
@ -240,7 +247,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_sticker_banner(cls, state, banner: int) -> Asset: def _from_sticker_banner(cls, state: _State, banner: int) -> Self:
return cls( return cls(
state, state,
url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png', url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png',
@ -249,7 +256,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset: def _from_user_banner(cls, state: _State, user_id: int, banner_hash: str) -> Self:
animated = banner_hash.startswith('a_') animated = banner_hash.startswith('a_')
format = 'gif' if animated else 'png' format = 'gif' if animated else 'png'
return cls( return cls(
@ -265,14 +272,14 @@ class Asset(AssetMixin):
def __len__(self) -> int: def __len__(self) -> int:
return len(self._url) return len(self._url)
def __repr__(self): def __repr__(self) -> str:
shorten = self._url.replace(self.BASE, '') shorten = self._url.replace(self.BASE, '')
return f'<Asset url={shorten!r}>' return f'<Asset url={shorten!r}>'
def __eq__(self, other): def __eq__(self, other: object) -> bool:
return isinstance(other, Asset) and self._url == other._url return isinstance(other, Asset) and self._url == other._url
def __hash__(self): def __hash__(self) -> int:
return hash(self._url) return hash(self._url)
@property @property
@ -295,7 +302,7 @@ class Asset(AssetMixin):
size: int = MISSING, size: int = MISSING,
format: ValidAssetFormatTypes = MISSING, format: ValidAssetFormatTypes = MISSING,
static_format: ValidStaticFormatTypes = MISSING, static_format: ValidStaticFormatTypes = MISSING,
) -> Asset: ) -> Self:
"""Returns a new asset with the passed components replaced. """Returns a new asset with the passed components replaced.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
@ -350,7 +357,7 @@ class Asset(AssetMixin):
url = str(url) url = str(url)
return Asset(state=self._state, url=url, key=self._key, animated=self._animated) return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_size(self, size: int, /) -> Asset: def with_size(self, size: int, /) -> Self:
"""Returns a new asset with the specified size. """Returns a new asset with the specified size.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
@ -378,7 +385,7 @@ class Asset(AssetMixin):
url = str(yarl.URL(self._url).with_query(size=size)) url = str(yarl.URL(self._url).with_query(size=size))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated) return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_format(self, format: ValidAssetFormatTypes, /) -> Asset: def with_format(self, format: ValidAssetFormatTypes, /) -> Self:
"""Returns a new asset with the specified format. """Returns a new asset with the specified format.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
@ -413,7 +420,7 @@ class Asset(AssetMixin):
url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string)) url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated) return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_static_format(self, format: ValidStaticFormatTypes, /) -> Asset: def with_static_format(self, format: ValidStaticFormatTypes, /) -> Self:
"""Returns a new asset with the specified static format. """Returns a new asset with the specified static format.
This only changes the format if the underlying asset is This only changes the format if the underlying asset is

25
discord/audit_logs.py

@ -50,12 +50,12 @@ if TYPE_CHECKING:
from .member import Member from .member import Member
from .role import Role from .role import Role
from .scheduled_event import ScheduledEvent from .scheduled_event import ScheduledEvent
from .state import ConnectionState
from .types.audit_log import ( from .types.audit_log import (
AuditLogChange as AuditLogChangePayload, AuditLogChange as AuditLogChangePayload,
AuditLogEntry as AuditLogEntryPayload, AuditLogEntry as AuditLogEntryPayload,
) )
from .types.channel import ( from .types.channel import (
PartialChannel as PartialChannelPayload,
PermissionOverwrite as PermissionOverwritePayload, PermissionOverwrite as PermissionOverwritePayload,
) )
from .types.invite import Invite as InvitePayload from .types.invite import Invite as InvitePayload
@ -242,8 +242,8 @@ class AuditLogChanges:
# fmt: on # fmt: on
def __init__(self, entry: AuditLogEntry, data: List[AuditLogChangePayload]): def __init__(self, entry: AuditLogEntry, data: List[AuditLogChangePayload]):
self.before = AuditLogDiff() self.before: AuditLogDiff = AuditLogDiff()
self.after = AuditLogDiff() self.after: AuditLogDiff = AuditLogDiff()
for elem in data: for elem in data:
attr = elem['key'] attr = elem['key']
@ -390,17 +390,17 @@ class AuditLogEntry(Hashable):
""" """
def __init__(self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild): def __init__(self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild):
self._state = guild._state self._state: ConnectionState = guild._state
self.guild = guild self.guild: Guild = guild
self._users = users self._users: Dict[int, User] = users
self._from_data(data) self._from_data(data)
def _from_data(self, data: AuditLogEntryPayload) -> None: def _from_data(self, data: AuditLogEntryPayload) -> None:
self.action = enums.try_enum(enums.AuditLogAction, data['action_type']) self.action: enums.AuditLogAction = enums.try_enum(enums.AuditLogAction, data['action_type'])
self.id = int(data['id']) self.id: int = int(data['id'])
# this key is technically not usually present # this key is technically not usually present
self.reason = data.get('reason') self.reason: Optional[str] = data.get('reason')
extra = data.get('options') extra = data.get('options')
# fmt: off # fmt: off
@ -464,10 +464,13 @@ class AuditLogEntry(Hashable):
self._changes = data.get('changes', []) self._changes = data.get('changes', [])
user_id = utils._get_as_snowflake(data, 'user_id') user_id = utils._get_as_snowflake(data, 'user_id')
self.user = user_id and self._get_member(user_id) self.user: Optional[Union[User, Member]] = self._get_member(user_id)
self._target_id = utils._get_as_snowflake(data, 'target_id') self._target_id = utils._get_as_snowflake(data, 'target_id')
def _get_member(self, user_id: int) -> Union[Member, User, None]: def _get_member(self, user_id: Optional[int]) -> Union[Member, User, None]:
if user_id is None:
return None
return self.guild.get_member(user_id) or self._users.get(user_id) return self.guild.get_member(user_id) or self._users.get(user_id)
def __repr__(self) -> str: def __repr__(self) -> str:

22
discord/channel.py

@ -198,7 +198,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id') self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id')
self._fill_overwrites(data) self._fill_overwrites(data)
async def _get_channel(self): async def _get_channel(self) -> Self:
return self return self
@property @property
@ -283,7 +283,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
async def edit(self) -> Optional[TextChannel]: async def edit(self) -> Optional[TextChannel]:
... ...
async def edit(self, *, reason=None, **options): async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[TextChannel]:
"""|coro| """|coro|
Edits the channel. Edits the channel.
@ -908,7 +908,7 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha
return self.guild.id, self.id return self.guild.id, self.id
def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None: def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None:
self.guild = guild self.guild: Guild = guild
self.name: str = data['name'] self.name: str = data['name']
self.rtc_region: Optional[str] = data.get('rtc_region') self.rtc_region: Optional[str] = data.get('rtc_region')
self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1)) self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1))
@ -1076,7 +1076,7 @@ class VoiceChannel(VocalGuildChannel):
async def edit(self) -> Optional[VoiceChannel]: async def edit(self) -> Optional[VoiceChannel]:
... ...
async def edit(self, *, reason=None, **options): async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[VoiceChannel]:
"""|coro| """|coro|
Edits the channel. Edits the channel.
@ -1220,7 +1220,7 @@ class StageChannel(VocalGuildChannel):
def _update(self, guild: Guild, data: StageChannelPayload) -> None: def _update(self, guild: Guild, data: StageChannelPayload) -> None:
super()._update(guild, data) super()._update(guild, data)
self.topic = data.get('topic') self.topic: Optional[str] = data.get('topic')
@property @property
def requesting_to_speak(self) -> List[Member]: def requesting_to_speak(self) -> List[Member]:
@ -1361,7 +1361,7 @@ class StageChannel(VocalGuildChannel):
async def edit(self) -> Optional[StageChannel]: async def edit(self) -> Optional[StageChannel]:
... ...
async def edit(self, *, reason=None, **options): async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[StageChannel]:
"""|coro| """|coro|
Edits the channel. Edits the channel.
@ -1522,7 +1522,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
async def edit(self) -> Optional[CategoryChannel]: async def edit(self) -> Optional[CategoryChannel]:
... ...
async def edit(self, *, reason=None, **options): async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[CategoryChannel]:
"""|coro| """|coro|
Edits the channel. Edits the channel.
@ -1578,7 +1578,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
@utils.copy_doc(discord.abc.GuildChannel.move) @utils.copy_doc(discord.abc.GuildChannel.move)
async def move(self, **kwargs): async def move(self, **kwargs: Any) -> None:
kwargs.pop('category', None) kwargs.pop('category', None)
await super().move(**kwargs) await super().move(**kwargs)
@ -1772,7 +1772,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
async def edit(self) -> Optional[StoreChannel]: async def edit(self) -> Optional[StoreChannel]:
... ...
async def edit(self, *, reason=None, **options): async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[StoreChannel]:
"""|coro| """|coro|
Edits the channel. Edits the channel.
@ -1874,7 +1874,7 @@ class DMChannel(discord.abc.Messageable, Hashable):
self.me: ClientUser = me self.me: ClientUser = me
self.id: int = int(data['id']) self.id: int = int(data['id'])
async def _get_channel(self): async def _get_channel(self) -> Self:
return self return self
def __str__(self) -> str: def __str__(self) -> str:
@ -2026,7 +2026,7 @@ class GroupChannel(discord.abc.Messageable, Hashable):
else: else:
self.owner = utils.find(lambda u: u.id == self.owner_id, self.recipients) self.owner = utils.find(lambda u: u.id == self.owner_id, self.recipients)
async def _get_channel(self): async def _get_channel(self) -> Self:
return self return self
def __str__(self) -> str: def __str__(self) -> str:

10
discord/client.py

@ -196,11 +196,11 @@ class Client:
unsync_clock: bool = options.pop('assume_unsync_clock', True) unsync_clock: bool = options.pop('assume_unsync_clock', True)
self.http: HTTPClient = HTTPClient(self.loop, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock) self.http: HTTPClient = HTTPClient(self.loop, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock)
self._handlers: Dict[str, Callable] = { self._handlers: Dict[str, Callable[..., None]] = {
'ready': self._handle_ready, 'ready': self._handle_ready,
} }
self._hooks: Dict[str, Callable] = { self._hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = {
'before_identify': self._call_before_identify_hook, 'before_identify': self._call_before_identify_hook,
} }
@ -698,7 +698,7 @@ class Client:
raise TypeError('activity must derive from BaseActivity.') raise TypeError('activity must derive from BaseActivity.')
@property @property
def status(self): def status(self) -> Status:
""":class:`.Status`: """:class:`.Status`:
The status being used upon logging on to Discord. The status being used upon logging on to Discord.
@ -709,7 +709,7 @@ class Client:
return Status.online return Status.online
@status.setter @status.setter
def status(self, value): def status(self, value: Status) -> None:
if value is Status.offline: if value is Status.offline:
self._connection._status = 'invisible' self._connection._status = 'invisible'
elif isinstance(value, Status): elif isinstance(value, Status):
@ -1077,7 +1077,7 @@ class Client:
*, *,
activity: Optional[BaseActivity] = None, activity: Optional[BaseActivity] = None,
status: Optional[Status] = None, status: Optional[Status] = None,
): ) -> None:
"""|coro| """|coro|
Changes the client's presence. Changes the client's presence.

13
discord/colour.py

@ -32,7 +32,6 @@ from typing import (
Callable, Callable,
Optional, Optional,
Tuple, Tuple,
Type,
Union, Union,
) )
@ -90,10 +89,10 @@ class Colour:
def _get_byte(self, byte: int) -> int: def _get_byte(self, byte: int) -> int:
return (self.value >> (8 * byte)) & 0xFF return (self.value >> (8 * byte)) & 0xFF
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, Colour) and self.value == other.value return isinstance(other, Colour) and self.value == other.value
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __str__(self) -> str: def __str__(self) -> str:
@ -265,28 +264,28 @@ class Colour:
"""A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``.""" """A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``."""
return cls(0x95A5A6) return cls(0x95A5A6)
lighter_gray: Callable[[Type[Self]], Self] = lighter_grey lighter_gray = lighter_grey
@classmethod @classmethod
def dark_grey(cls) -> Self: def dark_grey(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x607d8b``.""" """A factory method that returns a :class:`Colour` with a value of ``0x607d8b``."""
return cls(0x607D8B) return cls(0x607D8B)
dark_gray: Callable[[Type[Self]], Self] = dark_grey dark_gray = dark_grey
@classmethod @classmethod
def light_grey(cls) -> Self: def light_grey(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x979c9f``.""" """A factory method that returns a :class:`Colour` with a value of ``0x979c9f``."""
return cls(0x979C9F) return cls(0x979C9F)
light_gray: Callable[[Type[Self]], Self] = light_grey light_gray = light_grey
@classmethod @classmethod
def darker_grey(cls) -> Self: def darker_grey(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x546e7a``.""" """A factory method that returns a :class:`Colour` with a value of ``0x546e7a``."""
return cls(0x546E7A) return cls(0x546E7A)
darker_gray: Callable[[Type[Self]], Self] = darker_grey darker_gray = darker_grey
@classmethod @classmethod
def og_blurple(cls) -> Self: def og_blurple(cls) -> Self:

10
discord/components.py

@ -310,9 +310,9 @@ class SelectOption:
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None, emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
default: bool = False, default: bool = False,
) -> None: ) -> None:
self.label = label self.label: str = label
self.value = label if value is MISSING else value self.value: str = label if value is MISSING else value
self.description = description self.description: Optional[str] = description
if emoji is not None: if emoji is not None:
if isinstance(emoji, str): if isinstance(emoji, str):
@ -322,8 +322,8 @@ class SelectOption:
else: else:
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}') raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
self.emoji = emoji self.emoji: Optional[Union[str, Emoji, PartialEmoji]] = emoji
self.default = default self.default: bool = default
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (

10
discord/context_managers.py

@ -25,13 +25,15 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from typing import TYPE_CHECKING, Optional, Type from typing import TYPE_CHECKING, Optional, Type, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from .abc import Messageable from .abc import Messageable
from types import TracebackType from types import TracebackType
BE = TypeVar('BE', bound=BaseException)
# fmt: off # fmt: off
__all__ = ( __all__ = (
'Typing', 'Typing',
@ -67,13 +69,13 @@ class Typing:
async def __aenter__(self) -> None: async def __aenter__(self) -> None:
self._channel = channel = await self.messageable._get_channel() self._channel = channel = await self.messageable._get_channel()
await channel._state.http.send_typing(channel.id) await channel._state.http.send_typing(channel.id)
self.task: asyncio.Task = self.loop.create_task(self.do_typing()) self.task: asyncio.Task[None] = self.loop.create_task(self.do_typing())
self.task.add_done_callback(_typing_done_callback) self.task.add_done_callback(_typing_done_callback)
async def __aexit__( async def __aexit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: Optional[Type[BE]],
exc_value: Optional[BaseException], exc: Optional[BE],
traceback: Optional[TracebackType], traceback: Optional[TracebackType],
) -> None: ) -> None:
self.task.cancel() self.task.cancel()

12
discord/embeds.py

@ -189,10 +189,10 @@ class Embed:
): ):
self.colour = colour if colour is not EmptyEmbed else color self.colour = colour if colour is not EmptyEmbed else color
self.title = title self.title: MaybeEmpty[str] = title
self.type = type self.type: EmbedType = type
self.url = url self.url: MaybeEmpty[str] = url
self.description = description self.description: MaybeEmpty[str] = description
if self.title is not EmptyEmbed: if self.title is not EmptyEmbed:
self.title = str(self.title) self.title = str(self.title)
@ -311,7 +311,7 @@ class Embed:
return getattr(self, '_colour', EmptyEmbed) return getattr(self, '_colour', EmptyEmbed)
@colour.setter @colour.setter
def colour(self, value: Union[int, Colour, _EmptyEmbed]): def colour(self, value: Union[int, Colour, _EmptyEmbed]) -> None:
if isinstance(value, (Colour, _EmptyEmbed)): if isinstance(value, (Colour, _EmptyEmbed)):
self._colour = value self._colour = value
elif isinstance(value, int): elif isinstance(value, int):
@ -326,7 +326,7 @@ class Embed:
return getattr(self, '_timestamp', EmptyEmbed) return getattr(self, '_timestamp', EmptyEmbed)
@timestamp.setter @timestamp.setter
def timestamp(self, value: MaybeEmpty[datetime.datetime]): def timestamp(self, value: MaybeEmpty[datetime.datetime]) -> None:
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
if value.tzinfo is None: if value.tzinfo is None:
value = value.astimezone() value = value.astimezone()

4
discord/emoji.py

@ -142,10 +142,10 @@ class Emoji(_EmojiTag, AssetMixin):
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>' return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>'
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, _EmojiTag) and self.id == other.id return isinstance(other, _EmojiTag) and self.id == other.id
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:

44
discord/enums.py

@ -25,7 +25,7 @@ from __future__ import annotations
import types import types
from collections import namedtuple from collections import namedtuple
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Iterator, Mapping
__all__ = ( __all__ = (
'Enum', 'Enum',
@ -131,38 +131,38 @@ class EnumMeta(type):
value_cls._actual_enum_cls_ = actual_cls # type: ignore - Runtime attribute isn't understood value_cls._actual_enum_cls_ = actual_cls # type: ignore - Runtime attribute isn't understood
return actual_cls return actual_cls
def __iter__(cls): def __iter__(cls) -> Iterator[Any]:
return (cls._enum_member_map_[name] for name in cls._enum_member_names_) return (cls._enum_member_map_[name] for name in cls._enum_member_names_)
def __reversed__(cls): def __reversed__(cls) -> Iterator[Any]:
return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_)) return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_))
def __len__(cls): def __len__(cls) -> int:
return len(cls._enum_member_names_) return len(cls._enum_member_names_)
def __repr__(cls): def __repr__(cls) -> str:
return f'<enum {cls.__name__}>' return f'<enum {cls.__name__}>'
@property @property
def __members__(cls): def __members__(cls) -> Mapping[str, Any]:
return types.MappingProxyType(cls._enum_member_map_) return types.MappingProxyType(cls._enum_member_map_)
def __call__(cls, value): def __call__(cls, value: str) -> Any:
try: try:
return cls._enum_value_map_[value] return cls._enum_value_map_[value]
except (KeyError, TypeError): except (KeyError, TypeError):
raise ValueError(f"{value!r} is not a valid {cls.__name__}") raise ValueError(f"{value!r} is not a valid {cls.__name__}")
def __getitem__(cls, key): def __getitem__(cls, key: str) -> Any:
return cls._enum_member_map_[key] return cls._enum_member_map_[key]
def __setattr__(cls, name, value): def __setattr__(cls, name: str, value: Any) -> None:
raise TypeError('Enums are immutable.') raise TypeError('Enums are immutable.')
def __delattr__(cls, attr): def __delattr__(cls, attr: str) -> None:
raise TypeError('Enums are immutable') raise TypeError('Enums are immutable')
def __instancecheck__(self, instance): def __instancecheck__(self, instance: Any) -> bool:
# isinstance(x, Y) # isinstance(x, Y)
# -> __instancecheck__(Y, x) # -> __instancecheck__(Y, x)
try: try:
@ -197,7 +197,7 @@ class ChannelType(Enum):
private_thread = 12 private_thread = 12
stage_voice = 13 stage_voice = 13
def __str__(self): def __str__(self) -> str:
return self.name return self.name
@ -233,10 +233,10 @@ class SpeakingState(Enum):
soundshare = 2 soundshare = 2
priority = 4 priority = 4
def __str__(self): def __str__(self) -> str:
return self.name return self.name
def __int__(self): def __int__(self) -> int:
return self.value return self.value
@ -247,7 +247,7 @@ class VerificationLevel(Enum, comparable=True):
high = 3 high = 3
highest = 4 highest = 4
def __str__(self): def __str__(self) -> str:
return self.name return self.name
@ -256,7 +256,7 @@ class ContentFilter(Enum, comparable=True):
no_role = 1 no_role = 1
all_members = 2 all_members = 2
def __str__(self): def __str__(self) -> str:
return self.name return self.name
@ -268,7 +268,7 @@ class Status(Enum):
do_not_disturb = 'dnd' do_not_disturb = 'dnd'
invisible = 'invisible' invisible = 'invisible'
def __str__(self): def __str__(self) -> str:
return self.value return self.value
@ -280,7 +280,7 @@ class DefaultAvatar(Enum):
orange = 3 orange = 3
red = 4 red = 4
def __str__(self): def __str__(self) -> str:
return self.name return self.name
@ -467,7 +467,7 @@ class ActivityType(Enum):
custom = 4 custom = 4
competing = 5 competing = 5
def __int__(self): def __int__(self) -> int:
return self.value return self.value
@ -542,7 +542,7 @@ class VideoQualityMode(Enum):
auto = 1 auto = 1
full = 2 full = 2
def __int__(self): def __int__(self) -> int:
return self.value return self.value
@ -552,7 +552,7 @@ class ComponentType(Enum):
select = 3 select = 3
text_input = 4 text_input = 4
def __int__(self): def __int__(self) -> int:
return self.value return self.value
@ -571,7 +571,7 @@ class ButtonStyle(Enum):
red = 4 red = 4
url = 5 url = 5
def __int__(self): def __int__(self) -> int:
return self.value return self.value

18
discord/ext/commands/_types.py

@ -23,21 +23,35 @@ DEALINGS IN THE SOFTWARE.
""" """
from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple
T = TypeVar('T')
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import ParamSpec
from .bot import Bot, AutoShardedBot
from .context import Context from .context import Context
from .cog import Cog from .cog import Cog
from .errors import CommandError from .errors import CommandError
T = TypeVar('T') P = ParamSpec('P')
MaybeCoroFunc = Union[
Callable[P, 'Coro[T]'],
Callable[P, T],
]
else:
P = TypeVar('P')
MaybeCoroFunc = Tuple[P, T]
Coro = Coroutine[Any, Any, T] Coro = Coroutine[Any, Any, T]
MaybeCoro = Union[T, Coro[T]] MaybeCoro = Union[T, Coro[T]]
CoroFunc = Callable[..., Coro[Any]] CoroFunc = Callable[..., Coro[Any]]
ContextT = TypeVar('ContextT', bound='Context') ContextT = TypeVar('ContextT', bound='Context')
_Bot = Union['Bot', 'AutoShardedBot']
BotT = TypeVar('BotT', bound=_Bot)
Check = Union[Callable[["Cog", "ContextT"], MaybeCoro[bool]], Callable[["ContextT"], MaybeCoro[bool]]] Check = Union[Callable[["Cog", "ContextT"], MaybeCoro[bool]], Callable[["ContextT"], MaybeCoro[bool]]]
Hook = Union[Callable[["Cog", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]] Hook = Union[Callable[["Cog", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]]

88
discord/ext/commands/bot.py

@ -33,7 +33,21 @@ import importlib.util
import sys import sys
import traceback import traceback
import types import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union, overload from typing import (
Any,
Callable,
Mapping,
List,
Dict,
TYPE_CHECKING,
Optional,
TypeVar,
Type,
Union,
Iterable,
Collection,
overload,
)
import discord import discord
from discord import app_commands from discord import app_commands
@ -55,10 +69,18 @@ if TYPE_CHECKING:
from discord.message import Message from discord.message import Message
from discord.abc import User, Snowflake from discord.abc import User, Snowflake
from ._types import ( from ._types import (
_Bot,
BotT,
Check, Check,
CoroFunc, CoroFunc,
ContextT,
MaybeCoroFunc,
) )
_Prefix = Union[Iterable[str], str]
_PrefixCallable = MaybeCoroFunc[[BotT, Message], _Prefix]
PrefixType = Union[_Prefix, _PrefixCallable[BotT]]
__all__ = ( __all__ = (
'when_mentioned', 'when_mentioned',
'when_mentioned_or', 'when_mentioned_or',
@ -68,11 +90,9 @@ __all__ = (
T = TypeVar('T') T = TypeVar('T')
CFT = TypeVar('CFT', bound='CoroFunc') CFT = TypeVar('CFT', bound='CoroFunc')
CXT = TypeVar('CXT', bound='Context')
BT = TypeVar('BT', bound='Union[Bot, AutoShardedBot]')
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]: def when_mentioned(bot: _Bot, msg: Message) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned. """A callable that implements a command prefix equivalent to being mentioned.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
@ -81,7 +101,7 @@ def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]: def when_mentioned_or(*prefixes: str) -> Callable[[_Bot, Message], List[str]]:
"""A callable that implements when mentioned or other prefixes provided. """A callable that implements when mentioned or other prefixes provided.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
@ -124,27 +144,33 @@ class _DefaultRepr:
return '<default-help-command>' return '<default-help-command>'
_default = _DefaultRepr() _default: Any = _DefaultRepr()
class BotBase(GroupMixin): class BotBase(GroupMixin[None]):
def __init__(self, command_prefix, help_command=_default, description=None, **options): def __init__(
self,
command_prefix: PrefixType[BotT],
help_command: HelpCommand = _default,
description: Optional[str] = None,
**options: Any,
) -> None:
super().__init__(**options) super().__init__(**options)
self.command_prefix = command_prefix self.command_prefix: PrefixType[BotT] = command_prefix
self.extra_events: Dict[str, List[CoroFunc]] = {} self.extra_events: Dict[str, List[CoroFunc]] = {}
# Self doesn't have the ClientT bound, but since this is a mixin it technically does # Self doesn't have the ClientT bound, but since this is a mixin it technically does
self.__tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self) # type: ignore self.__tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self) # type: ignore
self.__cogs: Dict[str, Cog] = {} self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {} self.__extensions: Dict[str, types.ModuleType] = {}
self._checks: List[Check] = [] self._checks: List[Check] = []
self._check_once = [] self._check_once: List[Check] = []
self._before_invoke = None self._before_invoke: Optional[CoroFunc] = None
self._after_invoke = None self._after_invoke: Optional[CoroFunc] = None
self._help_command = None self._help_command: Optional[HelpCommand] = None
self.description = inspect.cleandoc(description) if description else '' self.description: str = inspect.cleandoc(description) if description else ''
self.owner_id = options.get('owner_id') self.owner_id: Optional[int] = options.get('owner_id')
self.owner_ids = options.get('owner_ids', set()) self.owner_ids: Optional[Collection[int]] = options.get('owner_ids', set())
self.strip_after_prefix = options.get('strip_after_prefix', False) self.strip_after_prefix: bool = options.get('strip_after_prefix', False)
if self.owner_id and self.owner_ids: if self.owner_id and self.owner_ids:
raise TypeError('Both owner_id and owner_ids are set.') raise TypeError('Both owner_id and owner_ids are set.')
@ -182,7 +208,7 @@ class BotBase(GroupMixin):
await super().close() # type: ignore await super().close() # type: ignore
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None: async def on_command_error(self, context: Context[BotT], exception: errors.CommandError) -> None:
"""|coro| """|coro|
The default command error handler provided by the bot. The default command error handler provided by the bot.
@ -237,7 +263,7 @@ class BotBase(GroupMixin):
self.add_check(func) # type: ignore self.add_check(func) # type: ignore
return func return func
def add_check(self, func: Check, /, *, call_once: bool = False) -> None: def add_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
"""Adds a global check to the bot. """Adds a global check to the bot.
This is the non-decorator interface to :meth:`.check` This is the non-decorator interface to :meth:`.check`
@ -261,7 +287,7 @@ class BotBase(GroupMixin):
else: else:
self._checks.append(func) self._checks.append(func)
def remove_check(self, func: Check, /, *, call_once: bool = False) -> None: def remove_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
"""Removes a global check from the bot. """Removes a global check from the bot.
This function is idempotent and will not raise an exception This function is idempotent and will not raise an exception
@ -324,7 +350,7 @@ class BotBase(GroupMixin):
self.add_check(func, call_once=True) self.add_check(func, call_once=True)
return func return func
async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool: async def can_run(self, ctx: Context[BotT], *, call_once: bool = False) -> bool:
data = self._check_once if call_once else self._checks data = self._check_once if call_once else self._checks
if len(data) == 0: if len(data) == 0:
@ -947,7 +973,7 @@ class BotBase(GroupMixin):
# if the load failed, the remnants should have been # if the load failed, the remnants should have been
# cleaned from the load_extension function call # cleaned from the load_extension function call
# so let's load it from our old compiled library. # so let's load it from our old compiled library.
await lib.setup(self) # type: ignore await lib.setup(self)
self.__extensions[name] = lib self.__extensions[name] = lib
# revert sys.modules back to normal and raise back to caller # revert sys.modules back to normal and raise back to caller
@ -1015,11 +1041,12 @@ class BotBase(GroupMixin):
""" """
prefix = ret = self.command_prefix prefix = ret = self.command_prefix
if callable(prefix): if callable(prefix):
ret = await discord.utils.maybe_coroutine(prefix, self, message) # self will be a Bot or AutoShardedBot
ret = await discord.utils.maybe_coroutine(prefix, self, message) # type: ignore
if not isinstance(ret, str): if not isinstance(ret, str):
try: try:
ret = list(ret) ret = list(ret) # type: ignore
except TypeError: except TypeError:
# It's possible that a generator raised this exception. Don't # It's possible that a generator raised this exception. Don't
# replace it with our own error if that's the case. # replace it with our own error if that's the case.
@ -1048,15 +1075,15 @@ class BotBase(GroupMixin):
self, self,
message: Message, message: Message,
*, *,
cls: Type[CXT] = ..., cls: Type[ContextT] = ...,
) -> CXT: # type: ignore ) -> ContextT:
... ...
async def get_context( async def get_context(
self, self,
message: Message, message: Message,
*, *,
cls: Type[CXT] = MISSING, cls: Type[ContextT] = MISSING,
) -> Any: ) -> Any:
r"""|coro| r"""|coro|
@ -1137,7 +1164,7 @@ class BotBase(GroupMixin):
ctx.command = self.all_commands.get(invoker) ctx.command = self.all_commands.get(invoker)
return ctx return ctx
async def invoke(self, ctx: Context) -> None: async def invoke(self, ctx: Context[BotT]) -> None:
"""|coro| """|coro|
Invokes the command given under the invocation context and Invokes the command given under the invocation context and
@ -1189,9 +1216,10 @@ class BotBase(GroupMixin):
return return
ctx = await self.get_context(message) ctx = await self.get_context(message)
await self.invoke(ctx) # the type of the invocation context's bot attribute will be correct
await self.invoke(ctx) # type: ignore
async def on_message(self, message): async def on_message(self, message: Message) -> None:
await self.process_commands(message) await self.process_commands(message)

16
discord/ext/commands/cog.py

@ -30,7 +30,7 @@ from discord.utils import maybe_coroutine
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union
from ._types import _BaseCommand from ._types import _BaseCommand, BotT
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self from typing_extensions import Self
@ -112,7 +112,7 @@ class CogMeta(type):
__cog_name__: str __cog_name__: str
__cog_settings__: Dict[str, Any] __cog_settings__: Dict[str, Any]
__cog_commands__: List[Command] __cog_commands__: List[Command[Any, ..., Any]]
__cog_is_app_commands_group__: bool __cog_is_app_commands_group__: bool
__cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Any, ..., Any]]] __cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Any, ..., Any]]]
__cog_listeners__: List[Tuple[str, str]] __cog_listeners__: List[Tuple[str, str]]
@ -406,7 +406,7 @@ class Cog(metaclass=CogMeta):
pass pass
@_cog_special_method @_cog_special_method
def bot_check_once(self, ctx: Context) -> bool: def bot_check_once(self, ctx: Context[BotT]) -> bool:
"""A special method that registers as a :meth:`.Bot.check_once` """A special method that registers as a :meth:`.Bot.check_once`
check. check.
@ -416,7 +416,7 @@ class Cog(metaclass=CogMeta):
return True return True
@_cog_special_method @_cog_special_method
def bot_check(self, ctx: Context) -> bool: def bot_check(self, ctx: Context[BotT]) -> bool:
"""A special method that registers as a :meth:`.Bot.check` """A special method that registers as a :meth:`.Bot.check`
check. check.
@ -426,7 +426,7 @@ class Cog(metaclass=CogMeta):
return True return True
@_cog_special_method @_cog_special_method
def cog_check(self, ctx: Context) -> bool: def cog_check(self, ctx: Context[BotT]) -> bool:
"""A special method that registers as a :func:`~discord.ext.commands.check` """A special method that registers as a :func:`~discord.ext.commands.check`
for every command and subcommand in this cog. for every command and subcommand in this cog.
@ -436,7 +436,7 @@ class Cog(metaclass=CogMeta):
return True return True
@_cog_special_method @_cog_special_method
async def cog_command_error(self, ctx: Context, error: Exception) -> None: async def cog_command_error(self, ctx: Context[BotT], error: Exception) -> None:
"""A special method that is called whenever an error """A special method that is called whenever an error
is dispatched inside this cog. is dispatched inside this cog.
@ -455,7 +455,7 @@ class Cog(metaclass=CogMeta):
pass pass
@_cog_special_method @_cog_special_method
async def cog_before_invoke(self, ctx: Context) -> None: async def cog_before_invoke(self, ctx: Context[BotT]) -> None:
"""A special method that acts as a cog local pre-invoke hook. """A special method that acts as a cog local pre-invoke hook.
This is similar to :meth:`.Command.before_invoke`. This is similar to :meth:`.Command.before_invoke`.
@ -470,7 +470,7 @@ class Cog(metaclass=CogMeta):
pass pass
@_cog_special_method @_cog_special_method
async def cog_after_invoke(self, ctx: Context) -> None: async def cog_after_invoke(self, ctx: Context[BotT]) -> None:
"""A special method that acts as a cog local post-invoke hook. """A special method that acts as a cog local post-invoke hook.
This is similar to :meth:`.Command.after_invoke`. This is similar to :meth:`.Command.after_invoke`.

15
discord/ext/commands/context.py

@ -28,6 +28,8 @@ import re
from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union
from ._types import BotT
import discord.abc import discord.abc
import discord.utils import discord.utils
@ -59,7 +61,6 @@ MISSING: Any = discord.utils.MISSING
T = TypeVar('T') T = TypeVar('T')
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar('CogT', bound="Cog") CogT = TypeVar('CogT', bound="Cog")
if TYPE_CHECKING: if TYPE_CHECKING:
@ -133,10 +134,10 @@ class Context(discord.abc.Messageable, Generic[BotT]):
args: List[Any] = MISSING, args: List[Any] = MISSING,
kwargs: Dict[str, Any] = MISSING, kwargs: Dict[str, Any] = MISSING,
prefix: Optional[str] = None, prefix: Optional[str] = None,
command: Optional[Command] = None, command: Optional[Command[Any, ..., Any]] = None,
invoked_with: Optional[str] = None, invoked_with: Optional[str] = None,
invoked_parents: List[str] = MISSING, invoked_parents: List[str] = MISSING,
invoked_subcommand: Optional[Command] = None, invoked_subcommand: Optional[Command[Any, ..., Any]] = None,
subcommand_passed: Optional[str] = None, subcommand_passed: Optional[str] = None,
command_failed: bool = False, command_failed: bool = False,
current_parameter: Optional[inspect.Parameter] = None, current_parameter: Optional[inspect.Parameter] = None,
@ -146,11 +147,11 @@ class Context(discord.abc.Messageable, Generic[BotT]):
self.args: List[Any] = args or [] self.args: List[Any] = args or []
self.kwargs: Dict[str, Any] = kwargs or {} self.kwargs: Dict[str, Any] = kwargs or {}
self.prefix: Optional[str] = prefix self.prefix: Optional[str] = prefix
self.command: Optional[Command] = command self.command: Optional[Command[Any, ..., Any]] = command
self.view: StringView = view self.view: StringView = view
self.invoked_with: Optional[str] = invoked_with self.invoked_with: Optional[str] = invoked_with
self.invoked_parents: List[str] = invoked_parents or [] self.invoked_parents: List[str] = invoked_parents or []
self.invoked_subcommand: Optional[Command] = invoked_subcommand self.invoked_subcommand: Optional[Command[Any, ..., Any]] = invoked_subcommand
self.subcommand_passed: Optional[str] = subcommand_passed self.subcommand_passed: Optional[str] = subcommand_passed
self.command_failed: bool = command_failed self.command_failed: bool = command_failed
self.current_parameter: Optional[inspect.Parameter] = current_parameter self.current_parameter: Optional[inspect.Parameter] = current_parameter
@ -361,7 +362,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
return None return None
cmd = cmd.copy() cmd = cmd.copy()
cmd.context = self cmd.context = self # type: ignore
if len(args) == 0: if len(args) == 0:
await cmd.prepare_help_command(self, None) await cmd.prepare_help_command(self, None)
mapping = cmd.get_bot_mapping() mapping = cmd.get_bot_mapping()
@ -390,7 +391,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
try: try:
if hasattr(entity, '__cog_commands__'): if hasattr(entity, '__cog_commands__'):
injected = wrap_callback(cmd.send_cog_help) injected = wrap_callback(cmd.send_cog_help)
return await injected(entity) return await injected(entity) # type: ignore
elif isinstance(entity, Group): elif isinstance(entity, Group):
injected = wrap_callback(cmd.send_group_help) injected = wrap_callback(cmd.send_group_help)
return await injected(entity) return await injected(entity)

86
discord/ext/commands/converter.py

@ -41,7 +41,6 @@ from typing import (
Tuple, Tuple,
Union, Union,
runtime_checkable, runtime_checkable,
overload,
) )
import discord import discord
@ -51,9 +50,8 @@ if TYPE_CHECKING:
from .context import Context from .context import Context
from discord.state import Channel from discord.state import Channel
from discord.threads import Thread from discord.threads import Thread
from .bot import Bot, AutoShardedBot
_Bot = TypeVar('_Bot', bound=Union[Bot, AutoShardedBot]) from ._types import BotT, _Bot
__all__ = ( __all__ = (
@ -87,7 +85,7 @@ __all__ = (
) )
def _get_from_guilds(bot, getter, argument): def _get_from_guilds(bot: _Bot, getter: str, argument: Any) -> Any:
result = None result = None
for guild in bot.guilds: for guild in bot.guilds:
result = getattr(guild, getter)(argument) result = getattr(guild, getter)(argument)
@ -115,7 +113,7 @@ class Converter(Protocol[T_co]):
method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`. method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`.
""" """
async def convert(self, ctx: Context, argument: str) -> T_co: async def convert(self, ctx: Context[BotT], argument: str) -> T_co:
"""|coro| """|coro|
The method to override to do conversion logic. The method to override to do conversion logic.
@ -163,7 +161,7 @@ class ObjectConverter(IDConverter[discord.Object]):
2. Lookup by member, role, or channel mention. 2. Lookup by member, role, or channel mention.
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Object: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Object:
match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument) match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument)
if match is None: if match is None:
@ -196,7 +194,7 @@ class MemberConverter(IDConverter[discord.Member]):
optionally caching the result if :attr:`.MemberCacheFlags.joined` is enabled. optionally caching the result if :attr:`.MemberCacheFlags.joined` is enabled.
""" """
async def query_member_named(self, guild, argument): async def query_member_named(self, guild: discord.Guild, argument: str) -> Optional[discord.Member]:
cache = guild._state.member_cache_flags.joined cache = guild._state.member_cache_flags.joined
if len(argument) > 5 and argument[-5] == '#': if len(argument) > 5 and argument[-5] == '#':
username, _, discriminator = argument.rpartition('#') username, _, discriminator = argument.rpartition('#')
@ -206,7 +204,7 @@ class MemberConverter(IDConverter[discord.Member]):
members = await guild.query_members(argument, limit=100, cache=cache) members = await guild.query_members(argument, limit=100, cache=cache)
return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members) return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members)
async def query_member_by_id(self, bot, guild, user_id): async def query_member_by_id(self, bot: _Bot, guild: discord.Guild, user_id: int) -> Optional[discord.Member]:
ws = bot._get_websocket(shard_id=guild.shard_id) ws = bot._get_websocket(shard_id=guild.shard_id)
cache = guild._state.member_cache_flags.joined cache = guild._state.member_cache_flags.joined
if ws.is_ratelimited(): if ws.is_ratelimited():
@ -227,7 +225,7 @@ class MemberConverter(IDConverter[discord.Member]):
return None return None
return members[0] return members[0]
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Member: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Member:
bot = ctx.bot bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
guild = ctx.guild guild = ctx.guild
@ -281,7 +279,7 @@ class UserConverter(IDConverter[discord.User]):
and it's not available in cache. and it's not available in cache.
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.User: async def convert(self, ctx: Context[BotT], argument: str) -> discord.User:
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
result = None result = None
state = ctx._state state = ctx._state
@ -359,7 +357,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod @staticmethod
def _resolve_channel( def _resolve_channel(
ctx: Context[_Bot], guild_id: Optional[int], channel_id: Optional[int] ctx: Context[BotT], guild_id: Optional[int], channel_id: Optional[int]
) -> Optional[Union[Channel, Thread]]: ) -> Optional[Union[Channel, Thread]]:
if channel_id is None: if channel_id is None:
# we were passed just a message id so we can assume the channel is the current context channel # we were passed just a message id so we can assume the channel is the current context channel
@ -373,7 +371,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
return ctx.bot.get_channel(channel_id) return ctx.bot.get_channel(channel_id)
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialMessage: async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialMessage:
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument) guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
channel = self._resolve_channel(ctx, guild_id, channel_id) channel = self._resolve_channel(ctx, guild_id, channel_id)
if not channel or not isinstance(channel, discord.abc.Messageable): if not channel or not isinstance(channel, discord.abc.Messageable):
@ -396,7 +394,7 @@ class MessageConverter(IDConverter[discord.Message]):
Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument` Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Message: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Message:
guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument) guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument)
message = ctx.bot._connection._get_message(message_id) message = ctx.bot._connection._get_message(message_id)
if message: if message:
@ -427,11 +425,11 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.abc.GuildChannel: async def convert(self, ctx: Context[BotT], argument: str) -> discord.abc.GuildChannel:
return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel) return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel)
@staticmethod @staticmethod
def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT: def _resolve_channel(ctx: Context[BotT], argument: str, attribute: str, type: Type[CT]) -> CT:
bot = ctx.bot bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument) match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
@ -448,7 +446,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
def check(c): def check(c):
return isinstance(c, type) and c.name == argument return isinstance(c, type) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels()) result = discord.utils.find(check, bot.get_all_channels()) # type: ignore
else: else:
channel_id = int(match.group(1)) channel_id = int(match.group(1))
if guild: if guild:
@ -463,7 +461,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
return result return result
@staticmethod @staticmethod
def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT: def _resolve_thread(ctx: Context[BotT], argument: str, attribute: str, type: Type[TT]) -> TT:
bot = ctx.bot bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument) match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
@ -502,7 +500,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.TextChannel: async def convert(self, ctx: Context[BotT], argument: str) -> discord.TextChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel) return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel)
@ -522,7 +520,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.VoiceChannel: async def convert(self, ctx: Context[BotT], argument: str) -> discord.VoiceChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel) return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel)
@ -541,7 +539,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
3. Lookup by name 3. Lookup by name
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StageChannel: async def convert(self, ctx: Context[BotT], argument: str) -> discord.StageChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel) return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel)
@ -561,7 +559,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.CategoryChannel: async def convert(self, ctx: Context[BotT], argument: str) -> discord.CategoryChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel) return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel)
@ -580,7 +578,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
.. versionadded:: 1.7 .. versionadded:: 1.7
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StoreChannel: async def convert(self, ctx: Context[BotT], argument: str) -> discord.StoreChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel) return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel)
@ -598,7 +596,7 @@ class ThreadConverter(IDConverter[discord.Thread]):
.. versionadded: 2.0 .. versionadded: 2.0
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Thread: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Thread:
return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread) return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread)
@ -630,7 +628,7 @@ class ColourConverter(Converter[discord.Colour]):
RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)') RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)')
def parse_hex_number(self, argument): def parse_hex_number(self, argument: str) -> discord.Colour:
arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument
try: try:
value = int(arg, base=16) value = int(arg, base=16)
@ -641,7 +639,7 @@ class ColourConverter(Converter[discord.Colour]):
else: else:
return discord.Color(value=value) return discord.Color(value=value)
def parse_rgb_number(self, argument, number): def parse_rgb_number(self, argument: str, number: str) -> int:
if number[-1] == '%': if number[-1] == '%':
value = int(number[:-1]) value = int(number[:-1])
if not (0 <= value <= 100): if not (0 <= value <= 100):
@ -653,7 +651,7 @@ class ColourConverter(Converter[discord.Colour]):
raise BadColourArgument(argument) raise BadColourArgument(argument)
return value return value
def parse_rgb(self, argument, *, regex=RGB_REGEX): def parse_rgb(self, argument: str, *, regex: re.Pattern[str] = RGB_REGEX) -> discord.Colour:
match = regex.match(argument) match = regex.match(argument)
if match is None: if match is None:
raise BadColourArgument(argument) raise BadColourArgument(argument)
@ -663,7 +661,7 @@ class ColourConverter(Converter[discord.Colour]):
blue = self.parse_rgb_number(argument, match.group('b')) blue = self.parse_rgb_number(argument, match.group('b'))
return discord.Color.from_rgb(red, green, blue) return discord.Color.from_rgb(red, green, blue)
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Colour: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Colour:
if argument[0] == '#': if argument[0] == '#':
return self.parse_hex_number(argument[1:]) return self.parse_hex_number(argument[1:])
@ -704,7 +702,7 @@ class RoleConverter(IDConverter[discord.Role]):
Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument` Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Role: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Role:
guild = ctx.guild guild = ctx.guild
if not guild: if not guild:
raise NoPrivateMessage() raise NoPrivateMessage()
@ -723,7 +721,7 @@ class RoleConverter(IDConverter[discord.Role]):
class GameConverter(Converter[discord.Game]): class GameConverter(Converter[discord.Game]):
"""Converts to :class:`~discord.Game`.""" """Converts to :class:`~discord.Game`."""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Game: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Game:
return discord.Game(name=argument) return discord.Game(name=argument)
@ -736,7 +734,7 @@ class InviteConverter(Converter[discord.Invite]):
Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument` Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Invite: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Invite:
try: try:
invite = await ctx.bot.fetch_invite(argument) invite = await ctx.bot.fetch_invite(argument)
return invite return invite
@ -755,7 +753,7 @@ class GuildConverter(IDConverter[discord.Guild]):
.. versionadded:: 1.7 .. versionadded:: 1.7
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Guild: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Guild:
match = self._get_id_match(argument) match = self._get_id_match(argument)
result = None result = None
@ -787,7 +785,7 @@ class EmojiConverter(IDConverter[discord.Emoji]):
Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument` Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Emoji: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Emoji:
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument) match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument)
result = None result = None
bot = ctx.bot bot = ctx.bot
@ -821,7 +819,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument` Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialEmoji: async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialEmoji:
match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument) match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument)
if match: if match:
@ -850,7 +848,7 @@ class GuildStickerConverter(IDConverter[discord.GuildSticker]):
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.GuildSticker: async def convert(self, ctx: Context[BotT], argument: str) -> discord.GuildSticker:
match = self._get_id_match(argument) match = self._get_id_match(argument)
result = None result = None
bot = ctx.bot bot = ctx.bot
@ -890,7 +888,7 @@ class ScheduledEventConverter(IDConverter[discord.ScheduledEvent]):
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.ScheduledEvent: async def convert(self, ctx: Context[BotT], argument: str) -> discord.ScheduledEvent:
guild = ctx.guild guild = ctx.guild
match = self._get_id_match(argument) match = self._get_id_match(argument)
result = None result = None
@ -967,7 +965,7 @@ class clean_content(Converter[str]):
self.escape_markdown = escape_markdown self.escape_markdown = escape_markdown
self.remove_markdown = remove_markdown self.remove_markdown = remove_markdown
async def convert(self, ctx: Context[_Bot], argument: str) -> str: async def convert(self, ctx: Context[BotT], argument: str) -> str:
msg = ctx.message msg = ctx.message
if ctx.guild: if ctx.guild:
@ -1047,10 +1045,10 @@ class Greedy(List[T]):
__slots__ = ('converter',) __slots__ = ('converter',)
def __init__(self, *, converter: T): def __init__(self, *, converter: T) -> None:
self.converter = converter self.converter: T = converter
def __repr__(self): def __repr__(self) -> str:
converter = getattr(self.converter, '__name__', repr(self.converter)) converter = getattr(self.converter, '__name__', repr(self.converter))
return f'Greedy[{converter}]' return f'Greedy[{converter}]'
@ -1099,11 +1097,11 @@ def get_converter(param: inspect.Parameter) -> Any:
_GenericAlias = type(List[T]) _GenericAlias = type(List[T])
def is_generic_type(tp: Any, *, _GenericAlias: Type = _GenericAlias) -> bool: def is_generic_type(tp: Any, *, _GenericAlias: type = _GenericAlias) -> bool:
return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias) # type: ignore return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias)
CONVERTER_MAPPING: Dict[Type[Any], Any] = { CONVERTER_MAPPING: Dict[type, Any] = {
discord.Object: ObjectConverter, discord.Object: ObjectConverter,
discord.Member: MemberConverter, discord.Member: MemberConverter,
discord.User: UserConverter, discord.User: UserConverter,
@ -1128,7 +1126,7 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = {
} }
async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter): async def _actual_conversion(ctx: Context[BotT], converter, argument: str, param: inspect.Parameter):
if converter is bool: if converter is bool:
return _convert_to_bool(argument) return _convert_to_bool(argument)
@ -1166,7 +1164,7 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter): async def run_converters(ctx: Context[BotT], converter: Any, argument: str, param: inspect.Parameter) -> Any:
"""|coro| """|coro|
Runs converters for a given converter, argument, and parameter. Runs converters for a given converter, argument, and parameter.

2
discord/ext/commands/cooldowns.py

@ -220,7 +220,7 @@ class CooldownMapping:
return self._type return self._type
@classmethod @classmethod
def from_cooldown(cls, rate, per, type) -> Self: def from_cooldown(cls, rate: float, per: float, type: Callable[[Message], Any]) -> Self:
return cls(Cooldown(rate, per), type) return cls(Cooldown(rate, per), type)
def _bucket_key(self, msg: Message) -> Any: def _bucket_key(self, msg: Message) -> Any:

115
discord/ext/commands/core.py

@ -61,6 +61,8 @@ if TYPE_CHECKING:
from discord.message import Message from discord.message import Message
from ._types import ( from ._types import (
BotT,
ContextT,
Coro, Coro,
CoroFunc, CoroFunc,
Check, Check,
@ -101,7 +103,6 @@ MISSING: Any = discord.utils.MISSING
T = TypeVar('T') T = TypeVar('T')
CogT = TypeVar('CogT', bound='Optional[Cog]') CogT = TypeVar('CogT', bound='Optional[Cog]')
CommandT = TypeVar('CommandT', bound='Command') CommandT = TypeVar('CommandT', bound='Command')
ContextT = TypeVar('ContextT', bound='Context')
# CHT = TypeVar('CHT', bound='Check') # CHT = TypeVar('CHT', bound='Check')
GroupT = TypeVar('GroupT', bound='Group') GroupT = TypeVar('GroupT', bound='Group')
FuncT = TypeVar('FuncT', bound=Callable[..., Any]) FuncT = TypeVar('FuncT', bound=Callable[..., Any])
@ -159,9 +160,9 @@ def get_signature_parameters(
return params return params
def wrap_callback(coro): def wrap_callback(coro: Callable[P, Coro[T]]) -> Callable[P, Coro[Optional[T]]]:
@functools.wraps(coro) @functools.wraps(coro)
async def wrapped(*args, **kwargs): async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
try: try:
ret = await coro(*args, **kwargs) ret = await coro(*args, **kwargs)
except CommandError: except CommandError:
@ -175,9 +176,11 @@ def wrap_callback(coro):
return wrapped return wrapped
def hooked_wrapped_callback(command, ctx, coro): def hooked_wrapped_callback(
command: Command[Any, ..., Any], ctx: Context[BotT], coro: Callable[P, Coro[T]]
) -> Callable[P, Coro[Optional[T]]]:
@functools.wraps(coro) @functools.wraps(coro)
async def wrapped(*args, **kwargs): async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
try: try:
ret = await coro(*args, **kwargs) ret = await coro(*args, **kwargs)
except CommandError: except CommandError:
@ -191,7 +194,7 @@ def hooked_wrapped_callback(command, ctx, coro):
raise CommandInvokeError(exc) from exc raise CommandInvokeError(exc) from exc
finally: finally:
if command._max_concurrency is not None: if command._max_concurrency is not None:
await command._max_concurrency.release(ctx) await command._max_concurrency.release(ctx.message)
await command.call_after_hooks(ctx) await command.call_after_hooks(ctx)
return ret return ret
@ -359,7 +362,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
except AttributeError: except AttributeError:
checks = kwargs.get('checks', []) checks = kwargs.get('checks', [])
self.checks: List[Check] = checks self.checks: List[Check[ContextT]] = checks
try: try:
cooldown = func.__commands_cooldown__ cooldown = func.__commands_cooldown__
@ -387,8 +390,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.cog: CogT = None self.cog: CogT = None
# bandaid for the fact that sometimes parent can be the bot instance # bandaid for the fact that sometimes parent can be the bot instance
parent = kwargs.get('parent') parent: Optional[GroupMixin[Any]] = kwargs.get('parent')
self.parent: Optional[GroupMixin] = parent if isinstance(parent, _BaseCommand) else None # type: ignore self.parent: Optional[GroupMixin[Any]] = parent if isinstance(parent, _BaseCommand) else None
self._before_invoke: Optional[Hook] = None self._before_invoke: Optional[Hook] = None
try: try:
@ -422,16 +425,16 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
) -> None: ) -> None:
self._callback = function self._callback = function
unwrap = unwrap_function(function) unwrap = unwrap_function(function)
self.module = unwrap.__module__ self.module: str = unwrap.__module__
try: try:
globalns = unwrap.__globals__ globalns = unwrap.__globals__
except AttributeError: except AttributeError:
globalns = {} globalns = {}
self.params = get_signature_parameters(function, globalns) self.params: Dict[str, inspect.Parameter] = get_signature_parameters(function, globalns)
def add_check(self, func: Check, /) -> None: def add_check(self, func: Check[ContextT], /) -> None:
"""Adds a check to the command. """Adds a check to the command.
This is the non-decorator interface to :func:`.check`. This is the non-decorator interface to :func:`.check`.
@ -450,7 +453,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.checks.append(func) self.checks.append(func)
def remove_check(self, func: Check, /) -> None: def remove_check(self, func: Check[ContextT], /) -> None:
"""Removes a check from the command. """Removes a check from the command.
This function is idempotent and will not raise an exception This function is idempotent and will not raise an exception
@ -484,7 +487,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.__init__(self.callback, **dict(self.__original_kwargs__, **kwargs)) self.__init__(self.callback, **dict(self.__original_kwargs__, **kwargs))
self.cog = cog self.cog = cog
async def __call__(self, context: Context, *args: P.args, **kwargs: P.kwargs) -> T: async def __call__(self, context: Context[BotT], *args: P.args, **kwargs: P.kwargs) -> T:
"""|coro| """|coro|
Calls the internal callback that the command holds. Calls the internal callback that the command holds.
@ -539,7 +542,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
else: else:
return self.copy() return self.copy()
async def dispatch_error(self, ctx: Context, error: Exception) -> None: async def dispatch_error(self, ctx: Context[BotT], error: CommandError) -> None:
ctx.command_failed = True ctx.command_failed = True
cog = self.cog cog = self.cog
try: try:
@ -549,7 +552,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
else: else:
injected = wrap_callback(coro) injected = wrap_callback(coro)
if cog is not None: if cog is not None:
await injected(cog, ctx, error) await injected(cog, ctx, error) # type: ignore
else: else:
await injected(ctx, error) await injected(ctx, error)
@ -562,7 +565,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
finally: finally:
ctx.bot.dispatch('command_error', ctx, error) ctx.bot.dispatch('command_error', ctx, error)
async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: async def transform(self, ctx: Context[BotT], param: inspect.Parameter) -> Any:
required = param.default is param.empty required = param.default is param.empty
converter = get_converter(param) converter = get_converter(param)
consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw
@ -610,7 +613,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
# type-checker fails to narrow argument # type-checker fails to narrow argument
return await run_converters(ctx, converter, argument, param) # type: ignore return await run_converters(ctx, converter, argument, param) # type: ignore
async def _transform_greedy_pos(self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any) -> Any: async def _transform_greedy_pos(
self, ctx: Context[BotT], param: inspect.Parameter, required: bool, converter: Any
) -> Any:
view = ctx.view view = ctx.view
result = [] result = []
while not view.eof: while not view.eof:
@ -631,7 +636,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return param.default return param.default
return result return result
async def _transform_greedy_var_pos(self, ctx: Context, param: inspect.Parameter, converter: Any) -> Any: async def _transform_greedy_var_pos(self, ctx: Context[BotT], param: inspect.Parameter, converter: Any) -> Any:
view = ctx.view view = ctx.view
previous = view.index previous = view.index
try: try:
@ -669,7 +674,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return ' '.join(reversed(entries)) return ' '.join(reversed(entries))
@property @property
def parents(self) -> List[Group]: def parents(self) -> List[Group[Any, ..., Any]]:
"""List[:class:`Group`]: Retrieves the parents of this command. """List[:class:`Group`]: Retrieves the parents of this command.
If the command has no parents then it returns an empty :class:`list`. If the command has no parents then it returns an empty :class:`list`.
@ -687,7 +692,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return entries return entries
@property @property
def root_parent(self) -> Optional[Group]: def root_parent(self) -> Optional[Group[Any, ..., Any]]:
"""Optional[:class:`Group`]: Retrieves the root parent of this command. """Optional[:class:`Group`]: Retrieves the root parent of this command.
If the command has no parents then it returns ``None``. If the command has no parents then it returns ``None``.
@ -716,7 +721,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
def __str__(self) -> str: def __str__(self) -> str:
return self.qualified_name return self.qualified_name
async def _parse_arguments(self, ctx: Context) -> None: async def _parse_arguments(self, ctx: Context[BotT]) -> None:
ctx.args = [ctx] if self.cog is None else [self.cog, ctx] ctx.args = [ctx] if self.cog is None else [self.cog, ctx]
ctx.kwargs = {} ctx.kwargs = {}
args = ctx.args args = ctx.args
@ -752,7 +757,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if not self.ignore_extra and not view.eof: if not self.ignore_extra and not view.eof:
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name) raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
async def call_before_hooks(self, ctx: Context) -> None: async def call_before_hooks(self, ctx: Context[BotT]) -> None:
# now that we're done preparing we can call the pre-command hooks # now that we're done preparing we can call the pre-command hooks
# first, call the command local hook: # first, call the command local hook:
cog = self.cog cog = self.cog
@ -777,7 +782,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if hook is not None: if hook is not None:
await hook(ctx) await hook(ctx)
async def call_after_hooks(self, ctx: Context) -> None: async def call_after_hooks(self, ctx: Context[BotT]) -> None:
cog = self.cog cog = self.cog
if self._after_invoke is not None: if self._after_invoke is not None:
instance = getattr(self._after_invoke, '__self__', cog) instance = getattr(self._after_invoke, '__self__', cog)
@ -796,7 +801,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if hook is not None: if hook is not None:
await hook(ctx) await hook(ctx)
def _prepare_cooldowns(self, ctx: Context) -> None: def _prepare_cooldowns(self, ctx: Context[BotT]) -> None:
if self._buckets.valid: if self._buckets.valid:
dt = ctx.message.edited_at or ctx.message.created_at dt = ctx.message.edited_at or ctx.message.created_at
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
@ -806,7 +811,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if retry_after: if retry_after:
raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore
async def prepare(self, ctx: Context) -> None: async def prepare(self, ctx: Context[BotT]) -> None:
ctx.command = self ctx.command = self
if not await self.can_run(ctx): if not await self.can_run(ctx):
@ -830,7 +835,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
await self._max_concurrency.release(ctx) # type: ignore await self._max_concurrency.release(ctx) # type: ignore
raise raise
def is_on_cooldown(self, ctx: Context) -> bool: def is_on_cooldown(self, ctx: Context[BotT]) -> bool:
"""Checks whether the command is currently on cooldown. """Checks whether the command is currently on cooldown.
Parameters Parameters
@ -851,7 +856,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
return bucket.get_tokens(current) == 0 return bucket.get_tokens(current) == 0
def reset_cooldown(self, ctx: Context) -> None: def reset_cooldown(self, ctx: Context[BotT]) -> None:
"""Resets the cooldown on this command. """Resets the cooldown on this command.
Parameters Parameters
@ -863,7 +868,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
bucket = self._buckets.get_bucket(ctx.message) bucket = self._buckets.get_bucket(ctx.message)
bucket.reset() bucket.reset()
def get_cooldown_retry_after(self, ctx: Context) -> float: def get_cooldown_retry_after(self, ctx: Context[BotT]) -> float:
"""Retrieves the amount of seconds before this command can be tried again. """Retrieves the amount of seconds before this command can be tried again.
.. versionadded:: 1.4 .. versionadded:: 1.4
@ -887,7 +892,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return 0.0 return 0.0
async def invoke(self, ctx: Context) -> None: async def invoke(self, ctx: Context[BotT]) -> None:
await self.prepare(ctx) await self.prepare(ctx)
# terminate the invoked_subcommand chain. # terminate the invoked_subcommand chain.
@ -896,9 +901,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
ctx.invoked_subcommand = None ctx.invoked_subcommand = None
ctx.subcommand_passed = None ctx.subcommand_passed = None
injected = hooked_wrapped_callback(self, ctx, self.callback) injected = hooked_wrapped_callback(self, ctx, self.callback)
await injected(*ctx.args, **ctx.kwargs) await injected(*ctx.args, **ctx.kwargs) # type: ignore
async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: async def reinvoke(self, ctx: Context[BotT], *, call_hooks: bool = False) -> None:
ctx.command = self ctx.command = self
await self._parse_arguments(ctx) await self._parse_arguments(ctx)
@ -936,7 +941,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if not asyncio.iscoroutinefunction(coro): if not asyncio.iscoroutinefunction(coro):
raise TypeError('The error handler must be a coroutine.') raise TypeError('The error handler must be a coroutine.')
self.on_error: Error = coro self.on_error: Error[Any] = coro
return coro return coro
def has_error_handler(self) -> bool: def has_error_handler(self) -> bool:
@ -1075,7 +1080,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return ' '.join(result) return ' '.join(result)
async def can_run(self, ctx: Context) -> bool: async def can_run(self, ctx: Context[BotT]) -> bool:
"""|coro| """|coro|
Checks if the command can be executed by checking all the predicates Checks if the command can be executed by checking all the predicates
@ -1341,7 +1346,7 @@ class GroupMixin(Generic[CogT]):
def command( def command(
self, self,
name: str = MISSING, name: str = MISSING,
cls: Type[Command] = MISSING, cls: Type[Command[Any, ..., Any]] = MISSING,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
@ -1401,7 +1406,7 @@ class GroupMixin(Generic[CogT]):
def group( def group(
self, self,
name: str = MISSING, name: str = MISSING,
cls: Type[Group] = MISSING, cls: Type[Group[Any, ..., Any]] = MISSING,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
@ -1461,9 +1466,9 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
ret = super().copy() ret = super().copy()
for cmd in self.commands: for cmd in self.commands:
ret.add_command(cmd.copy()) ret.add_command(cmd.copy())
return ret # type: ignore return ret
async def invoke(self, ctx: Context) -> None: async def invoke(self, ctx: Context[BotT]) -> None:
ctx.invoked_subcommand = None ctx.invoked_subcommand = None
ctx.subcommand_passed = None ctx.subcommand_passed = None
early_invoke = not self.invoke_without_command early_invoke = not self.invoke_without_command
@ -1481,7 +1486,7 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
if early_invoke: if early_invoke:
injected = hooked_wrapped_callback(self, ctx, self.callback) injected = hooked_wrapped_callback(self, ctx, self.callback)
await injected(*ctx.args, **ctx.kwargs) await injected(*ctx.args, **ctx.kwargs) # type: ignore
ctx.invoked_parents.append(ctx.invoked_with) # type: ignore ctx.invoked_parents.append(ctx.invoked_with) # type: ignore
@ -1494,7 +1499,7 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
view.previous = previous view.previous = previous
await super().invoke(ctx) await super().invoke(ctx)
async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: async def reinvoke(self, ctx: Context[BotT], *, call_hooks: bool = False) -> None:
ctx.invoked_subcommand = None ctx.invoked_subcommand = None
early_invoke = not self.invoke_without_command early_invoke = not self.invoke_without_command
if early_invoke: if early_invoke:
@ -1592,7 +1597,7 @@ def command(
def command( def command(
name: str = MISSING, name: str = MISSING,
cls: Type[Command] = MISSING, cls: Type[Command[Any, ..., Any]] = MISSING,
**attrs: Any, **attrs: Any,
) -> Any: ) -> Any:
"""A decorator that transforms a function into a :class:`.Command` """A decorator that transforms a function into a :class:`.Command`
@ -1662,7 +1667,7 @@ def group(
def group( def group(
name: str = MISSING, name: str = MISSING,
cls: Type[Group] = MISSING, cls: Type[Group[Any, ..., Any]] = MISSING,
**attrs: Any, **attrs: Any,
) -> Any: ) -> Any:
"""A decorator that transforms a function into a :class:`.Group`. """A decorator that transforms a function into a :class:`.Group`.
@ -1679,7 +1684,7 @@ def group(
return command(name=name, cls=cls, **attrs) return command(name=name, cls=cls, **attrs)
def check(predicate: Check) -> Callable[[T], T]: def check(predicate: Check[ContextT]) -> Callable[[T], T]:
r"""A decorator that adds a check to the :class:`.Command` or its r"""A decorator that adds a check to the :class:`.Command` or its
subclasses. These checks could be accessed via :attr:`.Command.checks`. subclasses. These checks could be accessed via :attr:`.Command.checks`.
@ -1774,7 +1779,7 @@ def check(predicate: Check) -> Callable[[T], T]:
return decorator # type: ignore return decorator # type: ignore
def check_any(*checks: Check) -> Callable[[T], T]: def check_any(*checks: Check[ContextT]) -> Callable[[T], T]:
r"""A :func:`check` that is added that checks if any of the checks passed r"""A :func:`check` that is added that checks if any of the checks passed
will pass, i.e. using logical OR. will pass, i.e. using logical OR.
@ -1827,7 +1832,7 @@ def check_any(*checks: Check) -> Callable[[T], T]:
else: else:
unwrapped.append(pred) unwrapped.append(pred)
async def predicate(ctx: Context) -> bool: async def predicate(ctx: Context[BotT]) -> bool:
errors = [] errors = []
for func in unwrapped: for func in unwrapped:
try: try:
@ -1870,7 +1875,7 @@ def has_role(item: Union[int, str]) -> Callable[[T], T]:
The name or ID of the role to check. The name or ID of the role to check.
""" """
def predicate(ctx: Context) -> bool: def predicate(ctx: Context[BotT]) -> bool:
if ctx.guild is None: if ctx.guild is None:
raise NoPrivateMessage() raise NoPrivateMessage()
@ -1923,7 +1928,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
raise NoPrivateMessage() raise NoPrivateMessage()
# ctx.guild is None doesn't narrow ctx.author to Member # ctx.guild is None doesn't narrow ctx.author to Member
getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore getter = functools.partial(discord.utils.get, ctx.author.roles)
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items): if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
return True return True
raise MissingAnyRole(list(items)) raise MissingAnyRole(list(items))
@ -2022,7 +2027,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
if invalid: if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool: def predicate(ctx: Context[BotT]) -> bool:
ch = ctx.channel ch = ctx.channel
permissions = ch.permissions_for(ctx.author) # type: ignore permissions = ch.permissions_for(ctx.author) # type: ignore
@ -2048,7 +2053,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
if invalid: if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool: def predicate(ctx: Context[BotT]) -> bool:
guild = ctx.guild guild = ctx.guild
me = guild.me if guild is not None else ctx.bot.user me = guild.me if guild is not None else ctx.bot.user
permissions = ctx.channel.permissions_for(me) # type: ignore permissions = ctx.channel.permissions_for(me) # type: ignore
@ -2077,7 +2082,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
if invalid: if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool: def predicate(ctx: Context[BotT]) -> bool:
if not ctx.guild: if not ctx.guild:
raise NoPrivateMessage raise NoPrivateMessage
@ -2103,7 +2108,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
if invalid: if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool: def predicate(ctx: Context[BotT]) -> bool:
if not ctx.guild: if not ctx.guild:
raise NoPrivateMessage raise NoPrivateMessage
@ -2129,7 +2134,7 @@ def dm_only() -> Callable[[T], T]:
.. versionadded:: 1.1 .. versionadded:: 1.1
""" """
def predicate(ctx: Context) -> bool: def predicate(ctx: Context[BotT]) -> bool:
if ctx.guild is not None: if ctx.guild is not None:
raise PrivateMessageOnly() raise PrivateMessageOnly()
return True return True
@ -2146,7 +2151,7 @@ def guild_only() -> Callable[[T], T]:
that is inherited from :exc:`.CheckFailure`. that is inherited from :exc:`.CheckFailure`.
""" """
def predicate(ctx: Context) -> bool: def predicate(ctx: Context[BotT]) -> bool:
if ctx.guild is None: if ctx.guild is None:
raise NoPrivateMessage() raise NoPrivateMessage()
return True return True
@ -2164,7 +2169,7 @@ def is_owner() -> Callable[[T], T]:
from :exc:`.CheckFailure`. from :exc:`.CheckFailure`.
""" """
async def predicate(ctx: Context) -> bool: async def predicate(ctx: Context[BotT]) -> bool:
if not await ctx.bot.is_owner(ctx.author): if not await ctx.bot.is_owner(ctx.author):
raise NotOwner('You do not own this bot.') raise NotOwner('You do not own this bot.')
return True return True
@ -2184,7 +2189,7 @@ def is_nsfw() -> Callable[[T], T]:
DM channels will also now pass this check. DM channels will also now pass this check.
""" """
def pred(ctx: Context) -> bool: def pred(ctx: Context[BotT]) -> bool:
ch = ctx.channel ch = ctx.channel
if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()): if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()):
return True return True

14
discord/ext/commands/errors.py

@ -39,6 +39,8 @@ if TYPE_CHECKING:
from discord.threads import Thread from discord.threads import Thread
from discord.types.snowflake import Snowflake, SnowflakeList from discord.types.snowflake import Snowflake, SnowflakeList
from ._types import BotT
__all__ = ( __all__ = (
'CommandError', 'CommandError',
@ -135,8 +137,8 @@ class ConversionError(CommandError):
the ``__cause__`` attribute. the ``__cause__`` attribute.
""" """
def __init__(self, converter: Converter, original: Exception) -> None: def __init__(self, converter: Converter[Any], original: Exception) -> None:
self.converter: Converter = converter self.converter: Converter[Any] = converter
self.original: Exception = original self.original: Exception = original
@ -224,9 +226,9 @@ class CheckAnyFailure(CheckFailure):
A list of check predicates that failed. A list of check predicates that failed.
""" """
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None: def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context[BotT]], bool]]) -> None:
self.checks: List[CheckFailure] = checks self.checks: List[CheckFailure] = checks
self.errors: List[Callable[[Context], bool]] = errors self.errors: List[Callable[[Context[BotT]], bool]] = errors
super().__init__('You do not have permission to run this command.') super().__init__('You do not have permission to run this command.')
@ -807,9 +809,9 @@ class BadUnionArgument(UserInputError):
A list of errors that were caught from failing the conversion. A list of errors that were caught from failing the conversion.
""" """
def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None: def __init__(self, param: Parameter, converters: Tuple[type, ...], errors: List[CommandError]) -> None:
self.param: Parameter = param self.param: Parameter = param
self.converters: Tuple[Type, ...] = converters self.converters: Tuple[type, ...] = converters
self.errors: List[CommandError] = errors self.errors: List[CommandError] = errors
def _get_name(x): def _get_name(x):

18
discord/ext/commands/flags.py

@ -49,8 +49,6 @@ from typing import (
Tuple, Tuple,
List, List,
Any, Any,
Type,
TypeVar,
Union, Union,
) )
@ -70,6 +68,8 @@ if TYPE_CHECKING:
from .context import Context from .context import Context
from ._types import BotT
@dataclass @dataclass
class Flag: class Flag:
@ -148,7 +148,7 @@ def flag(
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override) return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override)
def validate_flag_name(name: str, forbidden: Set[str]): def validate_flag_name(name: str, forbidden: Set[str]) -> None:
if not name: if not name:
raise ValueError('flag names should not be empty') raise ValueError('flag names should not be empty')
@ -348,7 +348,7 @@ class FlagsMeta(type):
return type.__new__(cls, name, bases, attrs) return type.__new__(cls, name, bases, attrs)
async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]: async def tuple_convert_all(ctx: Context[BotT], argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]:
view = StringView(argument) view = StringView(argument)
results = [] results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore param: inspect.Parameter = ctx.current_parameter # type: ignore
@ -373,7 +373,7 @@ async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter:
return tuple(results) return tuple(results)
async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]: async def tuple_convert_flag(ctx: Context[BotT], argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]:
view = StringView(argument) view = StringView(argument)
results = [] results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore param: inspect.Parameter = ctx.current_parameter # type: ignore
@ -401,7 +401,7 @@ async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters
return tuple(results) return tuple(results)
async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -> Any: async def convert_flag(ctx: Context[BotT], argument: str, flag: Flag, annotation: Any = None) -> Any:
param: inspect.Parameter = ctx.current_parameter # type: ignore param: inspect.Parameter = ctx.current_parameter # type: ignore
annotation = annotation or flag.annotation annotation = annotation or flag.annotation
try: try:
@ -480,7 +480,7 @@ class FlagConverter(metaclass=FlagsMeta):
yield (flag.name, getattr(self, flag.attribute)) yield (flag.name, getattr(self, flag.attribute))
@classmethod @classmethod
async def _construct_default(cls, ctx: Context) -> Self: async def _construct_default(cls, ctx: Context[BotT]) -> Self:
self = cls.__new__(cls) self = cls.__new__(cls)
flags = cls.__commands_flags__ flags = cls.__commands_flags__
for flag in flags.values(): for flag in flags.values():
@ -546,7 +546,7 @@ class FlagConverter(metaclass=FlagsMeta):
return result return result
@classmethod @classmethod
async def convert(cls, ctx: Context, argument: str) -> Self: async def convert(cls, ctx: Context[BotT], argument: str) -> Self:
"""|coro| """|coro|
The method that actually converters an argument to the flag mapping. The method that actually converters an argument to the flag mapping.
@ -610,7 +610,7 @@ class FlagConverter(metaclass=FlagsMeta):
values = [await convert_flag(ctx, value, flag) for value in values] values = [await convert_flag(ctx, value, flag) for value in values]
if flag.cast_to_dict: if flag.cast_to_dict:
values = dict(values) # type: ignore values = dict(values)
setattr(self, flag.attribute, values) setattr(self, flag.attribute, values)

293
discord/ext/commands/help.py

@ -22,13 +22,27 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import itertools import itertools
import copy import copy
import functools import functools
import inspect
import re import re
from typing import Optional, TYPE_CHECKING from typing import (
TYPE_CHECKING,
Optional,
Generator,
List,
TypeVar,
Callable,
Any,
Dict,
Tuple,
Iterable,
Sequence,
Mapping,
)
import discord.utils import discord.utils
@ -36,7 +50,21 @@ from .core import Group, Command, get_signature_parameters
from .errors import CommandError from .errors import CommandError
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
import inspect
import discord.abc
from .bot import BotBase
from .context import Context from .context import Context
from .cog import Cog
from ._types import (
Check,
ContextT,
BotT,
_Bot,
)
__all__ = ( __all__ = (
'Paginator', 'Paginator',
@ -45,7 +73,9 @@ __all__ = (
'MinimalHelpCommand', 'MinimalHelpCommand',
) )
MISSING = discord.utils.MISSING FuncT = TypeVar('FuncT', bound=Callable[..., Any])
MISSING: Any = discord.utils.MISSING
# help -> shows info of bot on top/bottom and lists subcommands # help -> shows info of bot on top/bottom and lists subcommands
# help command -> shows detailed info of command # help command -> shows detailed info of command
@ -80,10 +110,10 @@ class Paginator:
Attributes Attributes
----------- -----------
prefix: :class:`str` prefix: Optional[:class:`str`]
The prefix inserted to every page. e.g. three backticks. The prefix inserted to every page. e.g. three backticks, if any.
suffix: :class:`str` suffix: Optional[:class:`str`]
The suffix appended at the end of every page. e.g. three backticks. The suffix appended at the end of every page. e.g. three backticks, if any.
max_size: :class:`int` max_size: :class:`int`
The maximum amount of codepoints allowed in a page. The maximum amount of codepoints allowed in a page.
linesep: :class:`str` linesep: :class:`str`
@ -91,36 +121,38 @@ class Paginator:
.. versionadded:: 1.7 .. versionadded:: 1.7
""" """
def __init__(self, prefix='```', suffix='```', max_size=2000, linesep='\n'): def __init__(
self.prefix = prefix self, prefix: Optional[str] = '```', suffix: Optional[str] = '```', max_size: int = 2000, linesep: str = '\n'
self.suffix = suffix ) -> None:
self.max_size = max_size self.prefix: Optional[str] = prefix
self.linesep = linesep self.suffix: Optional[str] = suffix
self.max_size: int = max_size
self.linesep: str = linesep
self.clear() self.clear()
def clear(self): def clear(self) -> None:
"""Clears the paginator to have no pages.""" """Clears the paginator to have no pages."""
if self.prefix is not None: if self.prefix is not None:
self._current_page = [self.prefix] self._current_page: List[str] = [self.prefix]
self._count = len(self.prefix) + self._linesep_len # prefix + newline self._count: int = len(self.prefix) + self._linesep_len # prefix + newline
else: else:
self._current_page = [] self._current_page = []
self._count = 0 self._count = 0
self._pages = [] self._pages: List[str] = []
@property @property
def _prefix_len(self): def _prefix_len(self) -> int:
return len(self.prefix) if self.prefix else 0 return len(self.prefix) if self.prefix else 0
@property @property
def _suffix_len(self): def _suffix_len(self) -> int:
return len(self.suffix) if self.suffix else 0 return len(self.suffix) if self.suffix else 0
@property @property
def _linesep_len(self): def _linesep_len(self) -> int:
return len(self.linesep) return len(self.linesep)
def add_line(self, line='', *, empty=False): def add_line(self, line: str = '', *, empty: bool = False) -> None:
"""Adds a line to the current page. """Adds a line to the current page.
If the line exceeds the :attr:`max_size` then an exception If the line exceeds the :attr:`max_size` then an exception
@ -152,7 +184,7 @@ class Paginator:
self._current_page.append('') self._current_page.append('')
self._count += self._linesep_len self._count += self._linesep_len
def close_page(self): def close_page(self) -> None:
"""Prematurely terminate a page.""" """Prematurely terminate a page."""
if self.suffix is not None: if self.suffix is not None:
self._current_page.append(self.suffix) self._current_page.append(self.suffix)
@ -165,36 +197,38 @@ class Paginator:
self._current_page = [] self._current_page = []
self._count = 0 self._count = 0
def __len__(self): def __len__(self) -> int:
total = sum(len(p) for p in self._pages) total = sum(len(p) for p in self._pages)
return total + self._count return total + self._count
@property @property
def pages(self): def pages(self) -> List[str]:
"""List[:class:`str`]: Returns the rendered list of pages.""" """List[:class:`str`]: Returns the rendered list of pages."""
# we have more than just the prefix in our current page # we have more than just the prefix in our current page
if len(self._current_page) > (0 if self.prefix is None else 1): if len(self._current_page) > (0 if self.prefix is None else 1):
self.close_page() self.close_page()
return self._pages return self._pages
def __repr__(self): def __repr__(self) -> str:
fmt = '<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>' fmt = '<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>'
return fmt.format(self) return fmt.format(self)
def _not_overridden(f): def _not_overridden(f: FuncT) -> FuncT:
f.__help_command_not_overridden__ = True f.__help_command_not_overridden__ = True
return f return f
class _HelpCommandImpl(Command): class _HelpCommandImpl(Command):
def __init__(self, inject, *args, **kwargs): def __init__(self, inject: HelpCommand, *args: Any, **kwargs: Any) -> None:
super().__init__(inject.command_callback, *args, **kwargs) super().__init__(inject.command_callback, *args, **kwargs)
self._original = inject self._original: HelpCommand = inject
self._injected = inject self._injected: HelpCommand = inject
self.params = get_signature_parameters(inject.command_callback, globals(), skip_parameters=1) self.params: Dict[str, inspect.Parameter] = get_signature_parameters(
inject.command_callback, globals(), skip_parameters=1
)
async def prepare(self, ctx): async def prepare(self, ctx: Context[Any]) -> None:
self._injected = injected = self._original.copy() self._injected = injected = self._original.copy()
injected.context = ctx injected.context = ctx
self.callback = injected.command_callback self.callback = injected.command_callback
@ -209,7 +243,7 @@ class _HelpCommandImpl(Command):
await super().prepare(ctx) await super().prepare(ctx)
async def _parse_arguments(self, ctx): async def _parse_arguments(self, ctx: Context[BotT]) -> None:
# Make the parser think we don't have a cog so it doesn't # Make the parser think we don't have a cog so it doesn't
# inject the parameter into `ctx.args`. # inject the parameter into `ctx.args`.
original_cog = self.cog original_cog = self.cog
@ -219,22 +253,26 @@ class _HelpCommandImpl(Command):
finally: finally:
self.cog = original_cog self.cog = original_cog
async def _on_error_cog_implementation(self, dummy, ctx, error): async def _on_error_cog_implementation(self, _, ctx: Context[BotT], error: CommandError) -> None:
await self._injected.on_help_command_error(ctx, error) await self._injected.on_help_command_error(ctx, error)
def _inject_into_cog(self, cog): def _inject_into_cog(self, cog: Cog) -> None:
# Warning: hacky # Warning: hacky
# Make the cog think that get_commands returns this command # Make the cog think that get_commands returns this command
# as well if we inject it without modifying __cog_commands__ # as well if we inject it without modifying __cog_commands__
# since that's used for the injection and ejection of cogs. # since that's used for the injection and ejection of cogs.
def wrapped_get_commands(*, _original=cog.get_commands): def wrapped_get_commands(
*, _original: Callable[[], List[Command[Any, ..., Any]]] = cog.get_commands
) -> List[Command[Any, ..., Any]]:
ret = _original() ret = _original()
ret.append(self) ret.append(self)
return ret return ret
# Ditto here # Ditto here
def wrapped_walk_commands(*, _original=cog.walk_commands): def wrapped_walk_commands(
*, _original: Callable[[], Generator[Command[Any, ..., Any], None, None]] = cog.walk_commands
):
yield from _original() yield from _original()
yield self yield self
@ -244,7 +282,7 @@ class _HelpCommandImpl(Command):
cog.walk_commands = wrapped_walk_commands cog.walk_commands = wrapped_walk_commands
self.cog = cog self.cog = cog
def _eject_cog(self): def _eject_cog(self) -> None:
if self.cog is None: if self.cog is None:
return return
@ -298,7 +336,11 @@ class HelpCommand:
MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys())) MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys()))
def __new__(cls, *args, **kwargs): if TYPE_CHECKING:
__original_kwargs__: Dict[str, Any]
__original_args__: Tuple[Any, ...]
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
# To prevent race conditions of a single instance while also allowing # To prevent race conditions of a single instance while also allowing
# for settings to be passed the original arguments passed must be assigned # for settings to be passed the original arguments passed must be assigned
# to allow for easier copies (which will be made when the help command is actually called) # to allow for easier copies (which will be made when the help command is actually called)
@ -314,30 +356,31 @@ class HelpCommand:
self.__original_args__ = deepcopy(args) self.__original_args__ = deepcopy(args)
return self return self
def __init__(self, **options): def __init__(self, **options: Any) -> None:
self.show_hidden = options.pop('show_hidden', False) self.show_hidden: bool = options.pop('show_hidden', False)
self.verify_checks = options.pop('verify_checks', True) self.verify_checks: bool = options.pop('verify_checks', True)
self.command_attrs: Dict[str, Any]
self.command_attrs = attrs = options.pop('command_attrs', {}) self.command_attrs = attrs = options.pop('command_attrs', {})
attrs.setdefault('name', 'help') attrs.setdefault('name', 'help')
attrs.setdefault('help', 'Shows this message') attrs.setdefault('help', 'Shows this message')
self.context: Context = MISSING self.context: Context[_Bot] = MISSING
self._command_impl = _HelpCommandImpl(self, **self.command_attrs) self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
def copy(self): def copy(self) -> Self:
obj = self.__class__(*self.__original_args__, **self.__original_kwargs__) obj = self.__class__(*self.__original_args__, **self.__original_kwargs__)
obj._command_impl = self._command_impl obj._command_impl = self._command_impl
return obj return obj
def _add_to_bot(self, bot): def _add_to_bot(self, bot: BotBase) -> None:
command = _HelpCommandImpl(self, **self.command_attrs) command = _HelpCommandImpl(self, **self.command_attrs)
bot.add_command(command) bot.add_command(command)
self._command_impl = command self._command_impl = command
def _remove_from_bot(self, bot): def _remove_from_bot(self, bot: BotBase) -> None:
bot.remove_command(self._command_impl.name) bot.remove_command(self._command_impl.name)
self._command_impl._eject_cog() self._command_impl._eject_cog()
def add_check(self, func, /): def add_check(self, func: Check[ContextT], /) -> None:
""" """
Adds a check to the help command. Adds a check to the help command.
@ -355,7 +398,7 @@ class HelpCommand:
self._command_impl.add_check(func) self._command_impl.add_check(func)
def remove_check(self, func, /): def remove_check(self, func: Check[ContextT], /) -> None:
""" """
Removes a check from the help command. Removes a check from the help command.
@ -376,15 +419,15 @@ class HelpCommand:
self._command_impl.remove_check(func) self._command_impl.remove_check(func)
def get_bot_mapping(self): def get_bot_mapping(self) -> Dict[Optional[Cog], List[Command[Any, ..., Any]]]:
"""Retrieves the bot mapping passed to :meth:`send_bot_help`.""" """Retrieves the bot mapping passed to :meth:`send_bot_help`."""
bot = self.context.bot bot = self.context.bot
mapping = {cog: cog.get_commands() for cog in bot.cogs.values()} mapping: Dict[Optional[Cog], List[Command[Any, ..., Any]]] = {cog: cog.get_commands() for cog in bot.cogs.values()}
mapping[None] = [c for c in bot.commands if c.cog is None] mapping[None] = [c for c in bot.commands if c.cog is None]
return mapping return mapping
@property @property
def invoked_with(self): def invoked_with(self) -> Optional[str]:
"""Similar to :attr:`Context.invoked_with` except properly handles """Similar to :attr:`Context.invoked_with` except properly handles
the case where :meth:`Context.send_help` is used. the case where :meth:`Context.send_help` is used.
@ -395,7 +438,7 @@ class HelpCommand:
Returns Returns
--------- ---------
:class:`str` Optional[:class:`str`]
The command name that triggered this invocation. The command name that triggered this invocation.
""" """
command_name = self._command_impl.name command_name = self._command_impl.name
@ -404,7 +447,7 @@ class HelpCommand:
return command_name return command_name
return ctx.invoked_with return ctx.invoked_with
def get_command_signature(self, command): def get_command_signature(self, command: Command[Any, ..., Any]) -> str:
"""Retrieves the signature portion of the help page. """Retrieves the signature portion of the help page.
Parameters Parameters
@ -418,14 +461,14 @@ class HelpCommand:
The signature for the command. The signature for the command.
""" """
parent = command.parent parent: Optional[Group[Any, ..., Any]] = command.parent # type: ignore - the parent will be a Group
entries = [] entries = []
while parent is not None: while parent is not None:
if not parent.signature or parent.invoke_without_command: if not parent.signature or parent.invoke_without_command:
entries.append(parent.name) entries.append(parent.name)
else: else:
entries.append(parent.name + ' ' + parent.signature) entries.append(parent.name + ' ' + parent.signature)
parent = parent.parent parent = parent.parent # type: ignore
parent_sig = ' '.join(reversed(entries)) parent_sig = ' '.join(reversed(entries))
if len(command.aliases) > 0: if len(command.aliases) > 0:
@ -439,7 +482,7 @@ class HelpCommand:
return f'{self.context.clean_prefix}{alias} {command.signature}' return f'{self.context.clean_prefix}{alias} {command.signature}'
def remove_mentions(self, string): def remove_mentions(self, string: str) -> str:
"""Removes mentions from the string to prevent abuse. """Removes mentions from the string to prevent abuse.
This includes ``@everyone``, ``@here``, member mentions and role mentions. This includes ``@everyone``, ``@here``, member mentions and role mentions.
@ -450,13 +493,13 @@ class HelpCommand:
The string with mentions removed. The string with mentions removed.
""" """
def replace(obj, *, transforms=self.MENTION_TRANSFORMS): def replace(obj: re.Match, *, transforms: Dict[str, str] = self.MENTION_TRANSFORMS) -> str:
return transforms.get(obj.group(0), '@invalid') return transforms.get(obj.group(0), '@invalid')
return self.MENTION_PATTERN.sub(replace, string) return self.MENTION_PATTERN.sub(replace, string)
@property @property
def cog(self): def cog(self) -> Optional[Cog]:
"""A property for retrieving or setting the cog for the help command. """A property for retrieving or setting the cog for the help command.
When a cog is set for the help command, it is as-if the help command When a cog is set for the help command, it is as-if the help command
@ -473,7 +516,7 @@ class HelpCommand:
return self._command_impl.cog return self._command_impl.cog
@cog.setter @cog.setter
def cog(self, cog): def cog(self, cog: Optional[Cog]) -> None:
# Remove whatever cog is currently valid, if any # Remove whatever cog is currently valid, if any
self._command_impl._eject_cog() self._command_impl._eject_cog()
@ -481,7 +524,7 @@ class HelpCommand:
if cog is not None: if cog is not None:
self._command_impl._inject_into_cog(cog) self._command_impl._inject_into_cog(cog)
def command_not_found(self, string): def command_not_found(self, string: str) -> str:
"""|maybecoro| """|maybecoro|
A method called when a command is not found in the help command. A method called when a command is not found in the help command.
@ -502,7 +545,7 @@ class HelpCommand:
""" """
return f'No command called "{string}" found.' return f'No command called "{string}" found.'
def subcommand_not_found(self, command, string): def subcommand_not_found(self, command: Command[Any, ..., Any], string: str) -> str:
"""|maybecoro| """|maybecoro|
A method called when a command did not have a subcommand requested in the help command. A method called when a command did not have a subcommand requested in the help command.
@ -532,7 +575,13 @@ class HelpCommand:
return f'Command "{command.qualified_name}" has no subcommand named {string}' return f'Command "{command.qualified_name}" has no subcommand named {string}'
return f'Command "{command.qualified_name}" has no subcommands.' return f'Command "{command.qualified_name}" has no subcommands.'
async def filter_commands(self, commands, *, sort=False, key=None): async def filter_commands(
self,
commands: Iterable[Command[Any, ..., Any]],
*,
sort: bool = False,
key: Optional[Callable[[Command[Any, ..., Any]], Any]] = None,
) -> List[Command[Any, ..., Any]]:
"""|coro| """|coro|
Returns a filtered list of commands and optionally sorts them. Returns a filtered list of commands and optionally sorts them.
@ -546,7 +595,7 @@ class HelpCommand:
An iterable of commands that are getting filtered. An iterable of commands that are getting filtered.
sort: :class:`bool` sort: :class:`bool`
Whether to sort the result. Whether to sort the result.
key: Optional[Callable[:class:`Command`, Any]] key: Optional[Callable[[:class:`Command`], Any]]
An optional key function to pass to :func:`py:sorted` that An optional key function to pass to :func:`py:sorted` that
takes a :class:`Command` as its sole parameter. If ``sort`` is takes a :class:`Command` as its sole parameter. If ``sort`` is
passed as ``True`` then this will default as the command name. passed as ``True`` then this will default as the command name.
@ -565,14 +614,14 @@ class HelpCommand:
if self.verify_checks is False: if self.verify_checks is False:
# if we do not need to verify the checks then we can just # if we do not need to verify the checks then we can just
# run it straight through normally without using await. # run it straight through normally without using await.
return sorted(iterator, key=key) if sort else list(iterator) return sorted(iterator, key=key) if sort else list(iterator) # type: ignore - the key shouldn't be None
if self.verify_checks is None and not self.context.guild: if self.verify_checks is None and not self.context.guild:
# if verify_checks is None and we're in a DM, don't verify # if verify_checks is None and we're in a DM, don't verify
return sorted(iterator, key=key) if sort else list(iterator) return sorted(iterator, key=key) if sort else list(iterator) # type: ignore
# if we're here then we need to check every command if it can run # if we're here then we need to check every command if it can run
async def predicate(cmd): async def predicate(cmd: Command[Any, ..., Any]) -> bool:
try: try:
return await cmd.can_run(self.context) return await cmd.can_run(self.context)
except CommandError: except CommandError:
@ -588,7 +637,7 @@ class HelpCommand:
ret.sort(key=key) ret.sort(key=key)
return ret return ret
def get_max_size(self, commands): def get_max_size(self, commands: Sequence[Command[Any, ..., Any]]) -> int:
"""Returns the largest name length of the specified command list. """Returns the largest name length of the specified command list.
Parameters Parameters
@ -605,7 +654,7 @@ class HelpCommand:
as_lengths = (discord.utils._string_width(c.name) for c in commands) as_lengths = (discord.utils._string_width(c.name) for c in commands)
return max(as_lengths, default=0) return max(as_lengths, default=0)
def get_destination(self): def get_destination(self) -> discord.abc.MessageableChannel:
"""Returns the :class:`~discord.abc.Messageable` where the help command will be output. """Returns the :class:`~discord.abc.Messageable` where the help command will be output.
You can override this method to customise the behaviour. You can override this method to customise the behaviour.
@ -619,7 +668,7 @@ class HelpCommand:
""" """
return self.context.channel return self.context.channel
async def send_error_message(self, error): async def send_error_message(self, error: str) -> None:
"""|coro| """|coro|
Handles the implementation when an error happens in the help command. Handles the implementation when an error happens in the help command.
@ -644,7 +693,7 @@ class HelpCommand:
await destination.send(error) await destination.send(error)
@_not_overridden @_not_overridden
async def on_help_command_error(self, ctx, error): async def on_help_command_error(self, ctx: Context[BotT], error: CommandError) -> None:
"""|coro| """|coro|
The help command's error handler, as specified by :ref:`ext_commands_error_handler`. The help command's error handler, as specified by :ref:`ext_commands_error_handler`.
@ -664,7 +713,7 @@ class HelpCommand:
""" """
pass pass
async def send_bot_help(self, mapping): async def send_bot_help(self, mapping: Mapping[Optional[Cog], List[Command[Any, ..., Any]]]) -> None:
"""|coro| """|coro|
Handles the implementation of the bot command page in the help command. Handles the implementation of the bot command page in the help command.
@ -693,7 +742,7 @@ class HelpCommand:
""" """
return None return None
async def send_cog_help(self, cog): async def send_cog_help(self, cog: Cog) -> None:
"""|coro| """|coro|
Handles the implementation of the cog page in the help command. Handles the implementation of the cog page in the help command.
@ -721,7 +770,7 @@ class HelpCommand:
""" """
return None return None
async def send_group_help(self, group): async def send_group_help(self, group: Group[Any, ..., Any]) -> None:
"""|coro| """|coro|
Handles the implementation of the group page in the help command. Handles the implementation of the group page in the help command.
@ -749,7 +798,7 @@ class HelpCommand:
""" """
return None return None
async def send_command_help(self, command): async def send_command_help(self, command: Command[Any, ..., Any]) -> None:
"""|coro| """|coro|
Handles the implementation of the single command page in the help command. Handles the implementation of the single command page in the help command.
@ -787,7 +836,7 @@ class HelpCommand:
""" """
return None return None
async def prepare_help_command(self, ctx, command=None): async def prepare_help_command(self, ctx: Context[BotT], command: Optional[str] = None) -> None:
"""|coro| """|coro|
A low level method that can be used to prepare the help command A low level method that can be used to prepare the help command
@ -811,7 +860,7 @@ class HelpCommand:
""" """
pass pass
async def command_callback(self, ctx, *, command=None): async def command_callback(self, ctx: Context[BotT], *, command: Optional[str] = None) -> None:
"""|coro| """|coro|
The actual implementation of the help command. The actual implementation of the help command.
@ -856,7 +905,7 @@ class HelpCommand:
for key in keys[1:]: for key in keys[1:]:
try: try:
found = cmd.all_commands.get(key) found = cmd.all_commands.get(key) # type: ignore
except AttributeError: except AttributeError:
string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key)) string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key))
return await self.send_error_message(string) return await self.send_error_message(string)
@ -908,28 +957,28 @@ class DefaultHelpCommand(HelpCommand):
The paginator used to paginate the help command output. The paginator used to paginate the help command output.
""" """
def __init__(self, **options): def __init__(self, **options: Any) -> None:
self.width = options.pop('width', 80) self.width: int = options.pop('width', 80)
self.indent = options.pop('indent', 2) self.indent: int = options.pop('indent', 2)
self.sort_commands = options.pop('sort_commands', True) self.sort_commands: bool = options.pop('sort_commands', True)
self.dm_help = options.pop('dm_help', False) self.dm_help: bool = options.pop('dm_help', False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000) self.dm_help_threshold: int = options.pop('dm_help_threshold', 1000)
self.commands_heading = options.pop('commands_heading', "Commands:") self.commands_heading: str = options.pop('commands_heading', "Commands:")
self.no_category = options.pop('no_category', 'No Category') self.no_category: str = options.pop('no_category', 'No Category')
self.paginator = options.pop('paginator', None) self.paginator: Paginator = options.pop('paginator', None)
if self.paginator is None: if self.paginator is None:
self.paginator = Paginator() self.paginator: Paginator = Paginator()
super().__init__(**options) super().__init__(**options)
def shorten_text(self, text): def shorten_text(self, text: str) -> str:
""":class:`str`: Shortens text to fit into the :attr:`width`.""" """:class:`str`: Shortens text to fit into the :attr:`width`."""
if len(text) > self.width: if len(text) > self.width:
return text[: self.width - 3].rstrip() + '...' return text[: self.width - 3].rstrip() + '...'
return text return text
def get_ending_note(self): def get_ending_note(self) -> str:
""":class:`str`: Returns help command's ending note. This is mainly useful to override for i18n purposes.""" """:class:`str`: Returns help command's ending note. This is mainly useful to override for i18n purposes."""
command_name = self.invoked_with command_name = self.invoked_with
return ( return (
@ -937,7 +986,9 @@ class DefaultHelpCommand(HelpCommand):
f"You can also type {self.context.clean_prefix}{command_name} category for more info on a category." f"You can also type {self.context.clean_prefix}{command_name} category for more info on a category."
) )
def add_indented_commands(self, commands, *, heading, max_size=None): def add_indented_commands(
self, commands: Sequence[Command[Any, ..., Any]], *, heading: str, max_size: Optional[int] = None
) -> None:
"""Indents a list of commands after the specified heading. """Indents a list of commands after the specified heading.
The formatting is added to the :attr:`paginator`. The formatting is added to the :attr:`paginator`.
@ -973,13 +1024,13 @@ class DefaultHelpCommand(HelpCommand):
entry = f'{self.indent * " "}{name:<{width}} {command.short_doc}' entry = f'{self.indent * " "}{name:<{width}} {command.short_doc}'
self.paginator.add_line(self.shorten_text(entry)) self.paginator.add_line(self.shorten_text(entry))
async def send_pages(self): async def send_pages(self) -> None:
"""A helper utility to send the page output from :attr:`paginator` to the destination.""" """A helper utility to send the page output from :attr:`paginator` to the destination."""
destination = self.get_destination() destination = self.get_destination()
for page in self.paginator.pages: for page in self.paginator.pages:
await destination.send(page) await destination.send(page)
def add_command_formatting(self, command): def add_command_formatting(self, command: Command[Any, ..., Any]) -> None:
"""A utility function to format the non-indented block of commands and groups. """A utility function to format the non-indented block of commands and groups.
Parameters Parameters
@ -1002,7 +1053,7 @@ class DefaultHelpCommand(HelpCommand):
self.paginator.add_line(line) self.paginator.add_line(line)
self.paginator.add_line() self.paginator.add_line()
def get_destination(self): def get_destination(self) -> discord.abc.Messageable:
ctx = self.context ctx = self.context
if self.dm_help is True: if self.dm_help is True:
return ctx.author return ctx.author
@ -1011,11 +1062,11 @@ class DefaultHelpCommand(HelpCommand):
else: else:
return ctx.channel return ctx.channel
async def prepare_help_command(self, ctx, command): async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None:
self.paginator.clear() self.paginator.clear()
await super().prepare_help_command(ctx, command) await super().prepare_help_command(ctx, command)
async def send_bot_help(self, mapping): async def send_bot_help(self, mapping: Mapping[Optional[Cog], List[Command[Any, ..., Any]]]) -> None:
ctx = self.context ctx = self.context
bot = ctx.bot bot = ctx.bot
@ -1045,12 +1096,12 @@ class DefaultHelpCommand(HelpCommand):
await self.send_pages() await self.send_pages()
async def send_command_help(self, command): async def send_command_help(self, command: Command[Any, ..., Any]) -> None:
self.add_command_formatting(command) self.add_command_formatting(command)
self.paginator.close_page() self.paginator.close_page()
await self.send_pages() await self.send_pages()
async def send_group_help(self, group): async def send_group_help(self, group: Group[Any, ..., Any]) -> None:
self.add_command_formatting(group) self.add_command_formatting(group)
filtered = await self.filter_commands(group.commands, sort=self.sort_commands) filtered = await self.filter_commands(group.commands, sort=self.sort_commands)
@ -1064,7 +1115,7 @@ class DefaultHelpCommand(HelpCommand):
await self.send_pages() await self.send_pages()
async def send_cog_help(self, cog): async def send_cog_help(self, cog: Cog) -> None:
if cog.description: if cog.description:
self.paginator.add_line(cog.description, empty=True) self.paginator.add_line(cog.description, empty=True)
@ -1111,27 +1162,27 @@ class MinimalHelpCommand(HelpCommand):
The paginator used to paginate the help command output. The paginator used to paginate the help command output.
""" """
def __init__(self, **options): def __init__(self, **options: Any) -> None:
self.sort_commands = options.pop('sort_commands', True) self.sort_commands: bool = options.pop('sort_commands', True)
self.commands_heading = options.pop('commands_heading', "Commands") self.commands_heading: str = options.pop('commands_heading', "Commands")
self.dm_help = options.pop('dm_help', False) self.dm_help: bool = options.pop('dm_help', False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000) self.dm_help_threshold: int = options.pop('dm_help_threshold', 1000)
self.aliases_heading = options.pop('aliases_heading', "Aliases:") self.aliases_heading: str = options.pop('aliases_heading', "Aliases:")
self.no_category = options.pop('no_category', 'No Category') self.no_category: str = options.pop('no_category', 'No Category')
self.paginator = options.pop('paginator', None) self.paginator: Paginator = options.pop('paginator', None)
if self.paginator is None: if self.paginator is None:
self.paginator = Paginator(suffix=None, prefix=None) self.paginator: Paginator = Paginator(suffix=None, prefix=None)
super().__init__(**options) super().__init__(**options)
async def send_pages(self): async def send_pages(self) -> None:
"""A helper utility to send the page output from :attr:`paginator` to the destination.""" """A helper utility to send the page output from :attr:`paginator` to the destination."""
destination = self.get_destination() destination = self.get_destination()
for page in self.paginator.pages: for page in self.paginator.pages:
await destination.send(page) await destination.send(page)
def get_opening_note(self): def get_opening_note(self) -> str:
"""Returns help command's opening note. This is mainly useful to override for i18n purposes. """Returns help command's opening note. This is mainly useful to override for i18n purposes.
The default implementation returns :: The default implementation returns ::
@ -1150,10 +1201,10 @@ class MinimalHelpCommand(HelpCommand):
f"You can also use `{self.context.clean_prefix}{command_name} [category]` for more info on a category." f"You can also use `{self.context.clean_prefix}{command_name} [category]` for more info on a category."
) )
def get_command_signature(self, command): def get_command_signature(self, command: Command[Any, ..., Any]) -> str:
return f'{self.context.clean_prefix}{command.qualified_name} {command.signature}' return f'{self.context.clean_prefix}{command.qualified_name} {command.signature}'
def get_ending_note(self): def get_ending_note(self) -> str:
"""Return the help command's ending note. This is mainly useful to override for i18n purposes. """Return the help command's ending note. This is mainly useful to override for i18n purposes.
The default implementation does nothing. The default implementation does nothing.
@ -1163,9 +1214,9 @@ class MinimalHelpCommand(HelpCommand):
:class:`str` :class:`str`
The help command ending note. The help command ending note.
""" """
return None return ''
def add_bot_commands_formatting(self, commands, heading): def add_bot_commands_formatting(self, commands: Sequence[Command[Any, ..., Any]], heading: str) -> None:
"""Adds the minified bot heading with commands to the output. """Adds the minified bot heading with commands to the output.
The formatting should be added to the :attr:`paginator`. The formatting should be added to the :attr:`paginator`.
@ -1186,7 +1237,7 @@ class MinimalHelpCommand(HelpCommand):
self.paginator.add_line(f'__**{heading}**__') self.paginator.add_line(f'__**{heading}**__')
self.paginator.add_line(joined) self.paginator.add_line(joined)
def add_subcommand_formatting(self, command): def add_subcommand_formatting(self, command: Command[Any, ..., Any]) -> None:
"""Adds formatting information on a subcommand. """Adds formatting information on a subcommand.
The formatting should be added to the :attr:`paginator`. The formatting should be added to the :attr:`paginator`.
@ -1202,7 +1253,7 @@ class MinimalHelpCommand(HelpCommand):
fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}' fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}'
self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc)) self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc))
def add_aliases_formatting(self, aliases): def add_aliases_formatting(self, aliases: Sequence[str]) -> None:
"""Adds the formatting information on a command's aliases. """Adds the formatting information on a command's aliases.
The formatting should be added to the :attr:`paginator`. The formatting should be added to the :attr:`paginator`.
@ -1219,7 +1270,7 @@ class MinimalHelpCommand(HelpCommand):
""" """
self.paginator.add_line(f'**{self.aliases_heading}** {", ".join(aliases)}', empty=True) self.paginator.add_line(f'**{self.aliases_heading}** {", ".join(aliases)}', empty=True)
def add_command_formatting(self, command): def add_command_formatting(self, command: Command[Any, ..., Any]) -> None:
"""A utility function to format commands and groups. """A utility function to format commands and groups.
Parameters Parameters
@ -1246,7 +1297,7 @@ class MinimalHelpCommand(HelpCommand):
self.paginator.add_line(line) self.paginator.add_line(line)
self.paginator.add_line() self.paginator.add_line()
def get_destination(self): def get_destination(self) -> discord.abc.Messageable:
ctx = self.context ctx = self.context
if self.dm_help is True: if self.dm_help is True:
return ctx.author return ctx.author
@ -1255,11 +1306,11 @@ class MinimalHelpCommand(HelpCommand):
else: else:
return ctx.channel return ctx.channel
async def prepare_help_command(self, ctx, command): async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None:
self.paginator.clear() self.paginator.clear()
await super().prepare_help_command(ctx, command) await super().prepare_help_command(ctx, command)
async def send_bot_help(self, mapping): async def send_bot_help(self, mapping: Mapping[Optional[Cog], List[Command[Any, ..., Any]]]) -> None:
ctx = self.context ctx = self.context
bot = ctx.bot bot = ctx.bot
@ -1272,7 +1323,7 @@ class MinimalHelpCommand(HelpCommand):
no_category = f'\u200b{self.no_category}' no_category = f'\u200b{self.no_category}'
def get_category(command, *, no_category=no_category): def get_category(command: Command[Any, ..., Any], *, no_category: str = no_category) -> str:
cog = command.cog cog = command.cog
return cog.qualified_name if cog is not None else no_category return cog.qualified_name if cog is not None else no_category
@ -1290,7 +1341,7 @@ class MinimalHelpCommand(HelpCommand):
await self.send_pages() await self.send_pages()
async def send_cog_help(self, cog): async def send_cog_help(self, cog: Cog) -> None:
bot = self.context.bot bot = self.context.bot
if bot.description: if bot.description:
self.paginator.add_line(bot.description, empty=True) self.paginator.add_line(bot.description, empty=True)
@ -1315,7 +1366,7 @@ class MinimalHelpCommand(HelpCommand):
await self.send_pages() await self.send_pages()
async def send_group_help(self, group): async def send_group_help(self, group: Group[Any, ..., Any]) -> None:
self.add_command_formatting(group) self.add_command_formatting(group)
filtered = await self.filter_commands(group.commands, sort=self.sort_commands) filtered = await self.filter_commands(group.commands, sort=self.sort_commands)
@ -1335,7 +1386,7 @@ class MinimalHelpCommand(HelpCommand):
await self.send_pages() await self.send_pages()
async def send_command_help(self, command): async def send_command_help(self, command: Command[Any, ..., Any]) -> None:
self.add_command_formatting(command) self.add_command_formatting(command)
self.paginator.close_page() self.paginator.close_page()
await self.send_pages() await self.send_pages()

37
discord/ext/commands/view.py

@ -21,6 +21,11 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import Optional
from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError
# map from opening quotes to closing quotes # map from opening quotes to closing quotes
@ -47,24 +52,24 @@ _all_quotes = set(_quotes.keys()) | set(_quotes.values())
class StringView: class StringView:
def __init__(self, buffer): def __init__(self, buffer: str) -> None:
self.index = 0 self.index: int = 0
self.buffer = buffer self.buffer: str = buffer
self.end = len(buffer) self.end: int = len(buffer)
self.previous = 0 self.previous = 0
@property @property
def current(self): def current(self) -> Optional[str]:
return None if self.eof else self.buffer[self.index] return None if self.eof else self.buffer[self.index]
@property @property
def eof(self): def eof(self) -> bool:
return self.index >= self.end return self.index >= self.end
def undo(self): def undo(self) -> None:
self.index = self.previous self.index = self.previous
def skip_ws(self): def skip_ws(self) -> bool:
pos = 0 pos = 0
while not self.eof: while not self.eof:
try: try:
@ -79,7 +84,7 @@ class StringView:
self.index += pos self.index += pos
return self.previous != self.index return self.previous != self.index
def skip_string(self, string): def skip_string(self, string: str) -> bool:
strlen = len(string) strlen = len(string)
if self.buffer[self.index : self.index + strlen] == string: if self.buffer[self.index : self.index + strlen] == string:
self.previous = self.index self.previous = self.index
@ -87,19 +92,19 @@ class StringView:
return True return True
return False return False
def read_rest(self): def read_rest(self) -> str:
result = self.buffer[self.index :] result = self.buffer[self.index :]
self.previous = self.index self.previous = self.index
self.index = self.end self.index = self.end
return result return result
def read(self, n): def read(self, n: int) -> str:
result = self.buffer[self.index : self.index + n] result = self.buffer[self.index : self.index + n]
self.previous = self.index self.previous = self.index
self.index += n self.index += n
return result return result
def get(self): def get(self) -> Optional[str]:
try: try:
result = self.buffer[self.index + 1] result = self.buffer[self.index + 1]
except IndexError: except IndexError:
@ -109,7 +114,7 @@ class StringView:
self.index += 1 self.index += 1
return result return result
def get_word(self): def get_word(self) -> str:
pos = 0 pos = 0
while not self.eof: while not self.eof:
try: try:
@ -119,12 +124,12 @@ class StringView:
pos += 1 pos += 1
except IndexError: except IndexError:
break break
self.previous = self.index self.previous: int = self.index
result = self.buffer[self.index : self.index + pos] result = self.buffer[self.index : self.index + pos]
self.index += pos self.index += pos
return result return result
def get_quoted_word(self): def get_quoted_word(self) -> Optional[str]:
current = self.current current = self.current
if current is None: if current is None:
return None return None
@ -187,5 +192,5 @@ class StringView:
result.append(current) result.append(current)
def __repr__(self): def __repr__(self) -> str:
return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>' return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>'

8
discord/ext/tasks/__init__.py

@ -110,15 +110,15 @@ class SleepHandle:
__slots__ = ('future', 'loop', 'handle') __slots__ = ('future', 'loop', 'handle')
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None: def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop self.loop: asyncio.AbstractEventLoop = loop
self.future = future = loop.create_future() self.future: asyncio.Future[None] = loop.create_future()
relative_delta = discord.utils.compute_timedelta(dt) relative_delta = discord.utils.compute_timedelta(dt)
self.handle = loop.call_later(relative_delta, future.set_result, True) self.handle = loop.call_later(relative_delta, self.future.set_result, True)
def recalculate(self, dt: datetime.datetime) -> None: def recalculate(self, dt: datetime.datetime) -> None:
self.handle.cancel() self.handle.cancel()
relative_delta = discord.utils.compute_timedelta(dt) relative_delta = discord.utils.compute_timedelta(dt)
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True) self.handle: asyncio.TimerHandle = self.loop.call_later(relative_delta, self.future.set_result, True)
def wait(self) -> asyncio.Future[Any]: def wait(self) -> asyncio.Future[Any]:
return self.future return self.future

2
discord/file.py

@ -74,7 +74,7 @@ class File:
def __init__( def __init__(
self, self,
fp: Union[str, bytes, os.PathLike, io.BufferedIOBase], fp: Union[str, bytes, os.PathLike[Any], io.BufferedIOBase],
filename: Optional[str] = None, filename: Optional[str] = None,
*, *,
spoiler: bool = False, spoiler: bool = False,

20
discord/flags.py

@ -46,8 +46,8 @@ BF = TypeVar('BF', bound='BaseFlags')
class flag_value: class flag_value:
def __init__(self, func: Callable[[Any], int]): def __init__(self, func: Callable[[Any], int]):
self.flag = func(None) self.flag: int = func(None)
self.__doc__ = func.__doc__ self.__doc__: Optional[str] = func.__doc__
@overload @overload
def __get__(self, instance: None, owner: Type[BF]) -> Self: def __get__(self, instance: None, owner: Type[BF]) -> Self:
@ -65,7 +65,7 @@ class flag_value:
def __set__(self, instance: BaseFlags, value: bool) -> None: def __set__(self, instance: BaseFlags, value: bool) -> None:
instance._set_flag(self.flag, value) instance._set_flag(self.flag, value)
def __repr__(self): def __repr__(self) -> str:
return f'<flag_value flag={self.flag!r}>' return f'<flag_value flag={self.flag!r}>'
@ -73,8 +73,8 @@ class alias_flag_value(flag_value):
pass pass
def fill_with_flags(*, inverted: bool = False): def fill_with_flags(*, inverted: bool = False) -> Callable[[Type[BF]], Type[BF]]:
def decorator(cls: Type[BF]): def decorator(cls: Type[BF]) -> Type[BF]:
# fmt: off # fmt: off
cls.VALID_FLAGS = { cls.VALID_FLAGS = {
name: value.flag name: value.flag
@ -116,10 +116,10 @@ class BaseFlags:
self.value = value self.value = value
return self return self
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.value == other.value return isinstance(other, self.__class__) and self.value == other.value
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:
@ -504,8 +504,8 @@ class Intents(BaseFlags):
__slots__ = () __slots__ = ()
def __init__(self, **kwargs: bool): def __init__(self, **kwargs: bool) -> None:
self.value = self.DEFAULT_VALUE self.value: int = self.DEFAULT_VALUE
for key, value in kwargs.items(): for key, value in kwargs.items():
if key not in self.VALID_FLAGS: if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid flag name.') raise TypeError(f'{key!r} is not a valid flag name.')
@ -1005,7 +1005,7 @@ class MemberCacheFlags(BaseFlags):
def __init__(self, **kwargs: bool): def __init__(self, **kwargs: bool):
bits = max(self.VALID_FLAGS.values()).bit_length() bits = max(self.VALID_FLAGS.values()).bit_length()
self.value = (1 << bits) - 1 self.value: int = (1 << bits) - 1
for key, value in kwargs.items(): for key, value in kwargs.items():
if key not in self.VALID_FLAGS: if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid flag name.') raise TypeError(f'{key!r} is not a valid flag name.')

30
discord/gateway.py

@ -54,6 +54,8 @@ __all__ = (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from .client import Client from .client import Client
from .state import ConnectionState from .state import ConnectionState
from .voice_client import VoiceClient from .voice_client import VoiceClient
@ -62,10 +64,10 @@ if TYPE_CHECKING:
class ReconnectWebSocket(Exception): class ReconnectWebSocket(Exception):
"""Signals to safely reconnect the websocket.""" """Signals to safely reconnect the websocket."""
def __init__(self, shard_id, *, resume=True): def __init__(self, shard_id: Optional[int], *, resume: bool = True) -> None:
self.shard_id = shard_id self.shard_id: Optional[int] = shard_id
self.resume = resume self.resume: bool = resume
self.op = 'RESUME' if resume else 'IDENTIFY' self.op: str = 'RESUME' if resume else 'IDENTIFY'
class WebSocketClosure(Exception): class WebSocketClosure(Exception):
@ -225,7 +227,7 @@ class VoiceKeepAliveHandler(KeepAliveHandler):
ack_time = time.perf_counter() ack_time = time.perf_counter()
self._last_ack = ack_time self._last_ack = ack_time
self._last_recv = ack_time self._last_recv = ack_time
self.latency = ack_time - self._last_send self.latency: float = ack_time - self._last_send
self.recent_ack_latencies.append(self.latency) self.recent_ack_latencies.append(self.latency)
@ -339,7 +341,7 @@ class DiscordWebSocket:
@classmethod @classmethod
async def from_client( async def from_client(
cls: Type[DWS], cls,
client: Client, client: Client,
*, *,
initial: bool = False, initial: bool = False,
@ -348,7 +350,7 @@ class DiscordWebSocket:
session: Optional[str] = None, session: Optional[str] = None,
sequence: Optional[int] = None, sequence: Optional[int] = None,
resume: bool = False, resume: bool = False,
) -> DWS: ) -> Self:
"""Creates a main websocket for Discord from a :class:`Client`. """Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only. This is for internal use only.
@ -821,11 +823,11 @@ class DiscordVoiceWebSocket:
*, *,
hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None,
) -> None: ) -> None:
self.ws = socket self.ws: aiohttp.ClientWebSocketResponse = socket
self.loop = loop self.loop: asyncio.AbstractEventLoop = loop
self._keep_alive = None self._keep_alive: Optional[VoiceKeepAliveHandler] = None
self._close_code = None self._close_code: Optional[int] = None
self.secret_key = None self.secret_key: Optional[str] = None
if hook: if hook:
self._hook = hook # type: ignore - type-checker doesn't like overriding methods self._hook = hook # type: ignore - type-checker doesn't like overriding methods
@ -864,7 +866,9 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
@classmethod @classmethod
async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume=False, hook=None) -> DVWS: async def from_client(
cls, client: VoiceClient, *, resume: bool = False, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None
) -> Self:
"""Creates a voice websocket for the :class:`VoiceClient`.""" """Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4' gateway = 'wss://' + client.endpoint + '/?v=4'
http = client._state.http http = client._state.http

5
discord/guild.py

@ -123,6 +123,7 @@ if TYPE_CHECKING:
) )
from .types.integration import IntegrationType from .types.integration import IntegrationType
from .types.snowflake import SnowflakeList from .types.snowflake import SnowflakeList
from .types.widget import EditWidgetSettings
VocalGuildChannel = Union[VoiceChannel, StageChannel] VocalGuildChannel = Union[VoiceChannel, StageChannel]
GuildChannel = Union[VocalGuildChannel, TextChannel, CategoryChannel, StoreChannel] GuildChannel = Union[VocalGuildChannel, TextChannel, CategoryChannel, StoreChannel]
@ -3379,7 +3380,7 @@ class Guild(Hashable):
HTTPException HTTPException
Editing the widget failed. Editing the widget failed.
""" """
payload = {} payload: EditWidgetSettings = {}
if channel is not MISSING: if channel is not MISSING:
payload['channel_id'] = None if channel is None else channel.id payload['channel_id'] = None if channel is None else channel.id
if enabled is not MISSING: if enabled is not MISSING:
@ -3492,7 +3493,7 @@ class Guild(Hashable):
async def change_voice_state( async def change_voice_state(
self, *, channel: Optional[abc.Snowflake], self_mute: bool = False, self_deaf: bool = False self, *, channel: Optional[abc.Snowflake], self_mute: bool = False, self_deaf: bool = False
): ) -> None:
"""|coro| """|coro|
Changes client's voice state in the guild. Changes client's voice state in the guild.

22
discord/http.py

@ -76,12 +76,9 @@ if TYPE_CHECKING:
audit_log, audit_log,
channel, channel,
command, command,
components,
emoji, emoji,
embed,
guild, guild,
integration, integration,
interactions,
invite, invite,
member, member,
message, message,
@ -92,7 +89,6 @@ if TYPE_CHECKING:
channel, channel,
widget, widget,
threads, threads,
voice,
scheduled_event, scheduled_event,
sticker, sticker,
) )
@ -122,7 +118,7 @@ class MultipartParameters(NamedTuple):
multipart: Optional[List[Dict[str, Any]]] multipart: Optional[List[Dict[str, Any]]]
files: Optional[List[File]] files: Optional[List[File]]
def __enter__(self): def __enter__(self) -> Self:
return self return self
def __exit__( def __exit__(
@ -577,7 +573,7 @@ class HTTPClient:
return self.request(Route('POST', '/users/{user_id}/channels', user_id=user_id), json=payload) return self.request(Route('POST', '/users/{user_id}/channels', user_id=user_id), json=payload)
def leave_group(self, channel_id) -> Response[None]: def leave_group(self, channel_id: Snowflake) -> Response[None]:
return self.request(Route('DELETE', '/channels/{channel_id}', channel_id=channel_id)) return self.request(Route('DELETE', '/channels/{channel_id}', channel_id=channel_id))
# Message management # Message management
@ -1160,7 +1156,7 @@ class HTTPClient:
def sync_template(self, guild_id: Snowflake, code: str) -> Response[template.Template]: def sync_template(self, guild_id: Snowflake, code: str) -> Response[template.Template]:
return self.request(Route('PUT', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code)) return self.request(Route('PUT', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code))
def edit_template(self, guild_id: Snowflake, code: str, payload) -> Response[template.Template]: def edit_template(self, guild_id: Snowflake, code: str, payload: Dict[str, Any]) -> Response[template.Template]:
valid_keys = ( valid_keys = (
'name', 'name',
'description', 'description',
@ -1420,7 +1416,9 @@ class HTTPClient:
def get_widget(self, guild_id: Snowflake) -> Response[widget.Widget]: def get_widget(self, guild_id: Snowflake) -> Response[widget.Widget]:
return self.request(Route('GET', '/guilds/{guild_id}/widget.json', guild_id=guild_id)) return self.request(Route('GET', '/guilds/{guild_id}/widget.json', guild_id=guild_id))
def edit_widget(self, guild_id: Snowflake, payload, reason: Optional[str] = None) -> Response[widget.WidgetSettings]: def edit_widget(
self, guild_id: Snowflake, payload: widget.EditWidgetSettings, reason: Optional[str] = None
) -> Response[widget.WidgetSettings]:
return self.request(Route('PATCH', '/guilds/{guild_id}/widget', guild_id=guild_id), json=payload, reason=reason) return self.request(Route('PATCH', '/guilds/{guild_id}/widget', guild_id=guild_id), json=payload, reason=reason)
# Invite management # Invite management
@ -1812,7 +1810,9 @@ class HTTPClient:
) )
return self.request(r) return self.request(r)
def upsert_global_command(self, application_id: Snowflake, payload) -> Response[command.ApplicationCommand]: def upsert_global_command(
self, application_id: Snowflake, payload: command.ApplicationCommand
) -> Response[command.ApplicationCommand]:
r = Route('POST', '/applications/{application_id}/commands', application_id=application_id) r = Route('POST', '/applications/{application_id}/commands', application_id=application_id)
return self.request(r, json=payload) return self.request(r, json=payload)
@ -1845,7 +1845,9 @@ class HTTPClient:
) )
return self.request(r) return self.request(r)
def bulk_upsert_global_commands(self, application_id: Snowflake, payload) -> Response[List[command.ApplicationCommand]]: def bulk_upsert_global_commands(
self, application_id: Snowflake, payload: List[Dict[str, Any]]
) -> Response[List[command.ApplicationCommand]]:
r = Route('PUT', '/applications/{application_id}/commands', application_id=application_id) r = Route('PUT', '/applications/{application_id}/commands', application_id=application_id)
return self.request(r, json=payload) return self.request(r, json=payload)

17
discord/integrations.py

@ -39,6 +39,9 @@ __all__ = (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from .guild import Guild
from .role import Role
from .state import ConnectionState
from .types.integration import ( from .types.integration import (
IntegrationAccount as IntegrationAccountPayload, IntegrationAccount as IntegrationAccountPayload,
Integration as IntegrationPayload, Integration as IntegrationPayload,
@ -47,8 +50,6 @@ if TYPE_CHECKING:
IntegrationType, IntegrationType,
IntegrationApplication as IntegrationApplicationPayload, IntegrationApplication as IntegrationApplicationPayload,
) )
from .guild import Guild
from .role import Role
class IntegrationAccount: class IntegrationAccount:
@ -109,11 +110,11 @@ class Integration:
) )
def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None: def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None:
self.guild = guild self.guild: Guild = guild
self._state = guild._state self._state: ConnectionState = guild._state
self._from_data(data) self._from_data(data)
def __repr__(self): def __repr__(self) -> str:
return f"<{self.__class__.__name__} id={self.id} name={self.name!r}>" return f"<{self.__class__.__name__} id={self.id} name={self.name!r}>"
def _from_data(self, data: IntegrationPayload) -> None: def _from_data(self, data: IntegrationPayload) -> None:
@ -123,7 +124,7 @@ class Integration:
self.account: IntegrationAccount = IntegrationAccount(data['account']) self.account: IntegrationAccount = IntegrationAccount(data['account'])
user = data.get('user') user = data.get('user')
self.user = User(state=self._state, data=user) if user else None self.user: Optional[User] = User(state=self._state, data=user) if user else None
self.enabled: bool = data['enabled'] self.enabled: bool = data['enabled']
async def delete(self, *, reason: Optional[str] = None) -> None: async def delete(self, *, reason: Optional[str] = None) -> None:
@ -319,7 +320,7 @@ class IntegrationApplication:
'user', 'user',
) )
def __init__(self, *, data: IntegrationApplicationPayload, state): def __init__(self, *, data: IntegrationApplicationPayload, state: ConnectionState) -> None:
self.id: int = int(data['id']) self.id: int = int(data['id'])
self.name: str = data['name'] self.name: str = data['name']
self.icon: Optional[str] = data['icon'] self.icon: Optional[str] = data['icon']
@ -358,7 +359,7 @@ class BotIntegration(Integration):
def _from_data(self, data: BotIntegrationPayload) -> None: def _from_data(self, data: BotIntegrationPayload) -> None:
super()._from_data(data) super()._from_data(data)
self.application = IntegrationApplication(data=data['application'], state=self._state) self.application: IntegrationApplication = IntegrationApplication(data=data['application'], state=self._state)
def _integration_factory(value: str) -> Tuple[Type[Integration], str]: def _integration_factory(value: str) -> Tuple[Type[Integration], str]:

7
discord/interactions.py

@ -54,6 +54,9 @@ if TYPE_CHECKING:
Interaction as InteractionPayload, Interaction as InteractionPayload,
InteractionData, InteractionData,
) )
from .types.webhook import (
Webhook as WebhookPayload,
)
from .client import Client from .client import Client
from .guild import Guild from .guild import Guild
from .state import ConnectionState from .state import ConnectionState
@ -229,7 +232,7 @@ class Interaction:
@utils.cached_slot_property('_cs_followup') @utils.cached_slot_property('_cs_followup')
def followup(self) -> Webhook: def followup(self) -> Webhook:
""":class:`Webhook`: Returns the follow up webhook for follow up interactions.""" """:class:`Webhook`: Returns the follow up webhook for follow up interactions."""
payload = { payload: WebhookPayload = {
'id': self.application_id, 'id': self.application_id,
'type': 3, 'type': 3,
'token': self.token, 'token': self.token,
@ -703,7 +706,7 @@ class InteractionResponse:
self._responded = True self._responded = True
async def send_modal(self, modal: Modal, /): async def send_modal(self, modal: Modal, /) -> None:
"""|coro| """|coro|
Responds to this interaction by sending a modal. Responds to this interaction by sending a modal.

4
discord/invite.py

@ -456,7 +456,7 @@ class Invite(Hashable):
guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id) guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id)
channel_id = int(data['channel_id']) channel_id = int(data['channel_id'])
if guild is not None: if guild is not None:
channel = guild.get_channel(channel_id) or Object(id=channel_id) # type: ignore channel = guild.get_channel(channel_id) or Object(id=channel_id)
else: else:
guild = Object(id=guild_id) if guild_id is not None else None guild = Object(id=guild_id) if guild_id is not None else None
channel = Object(id=channel_id) channel = Object(id=channel_id)
@ -539,7 +539,7 @@ class Invite(Hashable):
return self return self
async def delete(self, *, reason: Optional[str] = None): async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro| """|coro|
Revokes the instant invite. Revokes the instant invite.

13
discord/member.py

@ -27,9 +27,8 @@ from __future__ import annotations
import datetime import datetime
import inspect import inspect
import itertools import itertools
import sys
from operator import attrgetter from operator import attrgetter
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, Type
import discord.abc import discord.abc
@ -207,7 +206,7 @@ class _ClientStatus:
return self return self
def flatten_user(cls): def flatten_user(cls: Any) -> Type[Member]:
for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()): for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()):
# ignore private/special methods # ignore private/special methods
if attr.startswith('_'): if attr.startswith('_'):
@ -333,7 +332,7 @@ class Member(discord.abc.Messageable, _UserTag):
default_avatar: Asset default_avatar: Asset
avatar: Optional[Asset] avatar: Optional[Asset]
dm_channel: Optional[DMChannel] dm_channel: Optional[DMChannel]
create_dm = User.create_dm create_dm: Callable[[], Coroutine[Any, Any, DMChannel]]
mutual_guilds: List[Guild] mutual_guilds: List[Guild]
public_flags: PublicUserFlags public_flags: PublicUserFlags
banner: Optional[Asset] banner: Optional[Asset]
@ -369,10 +368,10 @@ class Member(discord.abc.Messageable, _UserTag):
f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>' f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>'
) )
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, _UserTag) and other.id == self.id return isinstance(other, _UserTag) and other.id == self.id
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:
@ -425,7 +424,7 @@ class Member(discord.abc.Messageable, _UserTag):
self._user = member._user self._user = member._user
return self return self
async def _get_channel(self): async def _get_channel(self) -> DMChannel:
ch = await self.create_dm() ch = await self.create_dm()
return ch return ch

8
discord/mentions.py

@ -92,10 +92,10 @@ class AllowedMentions:
roles: Union[bool, List[Snowflake]] = default, roles: Union[bool, List[Snowflake]] = default,
replied_user: bool = default, replied_user: bool = default,
): ):
self.everyone = everyone self.everyone: bool = everyone
self.users = users self.users: Union[bool, List[Snowflake]] = users
self.roles = roles self.roles: Union[bool, List[Snowflake]] = roles
self.replied_user = replied_user self.replied_user: bool = replied_user
@classmethod @classmethod
def all(cls) -> Self: def all(cls) -> Self:

17
discord/message.py

@ -40,6 +40,7 @@ from typing import (
Tuple, Tuple,
ClassVar, ClassVar,
Optional, Optional,
Type,
overload, overload,
) )
@ -71,7 +72,6 @@ if TYPE_CHECKING:
MessageReference as MessageReferencePayload, MessageReference as MessageReferencePayload,
MessageApplication as MessageApplicationPayload, MessageApplication as MessageApplicationPayload,
MessageActivity as MessageActivityPayload, MessageActivity as MessageActivityPayload,
Reaction as ReactionPayload,
) )
from .types.components import Component as ComponentPayload from .types.components import Component as ComponentPayload
@ -87,7 +87,7 @@ if TYPE_CHECKING:
from .abc import GuildChannel, PartialMessageableChannel, MessageableChannel from .abc import GuildChannel, PartialMessageableChannel, MessageableChannel
from .components import Component from .components import Component
from .state import ConnectionState from .state import ConnectionState
from .channel import TextChannel, GroupChannel, DMChannel from .channel import TextChannel
from .mentions import AllowedMentions from .mentions import AllowedMentions
from .user import User from .user import User
from .role import Role from .role import Role
@ -95,6 +95,7 @@ if TYPE_CHECKING:
EmojiInputType = Union[Emoji, PartialEmoji, str] EmojiInputType = Union[Emoji, PartialEmoji, str]
__all__ = ( __all__ = (
'Attachment', 'Attachment',
'Message', 'Message',
@ -104,7 +105,7 @@ __all__ = (
) )
def convert_emoji_reaction(emoji): def convert_emoji_reaction(emoji: Union[EmojiInputType, Reaction]) -> str:
if isinstance(emoji, Reaction): if isinstance(emoji, Reaction):
emoji = emoji.emoji emoji = emoji.emoji
@ -216,7 +217,7 @@ class Attachment(Hashable):
async def save( async def save(
self, self,
fp: Union[io.BufferedIOBase, PathLike], fp: Union[io.BufferedIOBase, PathLike[Any]],
*, *,
seek_begin: bool = True, seek_begin: bool = True,
use_cached: bool = False, use_cached: bool = False,
@ -510,7 +511,7 @@ class MessageReference:
to_message_reference_dict = to_dict to_message_reference_dict = to_dict
def flatten_handlers(cls): def flatten_handlers(cls: Type[Message]) -> Type[Message]:
prefix = len('_handle_') prefix = len('_handle_')
handlers = [ handlers = [
(key[prefix:], value) (key[prefix:], value)
@ -1036,7 +1037,7 @@ class Message(Hashable):
) )
@utils.cached_slot_property('_cs_system_content') @utils.cached_slot_property('_cs_system_content')
def system_content(self): def system_content(self) -> Optional[str]:
r""":class:`str`: A property that returns the content that is rendered r""":class:`str`: A property that returns the content that is rendered
regardless of the :attr:`Message.type`. regardless of the :attr:`Message.type`.
@ -1657,7 +1658,7 @@ class Message(Hashable):
) )
return Thread(guild=self.guild, state=self._state, data=data) return Thread(guild=self.guild, state=self._state, data=data)
async def reply(self, content: Optional[str] = None, **kwargs) -> Message: async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message:
"""|coro| """|coro|
A shortcut method to :meth:`.abc.Messageable.send` to reply to the A shortcut method to :meth:`.abc.Messageable.send` to reply to the
@ -1798,7 +1799,7 @@ class PartialMessage(Hashable):
# Also needed for duck typing purposes # Also needed for duck typing purposes
# n.b. not exposed # n.b. not exposed
pinned = property(None, lambda x, y: None) pinned: Any = property(None, lambda x, y: None)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<PartialMessage id={self.id} channel={self.channel!r}>' return f'<PartialMessage id={self.id} channel={self.channel!r}>'

5
discord/opus.py

@ -363,7 +363,7 @@ class Encoder(_OpusStruct):
_lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0) _lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0)
def set_expected_packet_loss_percent(self, percentage: float) -> None: def set_expected_packet_loss_percent(self, percentage: float) -> None:
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100))))
def encode(self, pcm: bytes, frame_size: int) -> bytes: def encode(self, pcm: bytes, frame_size: int) -> bytes:
max_data_bytes = len(pcm) max_data_bytes = len(pcm)
@ -373,8 +373,7 @@ class Encoder(_OpusStruct):
ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes) ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes)
# array can be initialized with bytes but mypy doesn't know return array.array('b', data[:ret]).tobytes()
return array.array('b', data[:ret]).tobytes() # type: ignore
class Decoder(_OpusStruct): class Decoder(_OpusStruct):

15
discord/partial_emoji.py

@ -42,6 +42,7 @@ if TYPE_CHECKING:
from .state import ConnectionState from .state import ConnectionState
from datetime import datetime from datetime import datetime
from .types.message import PartialEmoji as PartialEmojiPayload from .types.message import PartialEmoji as PartialEmojiPayload
from .types.activity import ActivityEmoji
class _EmojiTag: class _EmojiTag:
@ -99,13 +100,13 @@ class PartialEmoji(_EmojiTag, AssetMixin):
id: Optional[int] id: Optional[int]
def __init__(self, *, name: str, animated: bool = False, id: Optional[int] = None): def __init__(self, *, name: str, animated: bool = False, id: Optional[int] = None):
self.animated = animated self.animated: bool = animated
self.name = name self.name: str = name
self.id = id self.id: Optional[int] = id
self._state: Optional[ConnectionState] = None self._state: Optional[ConnectionState] = None
@classmethod @classmethod
def from_dict(cls, data: Union[PartialEmojiPayload, Dict[str, Any]]) -> Self: def from_dict(cls, data: Union[PartialEmojiPayload, ActivityEmoji, Dict[str, Any]]) -> Self:
return cls( return cls(
animated=data.get('animated', False), animated=data.get('animated', False),
id=utils._get_as_snowflake(data, 'id'), id=utils._get_as_snowflake(data, 'id'),
@ -178,10 +179,10 @@ class PartialEmoji(_EmojiTag, AssetMixin):
return f'<a:{self.name}:{self.id}>' return f'<a:{self.name}:{self.id}>'
return f'<:{self.name}:{self.id}>' return f'<:{self.name}:{self.id}>'
def __repr__(self): def __repr__(self) -> str:
return f'<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>' return f'<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>'
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
if self.is_unicode_emoji(): if self.is_unicode_emoji():
return isinstance(other, PartialEmoji) and self.name == other.name return isinstance(other, PartialEmoji) and self.name == other.name
@ -189,7 +190,7 @@ class PartialEmoji(_EmojiTag, AssetMixin):
return self.id == other.id return self.id == other.id
return False return False
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:

4
discord/permissions.py

@ -276,7 +276,7 @@ class Permissions(BaseFlags):
# So 0000 OP2 0101 -> 0101 # So 0000 OP2 0101 -> 0101
# The OP is base & ~denied. # The OP is base & ~denied.
# The OP2 is base | allowed. # The OP2 is base | allowed.
self.value = (self.value & ~deny) | allow self.value: int = (self.value & ~deny) | allow
@flag_value @flag_value
def create_instant_invite(self) -> int: def create_instant_invite(self) -> int:
@ -691,7 +691,7 @@ class PermissionOverwrite:
setattr(self, key, value) setattr(self, key, value)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, PermissionOverwrite) and self._values == other._values return isinstance(other, PermissionOverwrite) and self._values == other._values
def _set(self, key: str, value: Optional[bool]) -> None: def _set(self, key: str, value: Optional[bool]) -> None:

21
discord/player.py

@ -365,12 +365,11 @@ class FFmpegOpusAudio(FFmpegAudio):
bitrate: Optional[int] = None, bitrate: Optional[int] = None,
codec: Optional[str] = None, codec: Optional[str] = None,
executable: str = 'ffmpeg', executable: str = 'ffmpeg',
pipe=False, pipe: bool = False,
stderr=None, stderr: Optional[IO[bytes]] = None,
before_options=None, before_options: Optional[str] = None,
options=None, options: Optional[str] = None,
) -> None: ) -> None:
args = [] args = []
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr} subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
@ -635,7 +634,13 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]):
class AudioPlayer(threading.Thread): class AudioPlayer(threading.Thread):
DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0 DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0
def __init__(self, source: AudioSource, client: VoiceClient, *, after=None): def __init__(
self,
source: AudioSource,
client: VoiceClient,
*,
after: Optional[Callable[[Optional[Exception]], Any]] = None,
) -> None:
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.daemon: bool = True self.daemon: bool = True
self.source: AudioSource = source self.source: AudioSource = source
@ -724,8 +729,8 @@ class AudioPlayer(threading.Thread):
self._speak(SpeakingState.none) self._speak(SpeakingState.none)
def resume(self, *, update_speaking: bool = True) -> None: def resume(self, *, update_speaking: bool = True) -> None:
self.loops = 0 self.loops: int = 0
self._start = time.perf_counter() self._start: float = time.perf_counter()
self._resumed.set() self._resumed.set()
if update_speaking: if update_speaking:
self._speak(SpeakingState.voice) self._speak(SpeakingState.voice)

4
discord/reaction.py

@ -94,10 +94,10 @@ class Reaction:
""":class:`bool`: If this is a custom emoji.""" """:class:`bool`: If this is a custom emoji."""
return not isinstance(self.emoji, str) return not isinstance(self.emoji, str)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and other.emoji == self.emoji return isinstance(other, self.__class__) and other.emoji == self.emoji
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
if isinstance(other, self.__class__): if isinstance(other, self.__class__):
return other.emoji != self.emoji return other.emoji != self.emoji
return True return True

4
discord/role.py

@ -211,7 +211,7 @@ class Role(Hashable):
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Role id={self.id} name={self.name!r}>' return f'<Role id={self.id} name={self.name!r}>'
def __lt__(self, other: Any) -> bool: def __lt__(self, other: object) -> bool:
if not isinstance(other, Role) or not isinstance(self, Role): if not isinstance(other, Role) or not isinstance(self, Role):
return NotImplemented return NotImplemented
@ -241,7 +241,7 @@ class Role(Hashable):
def __gt__(self, other: Any) -> bool: def __gt__(self, other: Any) -> bool:
return Role.__lt__(other, self) return Role.__lt__(other, self)
def __ge__(self, other: Any) -> bool: def __ge__(self, other: object) -> bool:
r = Role.__lt__(self, other) r = Role.__lt__(self, other)
if r is NotImplemented: if r is NotImplemented:
return NotImplemented return NotImplemented

6
discord/scheduled_event.py

@ -132,7 +132,7 @@ class ScheduledEvent(Hashable):
self.guild_id: int = int(data['guild_id']) self.guild_id: int = int(data['guild_id'])
self.name: str = data['name'] self.name: str = data['name']
self.description: Optional[str] = data.get('description') self.description: Optional[str] = data.get('description')
self.entity_type = try_enum(EntityType, data['entity_type']) self.entity_type: EntityType = try_enum(EntityType, data['entity_type'])
self.entity_id: Optional[int] = _get_as_snowflake(data, 'entity_id') self.entity_id: Optional[int] = _get_as_snowflake(data, 'entity_id')
self.start_time: datetime = parse_time(data['scheduled_start_time']) self.start_time: datetime = parse_time(data['scheduled_start_time'])
self.privacy_level: PrivacyLevel = try_enum(PrivacyLevel, data['status']) self.privacy_level: PrivacyLevel = try_enum(PrivacyLevel, data['status'])
@ -153,7 +153,7 @@ class ScheduledEvent(Hashable):
self.location: Optional[str] = data.get('location') if data else None self.location: Optional[str] = data.get('location') if data else None
@classmethod @classmethod
def from_creation(cls, *, state: ConnectionState, data: GuildScheduledEventPayload): def from_creation(cls, *, state: ConnectionState, data: GuildScheduledEventPayload) -> None:
creator_id = data.get('creator_id') creator_id = data.get('creator_id')
self = cls(state=state, data=data) self = cls(state=state, data=data)
if creator_id: if creator_id:
@ -180,7 +180,7 @@ class ScheduledEvent(Hashable):
return self.guild.get_channel(self.channel_id) # type: ignore return self.guild.get_channel(self.channel_id) # type: ignore
@property @property
def url(self): def url(self) -> str:
""":class:`str`: The url for the scheduled event.""" """:class:`str`: The url for the scheduled event."""
return f'https://discord.com/events/{self.guild_id}/{self.id}' return f'https://discord.com/events/{self.guild_id}/{self.id}'

5
discord/shard.py

@ -75,12 +75,12 @@ class EventItem:
self.shard: Optional['Shard'] = shard self.shard: Optional['Shard'] = shard
self.error: Optional[Exception] = error self.error: Optional[Exception] = error
def __lt__(self, other: Any) -> bool: def __lt__(self, other: object) -> bool:
if not isinstance(other, EventItem): if not isinstance(other, EventItem):
return NotImplemented return NotImplemented
return self.type < other.type return self.type < other.type
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, EventItem): if not isinstance(other, EventItem):
return NotImplemented return NotImplemented
return self.type == other.type return self.type == other.type
@ -409,6 +409,7 @@ class AutoShardedClient(Client):
async def launch_shards(self) -> None: async def launch_shards(self) -> None:
if self.shard_count is None: if self.shard_count is None:
self.shard_count: int
self.shard_count, gateway = await self.http.get_bot_gateway() self.shard_count, gateway = await self.http.get_bot_gateway()
else: else:
gateway = await self.http.get_gateway() gateway = await self.http.get_gateway()

6
discord/stage_instance.py

@ -97,11 +97,11 @@ class StageInstance(Hashable):
) )
def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None: def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None:
self._state = state self._state: ConnectionState = state
self.guild = guild self.guild: Guild = guild
self._update(data) self._update(data)
def _update(self, data: StageInstancePayload): def _update(self, data: StageInstancePayload) -> None:
self.id: int = int(data['id']) self.id: int = int(data['id'])
self.channel_id: int = int(data['channel_id']) self.channel_id: int = int(data['channel_id'])
self.topic: str = data['topic'] self.topic: str = data['topic']

58
discord/state.py

@ -43,6 +43,8 @@ from typing import (
Sequence, Sequence,
Tuple, Tuple,
Deque, Deque,
Literal,
overload,
) )
import weakref import weakref
import inspect import inspect
@ -88,7 +90,7 @@ if TYPE_CHECKING:
from .types.activity import Activity as ActivityPayload from .types.activity import Activity as ActivityPayload
from .types.channel import DMChannel as DMChannelPayload from .types.channel import DMChannel as DMChannelPayload
from .types.user import User as UserPayload, PartialUser as PartialUserPayload from .types.user import User as UserPayload, PartialUser as PartialUserPayload
from .types.emoji import Emoji as EmojiPayload from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload
from .types.sticker import GuildSticker as GuildStickerPayload from .types.sticker import GuildSticker as GuildStickerPayload
from .types.guild import Guild as GuildPayload from .types.guild import Guild as GuildPayload
from .types.message import Message as MessagePayload, PartialMessage as PartialMessagePayload from .types.message import Message as MessagePayload, PartialMessage as PartialMessagePayload
@ -165,9 +167,9 @@ class ConnectionState:
def __init__( def __init__(
self, self,
*, *,
dispatch: Callable, dispatch: Callable[..., Any],
handlers: Dict[str, Callable], handlers: Dict[str, Callable[..., Any]],
hooks: Dict[str, Callable], hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]],
http: HTTPClient, http: HTTPClient,
**options: Any, **options: Any,
) -> None: ) -> None:
@ -178,9 +180,9 @@ class ConnectionState:
if self.max_messages is not None and self.max_messages <= 0: if self.max_messages is not None and self.max_messages <= 0:
self.max_messages = 1000 self.max_messages = 1000
self.dispatch: Callable = dispatch self.dispatch: Callable[..., Any] = dispatch
self.handlers: Dict[str, Callable] = handlers self.handlers: Dict[str, Callable[..., Any]] = handlers
self.hooks: Dict[str, Callable] = hooks self.hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = hooks
self.shard_count: Optional[int] = None self.shard_count: Optional[int] = None
self._ready_task: Optional[asyncio.Task] = None self._ready_task: Optional[asyncio.Task] = None
self.application_id: Optional[int] = utils._get_as_snowflake(options, 'application_id') self.application_id: Optional[int] = utils._get_as_snowflake(options, 'application_id')
@ -245,6 +247,7 @@ class ConnectionState:
if not intents.members or cache_flags._empty: if not intents.members or cache_flags._empty:
self.store_user = self.store_user_no_intents # type: ignore - This reassignment is on purpose self.store_user = self.store_user_no_intents # type: ignore - This reassignment is on purpose
self.parsers: Dict[str, Callable[[Any], None]]
self.parsers = parsers = {} self.parsers = parsers = {}
for attr, func in inspect.getmembers(self): for attr, func in inspect.getmembers(self):
if attr.startswith('parse_'): if attr.startswith('parse_'):
@ -343,13 +346,13 @@ class ConnectionState:
self._users[user_id] = user self._users[user_id] = user
return user return user
def store_user_no_intents(self, data): def store_user_no_intents(self, data: Union[UserPayload, PartialUserPayload]) -> User:
return User(state=self, data=data) return User(state=self, data=data)
def create_user(self, data: Union[UserPayload, PartialUserPayload]) -> User: def create_user(self, data: Union[UserPayload, PartialUserPayload]) -> User:
return User(state=self, data=data) return User(state=self, data=data)
def get_user(self, id): def get_user(self, id: int) -> Optional[User]:
return self._users.get(id) return self._users.get(id)
def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji: def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji:
@ -571,8 +574,7 @@ class ConnectionState:
pass pass
else: else:
self.application_id = utils._get_as_snowflake(application, 'id') self.application_id = utils._get_as_snowflake(application, 'id')
# flags will always be present here self.application_flags: ApplicationFlags = ApplicationFlags._from_value(application['flags'])
self.application_flags = ApplicationFlags._from_value(application['flags'])
for guild_data in data['guilds']: for guild_data in data['guilds']:
self._add_guild_from_data(guild_data) # type: ignore self._add_guild_from_data(guild_data) # type: ignore
@ -743,7 +745,7 @@ class ConnectionState:
self.dispatch('presence_update', old_member, member) self.dispatch('presence_update', old_member, member)
def parse_user_update(self, data: gw.UserUpdateEvent): def parse_user_update(self, data: gw.UserUpdateEvent) -> None:
if self.user: if self.user:
self.user._update(data) self.user._update(data)
@ -1050,7 +1052,7 @@ class ConnectionState:
guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers'])) guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers']))
self.dispatch('guild_stickers_update', guild, before_stickers, guild.stickers) self.dispatch('guild_stickers_update', guild, before_stickers, guild.stickers)
def _get_create_guild(self, data): def _get_create_guild(self, data: gw.GuildCreateEvent) -> Guild:
if data.get('unavailable') is False: if data.get('unavailable') is False:
# GUILD_CREATE with unavailable in the response # GUILD_CREATE with unavailable in the response
# usually means that the guild has become available # usually means that the guild has become available
@ -1063,10 +1065,22 @@ class ConnectionState:
return self._add_guild_from_data(data) return self._add_guild_from_data(data)
def is_guild_evicted(self, guild) -> bool: def is_guild_evicted(self, guild: Guild) -> bool:
return guild.id not in self._guilds return guild.id not in self._guilds
async def chunk_guild(self, guild, *, wait=True, cache=None): @overload
async def chunk_guild(self, guild: Guild, *, wait: Literal[True] = ..., cache: Optional[bool] = ...) -> List[Member]:
...
@overload
async def chunk_guild(
self, guild: Guild, *, wait: Literal[False] = ..., cache: Optional[bool] = ...
) -> asyncio.Future[List[Member]]:
...
async def chunk_guild(
self, guild: Guild, *, wait: bool = True, cache: Optional[bool] = None
) -> Union[List[Member], asyncio.Future[List[Member]]]:
cache = cache or self.member_cache_flags.joined cache = cache or self.member_cache_flags.joined
request = self._chunk_requests.get(guild.id) request = self._chunk_requests.get(guild.id)
if request is None: if request is None:
@ -1445,16 +1459,19 @@ class ConnectionState:
return channel.guild.get_member(user_id) return channel.guild.get_member(user_id)
return self.get_user(user_id) return self.get_user(user_id)
def get_reaction_emoji(self, data) -> Union[Emoji, PartialEmoji]: def get_reaction_emoji(self, data: PartialEmojiPayload) -> Union[Emoji, PartialEmoji, str]:
emoji_id = utils._get_as_snowflake(data, 'id') emoji_id = utils._get_as_snowflake(data, 'id')
if not emoji_id: if not emoji_id:
return data['name'] # the name key will be a str
return data['name'] # type: ignore
try: try:
return self._emojis[emoji_id] return self._emojis[emoji_id]
except KeyError: except KeyError:
return PartialEmoji.with_state(self, animated=data.get('animated', False), id=emoji_id, name=data['name']) return PartialEmoji.with_state(
self, animated=data.get('animated', False), id=emoji_id, name=data['name'] # type: ignore
)
def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> Union[Emoji, PartialEmoji, str]: def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> Union[Emoji, PartialEmoji, str]:
emoji_id = emoji.id emoji_id = emoji.id
@ -1589,6 +1606,7 @@ class AutoShardedConnectionState(ConnectionState):
if not hasattr(self, '_ready_state'): if not hasattr(self, '_ready_state'):
self._ready_state = asyncio.Queue() self._ready_state = asyncio.Queue()
self.user: Optional[ClientUser]
self.user = user = ClientUser(state=self, data=data['user']) self.user = user = ClientUser(state=self, data=data['user'])
# self._users is a list of Users, we're setting a ClientUser # self._users is a list of Users, we're setting a ClientUser
self._users[user.id] = user # type: ignore self._users[user.id] = user # type: ignore
@ -1599,8 +1617,8 @@ class AutoShardedConnectionState(ConnectionState):
except KeyError: except KeyError:
pass pass
else: else:
self.application_id = utils._get_as_snowflake(application, 'id') self.application_id: Optional[int] = utils._get_as_snowflake(application, 'id')
self.application_flags = ApplicationFlags._from_value(application['flags']) self.application_flags: ApplicationFlags = ApplicationFlags._from_value(application['flags'])
for guild_data in data['guilds']: for guild_data in data['guilds']:
self._add_guild_from_data(guild_data) # type: ignore - _add_guild_from_data requires a complete Guild payload self._add_guild_from_data(guild_data) # type: ignore - _add_guild_from_data requires a complete Guild payload

2
discord/sticker.py

@ -228,7 +228,7 @@ class StickerItem(_StickerTag):
The retrieved sticker. The retrieved sticker.
""" """
data: StickerPayload = await self._state.http.get_sticker(self.id) data: StickerPayload = await self._state.http.get_sticker(self.id)
cls, _ = _sticker_factory(data['type']) # type: ignore cls, _ = _sticker_factory(data['type'])
return cls(state=self._state, data=data) return cls(state=self._state, data=data)

71
discord/threads.py

@ -41,6 +41,8 @@ __all__ = (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from .types.threads import ( from .types.threads import (
Thread as ThreadPayload, Thread as ThreadPayload,
ThreadMember as ThreadMemberPayload, ThreadMember as ThreadMemberPayload,
@ -147,13 +149,13 @@ class Thread(Messageable, Hashable):
'_created_at', '_created_at',
) )
def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload): def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload) -> None:
self._state: ConnectionState = state self._state: ConnectionState = state
self.guild = guild self.guild: Guild = guild
self._members: Dict[int, ThreadMember] = {} self._members: Dict[int, ThreadMember] = {}
self._from_data(data) self._from_data(data)
async def _get_channel(self): async def _get_channel(self) -> Self:
return self return self
def __repr__(self) -> str: def __repr__(self) -> str:
@ -166,17 +168,18 @@ class Thread(Messageable, Hashable):
return self.name return self.name
def _from_data(self, data: ThreadPayload): def _from_data(self, data: ThreadPayload):
self.id = int(data['id']) self.id: int = int(data['id'])
self.parent_id = int(data['parent_id']) self.parent_id: int = int(data['parent_id'])
self.owner_id = int(data['owner_id']) self.owner_id: int = int(data['owner_id'])
self.name = data['name'] self.name: str = data['name']
self._type = try_enum(ChannelType, data['type']) self._type: ChannelType = try_enum(ChannelType, data['type'])
self.last_message_id = _get_as_snowflake(data, 'last_message_id') self.last_message_id: Optional[int] = _get_as_snowflake(data, 'last_message_id')
self.slowmode_delay = data.get('rate_limit_per_user', 0) self.slowmode_delay: int = data.get('rate_limit_per_user', 0)
self.message_count = data['message_count'] self.message_count: int = data['message_count']
self.member_count = data['member_count'] self.member_count: int = data['member_count']
self._unroll_metadata(data['thread_metadata']) self._unroll_metadata(data['thread_metadata'])
self.me: Optional[ThreadMember]
try: try:
member = data['member'] member = data['member']
except KeyError: except KeyError:
@ -185,15 +188,15 @@ class Thread(Messageable, Hashable):
self.me = ThreadMember(self, member) self.me = ThreadMember(self, member)
def _unroll_metadata(self, data: ThreadMetadata): def _unroll_metadata(self, data: ThreadMetadata):
self.archived = data['archived'] self.archived: bool = data['archived']
self.archiver_id = _get_as_snowflake(data, 'archiver_id') self.archiver_id: Optional[int] = _get_as_snowflake(data, 'archiver_id')
self.auto_archive_duration = data['auto_archive_duration'] self.auto_archive_duration: int = data['auto_archive_duration']
self.archive_timestamp = parse_time(data['archive_timestamp']) self.archive_timestamp: datetime = parse_time(data['archive_timestamp'])
self.locked = data.get('locked', False) self.locked: bool = data.get('locked', False)
self.invitable = data.get('invitable', True) self.invitable: bool = data.get('invitable', True)
self._created_at = parse_time(data.get('create_timestamp')) self._created_at: Optional[datetime] = parse_time(data.get('create_timestamp'))
def _update(self, data): def _update(self, data: ThreadPayload) -> None:
try: try:
self.name = data['name'] self.name = data['name']
except KeyError: except KeyError:
@ -602,7 +605,7 @@ class Thread(Messageable, Hashable):
# The data payload will always be a Thread payload # The data payload will always be a Thread payload
return Thread(data=data, state=self._state, guild=self.guild) # type: ignore return Thread(data=data, state=self._state, guild=self.guild) # type: ignore
async def join(self): async def join(self) -> None:
"""|coro| """|coro|
Joins this thread. Joins this thread.
@ -619,7 +622,7 @@ class Thread(Messageable, Hashable):
""" """
await self._state.http.join_thread(self.id) await self._state.http.join_thread(self.id)
async def leave(self): async def leave(self) -> None:
"""|coro| """|coro|
Leaves this thread. Leaves this thread.
@ -631,7 +634,7 @@ class Thread(Messageable, Hashable):
""" """
await self._state.http.leave_thread(self.id) await self._state.http.leave_thread(self.id)
async def add_user(self, user: Snowflake, /): async def add_user(self, user: Snowflake, /) -> None:
"""|coro| """|coro|
Adds a user to this thread. Adds a user to this thread.
@ -654,7 +657,7 @@ class Thread(Messageable, Hashable):
""" """
await self._state.http.add_user_to_thread(self.id, user.id) await self._state.http.add_user_to_thread(self.id, user.id)
async def remove_user(self, user: Snowflake, /): async def remove_user(self, user: Snowflake, /) -> None:
"""|coro| """|coro|
Removes a user from this thread. Removes a user from this thread.
@ -718,7 +721,7 @@ class Thread(Messageable, Hashable):
members = await self._state.http.get_thread_members(self.id) members = await self._state.http.get_thread_members(self.id)
return [ThreadMember(parent=self, data=data) for data in members] return [ThreadMember(parent=self, data=data) for data in members]
async def delete(self): async def delete(self) -> None:
"""|coro| """|coro|
Deletes this thread. Deletes this thread.
@ -806,28 +809,28 @@ class ThreadMember(Hashable):
'parent', 'parent',
) )
def __init__(self, parent: Thread, data: ThreadMemberPayload): def __init__(self, parent: Thread, data: ThreadMemberPayload) -> None:
self.parent = parent self.parent: Thread = parent
self._state = parent._state self._state: ConnectionState = parent._state
self._from_data(data) self._from_data(data)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>' return f'<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>'
def _from_data(self, data: ThreadMemberPayload): def _from_data(self, data: ThreadMemberPayload) -> None:
try: try:
self.id = int(data['user_id']) self.id = int(data['user_id'])
except KeyError: except KeyError:
assert self._state.self_id is not None self.id = self._state.self_id # type: ignore
self.id = self._state.self_id
self.thread_id: int
try: try:
self.thread_id = int(data['id']) self.thread_id = int(data['id'])
except KeyError: except KeyError:
self.thread_id = self.parent.id self.thread_id = self.parent.id
self.joined_at = parse_time(data['join_timestamp']) self.joined_at: datetime = parse_time(data['join_timestamp'])
self.flags = data['flags'] self.flags: int = data['flags']
@property @property
def thread(self) -> Thread: def thread(self) -> Thread:

1
discord/types/activity.py

@ -112,3 +112,4 @@ class Activity(_BaseActivity, total=False):
session_id: Optional[str] session_id: Optional[str]
instance: bool instance: bool
buttons: List[ActivityButton] buttons: List[ActivityButton]
sync_id: str

5
discord/types/widget.py

@ -58,3 +58,8 @@ class Widget(TypedDict):
class WidgetSettings(TypedDict): class WidgetSettings(TypedDict):
enabled: bool enabled: bool
channel_id: Optional[Snowflake] channel_id: Optional[Snowflake]
class EditWidgetSettings(TypedDict, total=False):
enabled: bool
channel_id: Optional[Snowflake]

17
discord/ui/button.py

@ -44,6 +44,7 @@ if TYPE_CHECKING:
from .view import View from .view import View
from ..emoji import Emoji from ..emoji import Emoji
from ..types.components import ButtonComponent as ButtonComponentPayload
V = TypeVar('V', bound='View', covariant=True) V = TypeVar('V', bound='View', covariant=True)
@ -124,7 +125,7 @@ class Button(Item[V]):
style=style, style=style,
emoji=emoji, emoji=emoji,
) )
self.row = row self.row: Optional[int] = row
@property @property
def style(self) -> ButtonStyle: def style(self) -> ButtonStyle:
@ -132,7 +133,7 @@ class Button(Item[V]):
return self._underlying.style return self._underlying.style
@style.setter @style.setter
def style(self, value: ButtonStyle): def style(self, value: ButtonStyle) -> None:
self._underlying.style = value self._underlying.style = value
@property @property
@ -144,7 +145,7 @@ class Button(Item[V]):
return self._underlying.custom_id return self._underlying.custom_id
@custom_id.setter @custom_id.setter
def custom_id(self, value: Optional[str]): def custom_id(self, value: Optional[str]) -> None:
if value is not None and not isinstance(value, str): if value is not None and not isinstance(value, str):
raise TypeError('custom_id must be None or str') raise TypeError('custom_id must be None or str')
@ -156,7 +157,7 @@ class Button(Item[V]):
return self._underlying.url return self._underlying.url
@url.setter @url.setter
def url(self, value: Optional[str]): def url(self, value: Optional[str]) -> None:
if value is not None and not isinstance(value, str): if value is not None and not isinstance(value, str):
raise TypeError('url must be None or str') raise TypeError('url must be None or str')
self._underlying.url = value self._underlying.url = value
@ -167,7 +168,7 @@ class Button(Item[V]):
return self._underlying.disabled return self._underlying.disabled
@disabled.setter @disabled.setter
def disabled(self, value: bool): def disabled(self, value: bool) -> None:
self._underlying.disabled = bool(value) self._underlying.disabled = bool(value)
@property @property
@ -176,7 +177,7 @@ class Button(Item[V]):
return self._underlying.label return self._underlying.label
@label.setter @label.setter
def label(self, value: Optional[str]): def label(self, value: Optional[str]) -> None:
self._underlying.label = str(value) if value is not None else value self._underlying.label = str(value) if value is not None else value
@property @property
@ -185,7 +186,7 @@ class Button(Item[V]):
return self._underlying.emoji return self._underlying.emoji
@emoji.setter @emoji.setter
def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]) -> None:
if value is not None: if value is not None:
if isinstance(value, str): if isinstance(value, str):
self._underlying.emoji = PartialEmoji.from_str(value) self._underlying.emoji = PartialEmoji.from_str(value)
@ -212,7 +213,7 @@ class Button(Item[V]):
def type(self) -> ComponentType: def type(self) -> ComponentType:
return self._underlying.type return self._underlying.type
def to_component_dict(self): def to_component_dict(self) -> ButtonComponentPayload:
return self._underlying.to_dict() return self._underlying.to_dict()
def is_dispatchable(self) -> bool: def is_dispatchable(self) -> bool:

4
discord/ui/item.py

@ -101,7 +101,7 @@ class Item(Generic[V]):
return self._row return self._row
@row.setter @row.setter
def row(self, value: Optional[int]): def row(self, value: Optional[int]) -> None:
if value is None: if value is None:
self._row = None self._row = None
elif 5 > value >= 0: elif 5 > value >= 0:
@ -118,7 +118,7 @@ class Item(Generic[V]):
"""Optional[:class:`View`]: The underlying view for this item.""" """Optional[:class:`View`]: The underlying view for this item."""
return self._view return self._view
async def callback(self, interaction: Interaction): async def callback(self, interaction: Interaction) -> Any:
"""|coro| """|coro|
The callback associated with this UI item. The callback associated with this UI item.

8
discord/ui/modal.py

@ -38,6 +38,8 @@ from .item import Item
from .view import View from .view import View
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from ..interactions import Interaction from ..interactions import Interaction
from ..types.interactions import ModalSubmitComponentInteractionData as ModalSubmitComponentInteractionDataPayload from ..types.interactions import ModalSubmitComponentInteractionData as ModalSubmitComponentInteractionDataPayload
@ -101,7 +103,7 @@ class Modal(View):
title: str title: str
__discord_ui_modal__ = True __discord_ui_modal__ = True
__modal_children_items__: ClassVar[Dict[str, Item]] = {} __modal_children_items__: ClassVar[Dict[str, Item[Self]]] = {}
def __init_subclass__(cls, *, title: str = MISSING) -> None: def __init_subclass__(cls, *, title: str = MISSING) -> None:
if title is not MISSING: if title is not MISSING:
@ -139,7 +141,7 @@ class Modal(View):
super().__init__(timeout=timeout) super().__init__(timeout=timeout)
async def on_submit(self, interaction: Interaction): async def on_submit(self, interaction: Interaction) -> None:
"""|coro| """|coro|
Called when the modal is submitted. Called when the modal is submitted.
@ -169,7 +171,7 @@ class Modal(View):
print(f'Ignoring exception in modal {self}:', file=sys.stderr) print(f'Ignoring exception in modal {self}:', file=sys.stderr)
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr) traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
def refresh(self, components: Sequence[ModalSubmitComponentInteractionDataPayload]): def refresh(self, components: Sequence[ModalSubmitComponentInteractionDataPayload]) -> None:
for component in components: for component in components:
if component['type'] == 1: if component['type'] == 1:
self.refresh(component['components']) self.refresh(component['components'])

18
discord/ui/select.py

@ -121,7 +121,7 @@ class Select(Item[V]):
options=options, options=options,
disabled=disabled, disabled=disabled,
) )
self.row = row self.row: Optional[int] = row
@property @property
def custom_id(self) -> str: def custom_id(self) -> str:
@ -129,7 +129,7 @@ class Select(Item[V]):
return self._underlying.custom_id return self._underlying.custom_id
@custom_id.setter @custom_id.setter
def custom_id(self, value: str): def custom_id(self, value: str) -> None:
if not isinstance(value, str): if not isinstance(value, str):
raise TypeError('custom_id must be None or str') raise TypeError('custom_id must be None or str')
@ -141,7 +141,7 @@ class Select(Item[V]):
return self._underlying.placeholder return self._underlying.placeholder
@placeholder.setter @placeholder.setter
def placeholder(self, value: Optional[str]): def placeholder(self, value: Optional[str]) -> None:
if value is not None and not isinstance(value, str): if value is not None and not isinstance(value, str):
raise TypeError('placeholder must be None or str') raise TypeError('placeholder must be None or str')
@ -153,7 +153,7 @@ class Select(Item[V]):
return self._underlying.min_values return self._underlying.min_values
@min_values.setter @min_values.setter
def min_values(self, value: int): def min_values(self, value: int) -> None:
self._underlying.min_values = int(value) self._underlying.min_values = int(value)
@property @property
@ -162,7 +162,7 @@ class Select(Item[V]):
return self._underlying.max_values return self._underlying.max_values
@max_values.setter @max_values.setter
def max_values(self, value: int): def max_values(self, value: int) -> None:
self._underlying.max_values = int(value) self._underlying.max_values = int(value)
@property @property
@ -171,7 +171,7 @@ class Select(Item[V]):
return self._underlying.options return self._underlying.options
@options.setter @options.setter
def options(self, value: List[SelectOption]): def options(self, value: List[SelectOption]) -> None:
if not isinstance(value, list): if not isinstance(value, list):
raise TypeError('options must be a list of SelectOption') raise TypeError('options must be a list of SelectOption')
if not all(isinstance(obj, SelectOption) for obj in value): if not all(isinstance(obj, SelectOption) for obj in value):
@ -187,7 +187,7 @@ class Select(Item[V]):
description: Optional[str] = None, description: Optional[str] = None,
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None, emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
default: bool = False, default: bool = False,
): ) -> None:
"""Adds an option to the select menu. """Adds an option to the select menu.
To append a pre-existing :class:`discord.SelectOption` use the To append a pre-existing :class:`discord.SelectOption` use the
@ -226,7 +226,7 @@ class Select(Item[V]):
self.append_option(option) self.append_option(option)
def append_option(self, option: SelectOption): def append_option(self, option: SelectOption) -> None:
"""Appends an option to the select menu. """Appends an option to the select menu.
Parameters Parameters
@ -251,7 +251,7 @@ class Select(Item[V]):
return self._underlying.disabled return self._underlying.disabled
@disabled.setter @disabled.setter
def disabled(self, value: bool): def disabled(self, value: bool) -> None:
self._underlying.disabled = bool(value) self._underlying.disabled = bool(value)
@property @property

30
discord/ui/view.py

@ -50,6 +50,8 @@ __all__ = (
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from ..interactions import Interaction from ..interactions import Interaction
from ..message import Message from ..message import Message
from ..types.components import Component as ComponentPayload from ..types.components import Component as ComponentPayload
@ -163,7 +165,7 @@ class View:
cls.__view_children_items__ = children cls.__view_children_items__ = children
def _init_children(self) -> List[Item]: def _init_children(self) -> List[Item[Self]]:
children = [] children = []
for func in self.__view_children_items__: for func in self.__view_children_items__:
item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__) item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__)
@ -175,7 +177,7 @@ class View:
def __init__(self, *, timeout: Optional[float] = 180.0): def __init__(self, *, timeout: Optional[float] = 180.0):
self.timeout = timeout self.timeout = timeout
self.children: List[Item] = self._init_children() self.children: List[Item[Self]] = self._init_children()
self.__weights = _ViewWeights(self.children) self.__weights = _ViewWeights(self.children)
self.id: str = os.urandom(16).hex() self.id: str = os.urandom(16).hex()
self.__cancel_callback: Optional[Callable[[View], None]] = None self.__cancel_callback: Optional[Callable[[View], None]] = None
@ -250,7 +252,7 @@ class View:
view.add_item(_component_to_item(component)) view.add_item(_component_to_item(component))
return view return view
def add_item(self, item: Item) -> None: def add_item(self, item: Item[Any]) -> None:
"""Adds an item to the view. """Adds an item to the view.
Parameters Parameters
@ -278,7 +280,7 @@ class View:
item._view = self item._view = self
self.children.append(item) self.children.append(item)
def remove_item(self, item: Item) -> None: def remove_item(self, item: Item[Any]) -> None:
"""Removes an item from the view. """Removes an item from the view.
Parameters Parameters
@ -334,7 +336,7 @@ class View:
""" """
pass pass
async def on_error(self, error: Exception, item: Item, interaction: Interaction) -> None: async def on_error(self, error: Exception, item: Item[Any], interaction: Interaction) -> None:
"""|coro| """|coro|
A callback that is called when an item's callback or :meth:`interaction_check` A callback that is called when an item's callback or :meth:`interaction_check`
@ -395,16 +397,16 @@ class View:
asyncio.create_task(self._scheduled_task(item, interaction), name=f'discord-ui-view-dispatch-{self.id}') asyncio.create_task(self._scheduled_task(item, interaction), name=f'discord-ui-view-dispatch-{self.id}')
def refresh(self, components: List[Component]): def refresh(self, components: List[Component]) -> None:
# This is pretty hacky at the moment # This is pretty hacky at the moment
# fmt: off # fmt: off
old_state: Dict[Tuple[int, str], Item] = { old_state: Dict[Tuple[int, str], Item[Any]] = {
(item.type.value, item.custom_id): item # type: ignore (item.type.value, item.custom_id): item # type: ignore
for item in self.children for item in self.children
if item.is_dispatchable() if item.is_dispatchable()
} }
# fmt: on # fmt: on
children: List[Item] = [] children: List[Item[Any]] = []
for component in _walk_all_components(components): for component in _walk_all_components(components):
try: try:
older = old_state[(component.type.value, component.custom_id)] # type: ignore older = old_state[(component.type.value, component.custom_id)] # type: ignore
@ -494,7 +496,7 @@ class ViewStore:
for k in to_remove: for k in to_remove:
del self._views[k] del self._views[k]
def add_view(self, view: View, message_id: Optional[int] = None): def add_view(self, view: View, message_id: Optional[int] = None) -> None:
view._start_listening_from_store(self) view._start_listening_from_store(self)
if view.__discord_ui_modal__: if view.__discord_ui_modal__:
self._modals[view.custom_id] = view # type: ignore self._modals[view.custom_id] = view # type: ignore
@ -509,7 +511,7 @@ class ViewStore:
if message_id is not None: if message_id is not None:
self._synced_message_views[message_id] = view self._synced_message_views[message_id] = view
def remove_view(self, view: View): def remove_view(self, view: View) -> None:
if view.__discord_ui_modal__: if view.__discord_ui_modal__:
self._modals.pop(view.custom_id, None) # type: ignore self._modals.pop(view.custom_id, None) # type: ignore
return return
@ -523,7 +525,7 @@ class ViewStore:
del self._synced_message_views[key] del self._synced_message_views[key]
break break
def dispatch_view(self, component_type: int, custom_id: str, interaction: Interaction): def dispatch_view(self, component_type: int, custom_id: str, interaction: Interaction) -> None:
self.__verify_integrity() self.__verify_integrity()
message_id: Optional[int] = interaction.message and interaction.message.id message_id: Optional[int] = interaction.message and interaction.message.id
key = (component_type, message_id, custom_id) key = (component_type, message_id, custom_id)
@ -542,7 +544,7 @@ class ViewStore:
custom_id: str, custom_id: str,
interaction: Interaction, interaction: Interaction,
components: List[ModalSubmitComponentInteractionDataPayload], components: List[ModalSubmitComponentInteractionDataPayload],
): ) -> None:
modal = self._modals.get(custom_id) modal = self._modals.get(custom_id)
if modal is None: if modal is None:
_log.debug("Modal interaction referencing unknown custom_id %s. Discarding", custom_id) _log.debug("Modal interaction referencing unknown custom_id %s. Discarding", custom_id)
@ -551,13 +553,13 @@ class ViewStore:
modal.refresh(components) modal.refresh(components)
modal._dispatch_submit(interaction) modal._dispatch_submit(interaction)
def is_message_tracked(self, message_id: int): def is_message_tracked(self, message_id: int) -> bool:
return message_id in self._synced_message_views return message_id in self._synced_message_views
def remove_message_tracking(self, message_id: int) -> Optional[View]: def remove_message_tracking(self, message_id: int) -> Optional[View]:
return self._synced_message_views.pop(message_id, None) return self._synced_message_views.pop(message_id, None)
def update_from_message(self, message_id: int, components: List[ComponentPayload]): def update_from_message(self, message_id: int, components: List[ComponentPayload]) -> None:
# pre-req: is_message_tracked == true # pre-req: is_message_tracked == true
view = self._synced_message_views[message_id] view = self._synced_message_views[message_id]
view.refresh([_component_factory(d) for d in components]) view.refresh([_component_factory(d) for d in components])

6
discord/user.py

@ -99,10 +99,10 @@ class BaseUser(_UserTag):
def __str__(self) -> str: def __str__(self) -> str:
return f'{self.name}#{self.discriminator}' return f'{self.name}#{self.discriminator}'
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, _UserTag) and other.id == self.id return isinstance(other, _UserTag) and other.id == self.id
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:
@ -444,7 +444,7 @@ class User(BaseUser, discord.abc.Messageable):
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>' return f'<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>'
async def _get_channel(self): async def _get_channel(self) -> DMChannel:
ch = await self.create_dm() ch = await self.create_dm()
return ch return ch

36
discord/utils.py

@ -29,6 +29,7 @@ from typing import (
Any, Any,
AsyncIterable, AsyncIterable,
AsyncIterator, AsyncIterator,
Awaitable,
Callable, Callable,
Coroutine, Coroutine,
Dict, Dict,
@ -42,6 +43,7 @@ from typing import (
NamedTuple, NamedTuple,
Optional, Optional,
Protocol, Protocol,
Set,
Sequence, Sequence,
Tuple, Tuple,
Type, Type,
@ -66,7 +68,7 @@ import warnings
import yarl import yarl
try: try:
import orjson import orjson # type: ignore
except ModuleNotFoundError: except ModuleNotFoundError:
HAS_ORJSON = False HAS_ORJSON = False
else: else:
@ -123,7 +125,7 @@ class _cached_property:
if TYPE_CHECKING: if TYPE_CHECKING:
from functools import cached_property as cached_property from functools import cached_property as cached_property
from typing_extensions import ParamSpec from typing_extensions import ParamSpec, Self
from .permissions import Permissions from .permissions import Permissions
from .abc import Snowflake from .abc import Snowflake
@ -135,8 +137,16 @@ if TYPE_CHECKING:
P = ParamSpec('P') P = ParamSpec('P')
MaybeCoroFunc = Union[
Callable[P, Coroutine[Any, Any, 'T']],
Callable[P, 'T'],
]
_SnowflakeListBase = array.array[int]
else: else:
cached_property = _cached_property cached_property = _cached_property
_SnowflakeListBase = array.array
T = TypeVar('T') T = TypeVar('T')
@ -178,7 +188,7 @@ class classproperty(Generic[T_co]):
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co: def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
return self.fget(owner) return self.fget(owner)
def __set__(self, instance, value) -> None: def __set__(self, instance: Optional[Any], value: Any) -> None:
raise AttributeError('cannot set attribute') raise AttributeError('cannot set attribute')
@ -210,7 +220,7 @@ class SequenceProxy(Sequence[T_co]):
def __reversed__(self) -> Iterator[T_co]: def __reversed__(self) -> Iterator[T_co]:
return reversed(self.__proxied) return reversed(self.__proxied)
def index(self, value: Any, *args, **kwargs) -> int: def index(self, value: Any, *args: Any, **kwargs: Any) -> int:
return self.__proxied.index(value, *args, **kwargs) return self.__proxied.index(value, *args, **kwargs)
def count(self, value: Any) -> int: def count(self, value: Any) -> int:
@ -578,7 +588,7 @@ def _is_submodule(parent: str, child: str) -> bool:
if HAS_ORJSON: if HAS_ORJSON:
def _to_json(obj: Any) -> str: # type: ignore def _to_json(obj: Any) -> str:
return orjson.dumps(obj).decode('utf-8') return orjson.dumps(obj).decode('utf-8')
_from_json = orjson.loads # type: ignore _from_json = orjson.loads # type: ignore
@ -602,15 +612,15 @@ def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
return float(reset_after) return float(reset_after)
async def maybe_coroutine(f, *args, **kwargs): async def maybe_coroutine(f: MaybeCoroFunc[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
value = f(*args, **kwargs) value = f(*args, **kwargs)
if _isawaitable(value): if _isawaitable(value):
return await value return await value
else: else:
return value return value # type: ignore
async def async_all(gen, *, check=_isawaitable): async def async_all(gen: Iterable[Awaitable[T]], *, check: Callable[[T], bool] = _isawaitable) -> bool:
for elem in gen: for elem in gen:
if check(elem): if check(elem):
elem = await elem elem = await elem
@ -619,7 +629,7 @@ async def async_all(gen, *, check=_isawaitable):
return True return True
async def sane_wait_for(futures, *, timeout): async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: Optional[float]) -> Set[asyncio.Task[T]]:
ensured = [asyncio.ensure_future(fut) for fut in futures] ensured = [asyncio.ensure_future(fut) for fut in futures]
done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED) done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED)
@ -637,7 +647,7 @@ def get_slots(cls: Type[Any]) -> Iterator[str]:
continue continue
def compute_timedelta(dt: datetime.datetime): def compute_timedelta(dt: datetime.datetime) -> float:
if dt.tzinfo is None: if dt.tzinfo is None:
dt = dt.astimezone() dt = dt.astimezone()
now = datetime.datetime.now(datetime.timezone.utc) now = datetime.datetime.now(datetime.timezone.utc)
@ -686,7 +696,7 @@ def valid_icon_size(size: int) -> bool:
return not size & (size - 1) and 4096 >= size >= 16 return not size & (size - 1) and 4096 >= size >= 16
class SnowflakeList(array.array): class SnowflakeList(_SnowflakeListBase):
"""Internal data storage class to efficiently store a list of snowflakes. """Internal data storage class to efficiently store a list of snowflakes.
This should have the following characteristics: This should have the following characteristics:
@ -705,7 +715,7 @@ class SnowflakeList(array.array):
def __init__(self, data: Iterable[int], *, is_sorted: bool = False): def __init__(self, data: Iterable[int], *, is_sorted: bool = False):
... ...
def __new__(cls, data: Iterable[int], *, is_sorted: bool = False): def __new__(cls, data: Iterable[int], *, is_sorted: bool = False) -> Self:
return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) # type: ignore return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) # type: ignore
def add(self, element: int) -> None: def add(self, element: int) -> None:
@ -1010,7 +1020,7 @@ def evaluate_annotation(
cache: Dict[str, Any], cache: Dict[str, Any],
*, *,
implicit_str: bool = True, implicit_str: bool = True,
): ) -> Any:
if isinstance(tp, ForwardRef): if isinstance(tp, ForwardRef):
tp = tp.__forward_arg__ tp = tp.__forward_arg__
# ForwardRefs always evaluate their internals # ForwardRefs always evaluate their internals

10
discord/voice_client.py

@ -262,7 +262,7 @@ class VoiceClient(VoiceProtocol):
self._lite_nonce: int = 0 self._lite_nonce: int = 0
self.ws: DiscordVoiceWebSocket = MISSING self.ws: DiscordVoiceWebSocket = MISSING
warn_nacl = not has_nacl warn_nacl: bool = not has_nacl
supported_modes: Tuple[SupportedModes, ...] = ( supported_modes: Tuple[SupportedModes, ...] = (
'xsalsa20_poly1305_lite', 'xsalsa20_poly1305_lite',
'xsalsa20_poly1305_suffix', 'xsalsa20_poly1305_suffix',
@ -279,7 +279,7 @@ class VoiceClient(VoiceProtocol):
""":class:`ClientUser`: The user connected to voice (i.e. ourselves).""" """:class:`ClientUser`: The user connected to voice (i.e. ourselves)."""
return self._state.user # type: ignore - user can't be None after login return self._state.user # type: ignore - user can't be None after login
def checked_add(self, attr, value, limit): def checked_add(self, attr: str, value: int, limit: int) -> None:
val = getattr(self, attr) val = getattr(self, attr)
if val + value > limit: if val + value > limit:
setattr(self, attr, 0) setattr(self, attr, 0)
@ -289,7 +289,7 @@ class VoiceClient(VoiceProtocol):
# connection related # connection related
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None: async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
self.session_id = data['session_id'] self.session_id: str = data['session_id']
channel_id = data['channel_id'] channel_id = data['channel_id']
if not self._handshaking or self._potentially_reconnecting: if not self._handshaking or self._potentially_reconnecting:
@ -323,12 +323,12 @@ class VoiceClient(VoiceProtocol):
self.endpoint, _, _ = endpoint.rpartition(':') self.endpoint, _, _ = endpoint.rpartition(':')
if self.endpoint.startswith('wss://'): if self.endpoint.startswith('wss://'):
# Just in case, strip it off since we're going to add it later # Just in case, strip it off since we're going to add it later
self.endpoint = self.endpoint[6:] self.endpoint: str = self.endpoint[6:]
# This gets set later # This gets set later
self.endpoint_ip = MISSING self.endpoint_ip = MISSING
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.setblocking(False) self.socket.setblocking(False)
if not self._handshaking: if not self._handshaking:

122
discord/webhook/async_.py

@ -30,7 +30,7 @@ import json
import re import re
from urllib.parse import quote as urlquote from urllib.parse import quote as urlquote
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, TypeVar, Type, overload
from contextvars import ContextVar from contextvars import ContextVar
import weakref import weakref
@ -43,7 +43,7 @@ from ..enums import try_enum, WebhookType
from ..user import BaseUser, User from ..user import BaseUser, User
from ..flags import MessageFlags from ..flags import MessageFlags
from ..asset import Asset from ..asset import Asset
from ..http import Route, handle_message_parameters, MultipartParameters from ..http import Route, handle_message_parameters, MultipartParameters, HTTPClient
from ..mixins import Hashable from ..mixins import Hashable
from ..channel import PartialMessageable from ..channel import PartialMessageable
from ..file import File from ..file import File
@ -58,24 +58,38 @@ __all__ = (
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from types import TracebackType
from ..embeds import Embed from ..embeds import Embed
from ..mentions import AllowedMentions from ..mentions import AllowedMentions
from ..message import Attachment from ..message import Attachment
from ..state import ConnectionState from ..state import ConnectionState
from ..http import Response from ..http import Response
from ..guild import Guild
from ..channel import TextChannel
from ..abc import Snowflake
from ..ui.view import View
import datetime
from ..types.webhook import ( from ..types.webhook import (
Webhook as WebhookPayload, Webhook as WebhookPayload,
SourceGuild as SourceGuildPayload,
) )
from ..types.message import ( from ..types.message import (
Message as MessagePayload, Message as MessagePayload,
) )
from ..guild import Guild from ..types.user import (
from ..channel import TextChannel User as UserPayload,
from ..abc import Snowflake PartialUser as PartialUserPayload,
from ..ui.view import View )
import datetime from ..types.channel import (
PartialChannel as PartialChannelPayload,
)
BE = TypeVar('BE', bound=BaseException)
_State = Union[ConnectionState, '_WebhookState']
MISSING = utils.MISSING MISSING: Any = utils.MISSING
class AsyncDeferredLock: class AsyncDeferredLock:
@ -83,14 +97,19 @@ class AsyncDeferredLock:
self.lock = lock self.lock = lock
self.delta: Optional[float] = None self.delta: Optional[float] = None
async def __aenter__(self): async def __aenter__(self) -> Self:
await self.lock.acquire() await self.lock.acquire()
return self return self
def delay_by(self, delta: float) -> None: def delay_by(self, delta: float) -> None:
self.delta = delta self.delta = delta
async def __aexit__(self, type, value, traceback): async def __aexit__(
self,
exc_type: Optional[Type[BE]],
exc: Optional[BE],
traceback: Optional[TracebackType],
) -> None:
if self.delta: if self.delta:
await asyncio.sleep(self.delta) await asyncio.sleep(self.delta)
self.lock.release() self.lock.release()
@ -545,11 +564,11 @@ class PartialWebhookChannel(Hashable):
__slots__ = ('id', 'name') __slots__ = ('id', 'name')
def __init__(self, *, data): def __init__(self, *, data: PartialChannelPayload) -> None:
self.id = int(data['id']) self.id: int = int(data['id'])
self.name = data['name'] self.name: str = data['name']
def __repr__(self): def __repr__(self) -> str:
return f'<PartialWebhookChannel name={self.name!r} id={self.id}>' return f'<PartialWebhookChannel name={self.name!r} id={self.id}>'
@ -570,13 +589,13 @@ class PartialWebhookGuild(Hashable):
__slots__ = ('id', 'name', '_icon', '_state') __slots__ = ('id', 'name', '_icon', '_state')
def __init__(self, *, data, state): def __init__(self, *, data: SourceGuildPayload, state: _State) -> None:
self._state = state self._state: _State = state
self.id = int(data['id']) self.id: int = int(data['id'])
self.name = data['name'] self.name: str = data['name']
self._icon = data['icon'] self._icon: str = data['icon']
def __repr__(self): def __repr__(self) -> str:
return f'<PartialWebhookGuild name={self.name!r} id={self.id}>' return f'<PartialWebhookGuild name={self.name!r} id={self.id}>'
@property @property
@ -590,14 +609,14 @@ class PartialWebhookGuild(Hashable):
class _FriendlyHttpAttributeErrorHelper: class _FriendlyHttpAttributeErrorHelper:
__slots__ = () __slots__ = ()
def __getattr__(self, attr): def __getattr__(self, attr: str) -> Any:
raise AttributeError('PartialWebhookState does not support http methods.') raise AttributeError('PartialWebhookState does not support http methods.')
class _WebhookState: class _WebhookState:
__slots__ = ('_parent', '_webhook') __slots__ = ('_parent', '_webhook')
def __init__(self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]): def __init__(self, webhook: Any, parent: Optional[_State]):
self._webhook: Any = webhook self._webhook: Any = webhook
self._parent: Optional[ConnectionState] self._parent: Optional[ConnectionState]
@ -606,23 +625,23 @@ class _WebhookState:
else: else:
self._parent = parent self._parent = parent
def _get_guild(self, guild_id): def _get_guild(self, guild_id: Optional[int]) -> Optional[Guild]:
if self._parent is not None: if self._parent is not None:
return self._parent._get_guild(guild_id) return self._parent._get_guild(guild_id)
return None return None
def store_user(self, data): def store_user(self, data: Union[UserPayload, PartialUserPayload]) -> BaseUser:
if self._parent is not None: if self._parent is not None:
return self._parent.store_user(data) return self._parent.store_user(data)
# state parameter is artificial # state parameter is artificial
return BaseUser(state=self, data=data) # type: ignore return BaseUser(state=self, data=data) # type: ignore
def create_user(self, data): def create_user(self, data: Union[UserPayload, PartialUserPayload]) -> BaseUser:
# state parameter is artificial # state parameter is artificial
return BaseUser(state=self, data=data) # type: ignore return BaseUser(state=self, data=data) # type: ignore
@property @property
def http(self): def http(self) -> Union[HTTPClient, _FriendlyHttpAttributeErrorHelper]:
if self._parent is not None: if self._parent is not None:
return self._parent.http return self._parent.http
@ -630,7 +649,7 @@ class _WebhookState:
# however, using it should result in a late-binding error. # however, using it should result in a late-binding error.
return _FriendlyHttpAttributeErrorHelper() return _FriendlyHttpAttributeErrorHelper()
def __getattr__(self, attr): def __getattr__(self, attr: str) -> Any:
if self._parent is not None: if self._parent is not None:
return getattr(self._parent, attr) return getattr(self._parent, attr)
@ -830,19 +849,24 @@ class BaseWebhook(Hashable):
'_state', '_state',
) )
def __init__(self, data: WebhookPayload, token: Optional[str] = None, state: Optional[ConnectionState] = None): def __init__(
self,
data: WebhookPayload,
token: Optional[str] = None,
state: Optional[_State] = None,
) -> None:
self.auth_token: Optional[str] = token self.auth_token: Optional[str] = token
self._state: Union[ConnectionState, _WebhookState] = state or _WebhookState(self, parent=state) self._state: _State = state or _WebhookState(self, parent=state)
self._update(data) self._update(data)
def _update(self, data: WebhookPayload): def _update(self, data: WebhookPayload) -> None:
self.id = int(data['id']) self.id: int = int(data['id'])
self.type = try_enum(WebhookType, int(data['type'])) self.type: WebhookType = try_enum(WebhookType, int(data['type']))
self.channel_id = utils._get_as_snowflake(data, 'channel_id') self.channel_id: Optional[int] = utils._get_as_snowflake(data, 'channel_id')
self.guild_id = utils._get_as_snowflake(data, 'guild_id') self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id')
self.name = data.get('name') self.name: Optional[str] = data.get('name')
self._avatar = data.get('avatar') self._avatar: Optional[str] = data.get('avatar')
self.token = data.get('token') self.token: Optional[str] = data.get('token')
user = data.get('user') user = data.get('user')
self.user: Optional[Union[BaseUser, User]] = None self.user: Optional[Union[BaseUser, User]] = None
@ -1010,11 +1034,17 @@ class Webhook(BaseWebhook):
__slots__: Tuple[str, ...] = ('session',) __slots__: Tuple[str, ...] = ('session',)
def __init__(self, data: WebhookPayload, session: aiohttp.ClientSession, token: Optional[str] = None, state=None): def __init__(
self,
data: WebhookPayload,
session: aiohttp.ClientSession,
token: Optional[str] = None,
state: Optional[_State] = None,
) -> None:
super().__init__(data, token, state) super().__init__(data, token, state)
self.session = session self.session: aiohttp.ClientSession = session
def __repr__(self): def __repr__(self) -> str:
return f'<Webhook id={self.id!r}>' return f'<Webhook id={self.id!r}>'
@property @property
@ -1023,7 +1053,7 @@ class Webhook(BaseWebhook):
return f'https://discord.com/api/webhooks/{self.id}/{self.token}' return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
@classmethod @classmethod
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook: def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Self:
"""Creates a partial :class:`Webhook`. """Creates a partial :class:`Webhook`.
Parameters Parameters
@ -1059,7 +1089,7 @@ class Webhook(BaseWebhook):
return cls(data, session, token=bot_token) return cls(data, session, token=bot_token)
@classmethod @classmethod
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook: def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Self:
"""Creates a partial :class:`Webhook` from a webhook URL. """Creates a partial :class:`Webhook` from a webhook URL.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
@ -1102,7 +1132,7 @@ class Webhook(BaseWebhook):
return cls(data, session, token=bot_token) # type: ignore return cls(data, session, token=bot_token) # type: ignore
@classmethod @classmethod
def _as_follower(cls, data, *, channel, user) -> Webhook: def _as_follower(cls, data, *, channel, user) -> Self:
name = f"{channel.guild} #{channel}" name = f"{channel.guild} #{channel}"
feed: WebhookPayload = { feed: WebhookPayload = {
'id': data['webhook_id'], 'id': data['webhook_id'],
@ -1118,8 +1148,8 @@ class Webhook(BaseWebhook):
return cls(feed, session=session, state=state, token=state.http.token) return cls(feed, session=session, state=state, token=state.http.token)
@classmethod @classmethod
def from_state(cls, data, state) -> Webhook: def from_state(cls, data: WebhookPayload, state: ConnectionState) -> Self:
session = state.http._HTTPClient__session session = state.http._HTTPClient__session # type: ignore
return cls(data, session=session, state=state, token=state.http.token) return cls(data, session=session, state=state, token=state.http.token)
async def fetch(self, *, prefer_auth: bool = True) -> Webhook: async def fetch(self, *, prefer_auth: bool = True) -> Webhook:
@ -1168,7 +1198,7 @@ class Webhook(BaseWebhook):
return Webhook(data, self.session, token=self.auth_token, state=self._state) return Webhook(data, self.session, token=self.auth_token, state=self._state)
async def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True): async def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True) -> None:
"""|coro| """|coro|
Deletes this Webhook. Deletes this Webhook.

62
discord/webhook/sync.py

@ -37,7 +37,7 @@ import time
import re import re
from urllib.parse import quote as urlquote from urllib.parse import quote as urlquote
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, TypeVar, Type, overload
import weakref import weakref
from .. import utils from .. import utils
@ -56,36 +56,50 @@ __all__ = (
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from types import TracebackType
from ..file import File from ..file import File
from ..embeds import Embed from ..embeds import Embed
from ..mentions import AllowedMentions from ..mentions import AllowedMentions
from ..message import Attachment from ..message import Attachment
from ..abc import Snowflake
from ..state import ConnectionState
from ..types.webhook import ( from ..types.webhook import (
Webhook as WebhookPayload, Webhook as WebhookPayload,
) )
from ..abc import Snowflake from ..types.message import (
Message as MessagePayload,
)
BE = TypeVar('BE', bound=BaseException)
try: try:
from requests import Session, Response from requests import Session, Response
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
MISSING = utils.MISSING MISSING: Any = utils.MISSING
class DeferredLock: class DeferredLock:
def __init__(self, lock: threading.Lock): def __init__(self, lock: threading.Lock) -> None:
self.lock = lock self.lock: threading.Lock = lock
self.delta: Optional[float] = None self.delta: Optional[float] = None
def __enter__(self): def __enter__(self) -> Self:
self.lock.acquire() self.lock.acquire()
return self return self
def delay_by(self, delta: float) -> None: def delay_by(self, delta: float) -> None:
self.delta = delta self.delta = delta
def __exit__(self, type, value, traceback): def __exit__(
self,
exc_type: Optional[Type[BE]],
exc: Optional[BE],
traceback: Optional[TracebackType],
) -> None:
if self.delta: if self.delta:
time.sleep(self.delta) time.sleep(self.delta)
self.lock.release() self.lock.release()
@ -218,7 +232,7 @@ class WebhookAdapter:
token: Optional[str] = None, token: Optional[str] = None,
session: Session, session: Session,
reason: Optional[str] = None, reason: Optional[str] = None,
): ) -> None:
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id) route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session, reason=reason, auth_token=token) return self.request(route, session, reason=reason, auth_token=token)
@ -229,7 +243,7 @@ class WebhookAdapter:
*, *,
session: Session, session: Session,
reason: Optional[str] = None, reason: Optional[str] = None,
): ) -> None:
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token) route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason) return self.request(route, session, reason=reason)
@ -241,7 +255,7 @@ class WebhookAdapter:
*, *,
session: Session, session: Session,
reason: Optional[str] = None, reason: Optional[str] = None,
): ) -> WebhookPayload:
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id) route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session, reason=reason, payload=payload, auth_token=token) return self.request(route, session, reason=reason, payload=payload, auth_token=token)
@ -253,7 +267,7 @@ class WebhookAdapter:
*, *,
session: Session, session: Session,
reason: Optional[str] = None, reason: Optional[str] = None,
): ) -> WebhookPayload:
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token) route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason, payload=payload) return self.request(route, session, reason=reason, payload=payload)
@ -268,7 +282,7 @@ class WebhookAdapter:
files: Optional[List[File]] = None, files: Optional[List[File]] = None,
thread_id: Optional[int] = None, thread_id: Optional[int] = None,
wait: bool = False, wait: bool = False,
): ) -> MessagePayload:
params = {'wait': int(wait)} params = {'wait': int(wait)}
if thread_id: if thread_id:
params['thread_id'] = thread_id params['thread_id'] = thread_id
@ -282,7 +296,7 @@ class WebhookAdapter:
message_id: int, message_id: int,
*, *,
session: Session, session: Session,
): ) -> MessagePayload:
route = Route( route = Route(
'GET', 'GET',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}', '/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -302,7 +316,7 @@ class WebhookAdapter:
payload: Optional[Dict[str, Any]] = None, payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None, multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None, files: Optional[List[File]] = None,
): ) -> MessagePayload:
route = Route( route = Route(
'PATCH', 'PATCH',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}', '/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -319,7 +333,7 @@ class WebhookAdapter:
message_id: int, message_id: int,
*, *,
session: Session, session: Session,
): ) -> None:
route = Route( route = Route(
'DELETE', 'DELETE',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}', '/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -335,7 +349,7 @@ class WebhookAdapter:
token: str, token: str,
*, *,
session: Session, session: Session,
): ) -> WebhookPayload:
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id) route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session=session, auth_token=token) return self.request(route, session=session, auth_token=token)
@ -345,7 +359,7 @@ class WebhookAdapter:
token: str, token: str,
*, *,
session: Session, session: Session,
): ) -> WebhookPayload:
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token) route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session=session) return self.request(route, session=session)
@ -569,11 +583,17 @@ class SyncWebhook(BaseWebhook):
__slots__: Tuple[str, ...] = ('session',) __slots__: Tuple[str, ...] = ('session',)
def __init__(self, data: WebhookPayload, session: Session, token: Optional[str] = None, state=None): def __init__(
self,
data: WebhookPayload,
session: Session,
token: Optional[str] = None,
state: Optional[Union[ConnectionState, _WebhookState]] = None,
) -> None:
super().__init__(data, token, state) super().__init__(data, token, state)
self.session = session self.session: Session = session
def __repr__(self): def __repr__(self) -> str:
return f'<Webhook id={self.id!r}>' return f'<Webhook id={self.id!r}>'
@property @property
@ -812,7 +832,7 @@ class SyncWebhook(BaseWebhook):
return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state) return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state)
def _create_message(self, data): def _create_message(self, data: MessagePayload) -> SyncWebhookMessage:
state = _WebhookState(self, parent=self._state) state = _WebhookState(self, parent=self._state)
# state may be artificial (unlikely at this point...) # state may be artificial (unlikely at this point...)
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore

2
discord/widget.py

@ -278,7 +278,7 @@ class Widget:
def __str__(self) -> str: def __str__(self) -> str:
return self.json_url return self.json_url
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
if isinstance(other, Widget): if isinstance(other, Widget):
return self.id == other.id return self.id == other.id
return False return False

Loading…
Cancel
Save